1️⃣ Deep Q-Learning
Learning Objectives
- Understand the DQN algorithm
- Learn more about RL debugging, and build probe environments to debug your agents
- Create a replay buffer to store environment transitions
- Implement DQN using PyTorch, on the CartPole environment
In this section, you'll implement Deep Q-Learning, often referred to as DQN for "Deep Q-Network". This was used in a landmark paper Playing Atari with Deep Reinforcement Learning.
At the time, the paper was very exciting: The agent would play the game by only looking at the same screen pixel data that a human player would be looking at, rather than a description of where the enemies in the game world are. The idea that convolutional neural networks could look at Atari game pixels and "see" gameplay-relevant features like a Space Invader was new and noteworthy. In 2022, we take for granted that convnets work, so we're going to focus on the RL aspect solely, and not the vision component.
Optional Readings
- Deep Q Networks Explained (25 minutes)
- A high-level distillation as to how DQN works.
- Read sections 1-4 (further sections optional).
- Andy Jones - Debugging RL, Without the Agonizing Pain (10 minutes)
- Useful tips for debugging your code when it's not working.
- Read up to (not including) the Common Fixes section. Also read the Practical Advice section, up to and including "Use probe agents". The rest of the post is optional, and you're recommended to come back to it near the end if you're stuck.
- The "probe environments" (a collection of simple environments of increasing complexity) section will be our first line of defense against bugs, you'll implement these in exercises below.
Interesting Resources (not required reading)
- An Outsider's Tour of Reinforcement Learning - comparison of RL techniques with the engineering discipline of control theory.
- Towards Characterizing Divergence in Deep Q-Learning - analysis of what causes learning to diverge
- Divergence in Deep Q-Learning: Tips and Tricks - includes some plots of average returns for comparison
- Deep RL Bootcamp - 2017 bootcamp with video and slides. Good if you like videos.
- DQN debugging using OpenAI gym Cartpole - random dude's adventures in trying to get it to work.
- CleanRL DQN - single file implementations of RL algorithms. Your starter code today is based on this; try not to spoiler yourself by looking at the solutions too early!
- Deep Reinforcement Learning Doesn't Work Yet - 2018 article describing difficulties preventing industrial adoption of RL.
- Deep Reinforcement Learning Works - Now What? - 2020 response to the previous article highlighting recent progress.
- Seed RL - example of distributed RL using Docker and GCP.
Conceptual overview of DQN
DQN is the natural extension of Q-Learning into the domain of deep learning. The main difference is that, instead of a table to store all the Q-values for each state-action pair, we train a neural network to learn this function for us. The usual implementation (which we'll use here) is for the Q-network to take the state as input, and output a vector of optimalQ-values for each action, i.e. we're learning the function:
Below is an algorithm showing the conceptual overview of DQN. We cycle through the following process:
- Generate a batch of experiences using our current policy, by epsilon-greedy sampling (i.e. we mostly take the action with the highest Q-value, but occasionally take a random action to encourage exploration). Store these experiences in the replay buffer.
- Use these values to calculate a TD (temporal difference) error, and update our network.
- To increase stability, we also have a target network we use for the "next step" part of the TD error. This is a lagged copy of the Q-network (i.e. we update our Q-network via gradient descent, and then every so often we copy the Q-network weights over to our target network).
- Repeat this until convergence.

Fast Feedback Loops
We want to have faster feedback loops, and learning from Atari pixels doesn't achieve that. It might take 15 minutes per training run to get an agent to do well on Breakout, and that's if your implementation is relatively optimized. Even waiting 5 minutes to learn Pong from pixels is going to limit your ability to iterate, compared to using environments that are as simple as possible.
CartPole
The classic environment "CartPole-v1" is simple to understand, yet hard enough for a RL agent to be interesting, by the end of the day your agent will be able to do this and more! (Click to watch!)
If you'd like to try the CartPole environment yourself, click here to open the simulation in a new tab. * Use Left/Right arrow keys to move the cart, * R to reset, * Q to quit. * Use F/S to make the simulation faster/slower.
Unlike the real CartPole environment, this simulation will not terminate the episode if the pole falls over. We've also cheated here and added a hidden third no-op action, such that if no button is pressed, no force is applied to the cart. This makes the simulation a bit easier for you as the human to play. The real cartpole environment doesn't act like this: the agent must choose to push the cart either left or right on each timestep.
The description of the task is here. Note that unlike the previous environments, the observation here is now continuous. You can see the source for CartPole here; don't worry about the implementation but do read the documentation to understand the format of the actions and observations.
The simple physics involved would be very easy for a model-based algorithm to fit, (this is a common assignment in control theory using proportional-integral-derivative (PID) controllers) but today we're doing it model-free: your agent has no idea that these observations represent positions or velocities, and it has no idea what the laws of physics are. The network has to learn in which direction to bump the cart in response to the current state of the world.
Each environment can have different versions registered to it. By consulting the Gym source you can see that CartPole-v0 and CartPole-v1 are the same environment, except that v1 has longer episodes. Again, a minor change like this can affect what algorithms score well; an agent might consistently survive for 200 steps in an unstable fashion that means it would fall over if ran for 500 steps.
env = gym.make("CartPole-v1", render_mode="rgb_array")
print(env.action_space) # 2 actions: left and right
print(env.observation_space) # Box(4): each action can take a continuous range of values
Discrete(2) Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)
Outline of the Exercises
The exercises are roughly split into 4 sections:
- Implement the Q-network that maps a state to an estimated value for each action.
- Implement a replay buffer to store experiences $e_t = (s_t, a_t, r_{t+1}, d_{t+1}, s_{t+1})$.
- Implement the policy which chooses actions based on the Q-network, plus epsilon greedy randomness to encourage exploration.
- Piece everything together into a training loop and train your agent.
The Q-Network
The Q-Network takes in an observation $s$ and outputs a vector $[Q^*(s, a^1), \ldots Q^*(s,a^n)]$ representing an estimate of the optimal Q-value for the given state $s$, and each possible action $\mathcal{A} = \{a^1, \ldots, a^n\}$. This replaces our Q-value table used in Q-learning.
For best results, the architecture of the Q-network can be customized to each particular problem. For example, the architecture of OpenAI Five used to play DOTA 2 is pretty complex and involves LSTMs.
For learning from pixels, a simple convolutional network and some fully connected layers does quite well. Where we have already processed features here, it's even easier: an MLP of this size should be plenty large for any environment today.
Implement the Q-network using a standard MLP, constructed of alternating Linear and ReLU layers. The size of the input will match the dimensionality of the observation space, and the size of the output will match the number of actions to choose from (associating a reward to each.) The dimensions of the hidden_sizes are provided.
Here is a diagram of what our particular Q-Network will look like for CartPole (you can open it in a new tab if it's hard to see clearly):
Question - why do we not include a ReLU at the end?
If you end with a ReLU, then your network can only predict 0 or positive Q-values. This will cause problems as soon as you encounter an environment with negative rewards, or you try to do some scaling of the rewards.
Question - since CartPole-v1 gives +1 reward on every timestep, why do you think the network doesn't just learn the constant +1 function regardless of observation?
The network is learning Q-values (the sum of all future expected discounted rewards from this state/action pair), not rewards. Correspondingly, once the agent has learned a good policy, the Q-value associated with state action pair (pole is slightly left of vertical, move cart left) should be large, as we would expect a long episode (and correspondingly lots of reward) by taking actions to help to balance the pole. Pairs like (cart near right boundary, move cart right) cause the episode to terminate, and as such the network will learn low Q-values.
Exercise - implement QNetwork
Note - in this implementation we can assume that obs_shape is a tuple of length 1 (in the case of CartPole this will be (4,)), so you can treat it as just an integer value above, e.g. your first linear layer should be from obs_shape[0] to 120.
class QNetwork(nn.Module):
"""
For consistency with your tests, please wrap your modules in a `nn.Sequential` called `layers`.
"""
layers: nn.Sequential
def __init__(self, obs_shape: tuple[int], num_actions: int, hidden_sizes: list[int] = [120, 84]):
super().__init__()
assert len(obs_shape) == 1, "Expecting a single vector of observations"
raise NotImplementedError()
def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
net = QNetwork(obs_shape=(4,), num_actions=2)
n_params = sum((p.nelement() for p in net.parameters()))
assert isinstance(getattr(net, "layers", None), nn.Sequential)
print(net)
print(f"Total number of parameters: {n_params}")
print("You should manually verify network is Linear-ReLU-Linear-ReLU-Linear")
assert not isinstance(net.layers[-1], nn.ReLU)
assert n_params == 10934
Solution
class QNetwork(nn.Module):
"""
For consistency with your tests, please wrap your modules in a `nn.Sequential` called `layers`.
"""
layers: nn.Sequential
def __init__(self, obs_shape: tuple[int], num_actions: int, hidden_sizes: list[int] = [120, 84]):
super().__init__()
assert len(obs_shape) == 1, "Expecting a single vector of observations"
in_features_list = [obs_shape[0]] + hidden_sizes
out_features_list = hidden_sizes + [num_actions]
layers = []
for i, (in_features, out_features) in enumerate(zip(in_features_list, out_features_list)):
layers.append(nn.Linear(in_features, out_features))
if i < len(in_features_list) - 1:
layers.append(nn.ReLU())
self.layers = nn.Sequential(*layers)
def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
Replay Buffer
The goal of DQN is to reduce the reinforcement learning problem to a supervised learning problem. In supervised learning, training examples should be drawn identically and independantly distributed (i.i.d.) from some distribution, and we hope to generalize to future examples from that distribution. Obviously perfect i.i.d. sampling isn't attainable, but we can approximate this by filling a buffer of past experiences and sampling from it. Note that for very complex problems we may need a very large buffer, because we want the policy to get a representative sample of all the diverse scenarios that can happen in the environment. OpenAI Five used batch sizes of over 2 million experiences for Dota 2! However we'll be working with the fairly simple CartPole environment, and so we can get away with a much smaller buffer.
In RL, the distribution of experiences $e_t = (s_t, a_t, r_{t+1}, s_{t+1})$ to train from depend on the policy $\pi$ followed, which depends on the current state of the Q-value network, so DQN is always chasing a moving target. This is why the training loss curve isn't going to have a nice steady decrease like in supervised learning. We will extend experiences to $e_t = (s_t, a_t, r_{t+1}, s_{t+1}, d_{t+1})$. Here, $d_{t+1}$ is a boolean indicating that $s_{t+1}$ is a terminal observation, and that no further interaction happened beyond $s_{t+1}$ in the episode from which it was generated.
Termination vs Truncation
Note that we take $d_{t+1}$ to be terminated, not done = terminated | truncated. The reason is as follows: our time limit was imposed for practical reasons to help with learning, but if the agent views the environment timing out as a form of failure that terminates its reward then it would have no reason to prefer the behaviour "stay perfectly level" to the behaviour "stay level for the first 499 timesteps then immediately fall over"! We want to encourage the agent to perform well all the time, not just perform well until the environment times out. See this page for more discussion.
ReplayBuffer and ReplayBufferSamples
We've given you 2 classes below. The first, ReplayBuffer, holds data from past experiences and also contains methods for sampling that data (the samples are instances of ReplayBufferSamples).
You should read these implementations carefully, making sure you understand how they work. A few things to note:
- The
addmethod adds multiple experiences at once: the tensors likeobshave shape(num_envs, *obs_shape). This is because we're using theSyncVectorEnvclass which allows us to step through & generate experiences for multiple environments simultaneously. We'll see how this works in practice later. - The
addmethod will add these experiences to the end of the buffer, slicing the buffer if it's too long. Note that the slicing is done so that we remove the oldest experiences when the buffer is full. - The
samplemethod will return aReplayBufferSamplesobject containing the experiences sampled from the buffer. These are sampled with replacement, and the data is converted to PyTorch tensors on the correct device.
@dataclass
class ReplayBufferSamples:
"""
Samples from the replay buffer, converted to PyTorch for use in neural network training.
Data is equivalent to (s_t, a_t, r_{t+1}, d_{t+1}, s_{t+1}). Note - here, d_{t+1} is actually **terminated** rather
than **done** (i.e. it records the times when we went out of bounds, not when the environment timed out).
"""
obs: Float[Tensor, " sample_size *obs_shape"]
actions: Float[Tensor, " sample_size *action_shape"]
rewards: Float[Tensor, " sample_size"]
terminated: Bool[Tensor, " sample_size"]
next_obs: Float[Tensor, " sample_size *obs_shape"]
class ReplayBuffer:
"""
Contains buffer; has a method to sample from it to return a ReplayBufferSamples object.
"""
rng: np.random.Generator
obs: Float[Arr, " buffer_size *obs_shape"]
actions: Float[Arr, " buffer_size *action_shape"]
rewards: Float[Arr, " buffer_size"]
terminated: Bool[Arr, " buffer_size"]
next_obs: Float[Arr, " buffer_size *obs_shape"]
def __init__(
self,
num_envs: int,
obs_shape: tuple[int],
action_shape: tuple[int],
buffer_size: int,
seed: int,
):
self.num_envs = num_envs
self.obs_shape = obs_shape
self.action_shape = action_shape
self.buffer_size = buffer_size
self.rng = np.random.default_rng(seed)
self.obs = np.empty((0, *self.obs_shape), dtype=np.float32)
self.actions = np.empty((0, *self.action_shape), dtype=np.int32)
self.rewards = np.empty(0, dtype=np.float32)
self.terminated = np.empty(0, dtype=bool)
self.next_obs = np.empty((0, *self.obs_shape), dtype=np.float32)
def add(
self,
obs: Float[Arr, " num_envs *obs_shape"],
actions: Int[Arr, " num_envs *action_shape"],
rewards: Float[Arr, " num_envs"],
terminated: Bool[Arr, " num_envs"],
next_obs: Float[Arr, " num_envs *obs_shape"],
) -> None:
"""
Add a batch of transitions to the replay buffer.
"""
# Check shapes & datatypes
for data, expected_shape in zip(
[obs, actions, rewards, terminated, next_obs],
[self.obs_shape, self.action_shape, (), (), self.obs_shape],
):
assert isinstance(data, np.ndarray)
assert data.shape == (self.num_envs, *expected_shape)
# Add data to buffer, slicing off the old elements
self.obs = np.concatenate((self.obs, obs))[-self.buffer_size :]
self.actions = np.concatenate((self.actions, actions))[-self.buffer_size :]
self.rewards = np.concatenate((self.rewards, rewards))[-self.buffer_size :]
self.terminated = np.concatenate((self.terminated, terminated))[-self.buffer_size :]
self.next_obs = np.concatenate((self.next_obs, next_obs))[-self.buffer_size :]
def sample(self, sample_size: int, device: t.device) -> ReplayBufferSamples:
"""
Sample a batch of transitions from the buffer, with replacement.
"""
indices = self.rng.integers(0, self.buffer_size, sample_size)
return ReplayBufferSamples(
obs=t.tensor(self.obs[indices], dtype=t.float32, device=device),
actions=t.tensor(self.actions[indices], device=device),
rewards=t.tensor(self.rewards[indices], dtype=t.float32, device=device),
terminated=t.tensor(self.terminated[indices], device=device),
next_obs=t.tensor(self.next_obs[indices], dtype=t.float32, device=device),
)
Next, you can run the following code to visualize your cart's position and angle, and see how these look in both the buffer and the buffer's random samples. Do the samples look correctly shuffled? Also, based on the CartPole source code, do the angles & positions at which the cart terminates make sense? (Note, the min/max values in the table are different to the termination ranges, the latter can be found below the table in the docstring.)
Note that the code below uses the SyncVectorEnv class, which is what lets us step through multiple environments at once. We create it by passing it a list of functions which can be called to create environments (see the make_env function in utils.py for exactly how this works). Note that in this case we're just passing it a single environment; tomorrow we'll actually make full use of SyncVectorEnv by giving it multiple environments.
Lastly, note how when we terminate environments we do something slightly different. If envs.step results in some environments terminating, it'll actually return next_obs as the observation for the next environment. In this case, we want to use this as our starting observation for the next step, but we need to make sure we record the correct terminal observation in our buffer - we do this by extracting it from the infos dict, which is where it gets stored. You can see this in the plots below: the vertical lines are the values $t$ where $d_{t+1}=1$ i.e. $s_{t+1}$ is terminal, and we can see that $s_t, s_{t+1}$ both refer to the terminated episode at this point and both refer to the new episode immediately after.
buffer = ReplayBuffer(num_envs=1, obs_shape=(4,), action_shape=(), buffer_size=256, seed=0)
envs = gym.vector.SyncVectorEnv([make_env("CartPole-v1", 0, 0, "test")])
obs, infos = envs.reset()
for i in range(256):
# Choose random action, and take a step in the environment
actions = envs.action_space.sample()
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
# Get `real_next_obs` by finding all environments where we terminated & replacing `next_obs`
# with the actual terminal states
true_next_obs = next_obs.copy()
for n in range(envs.num_envs):
if (terminated | truncated)[n]:
true_next_obs[n] = infos["final_observation"][n]
# Add experience to buffer, as long as we didn't just finish an episode (so obs & next_obs are
# from the same episode)
buffer.add(obs, actions, rewards, terminated, true_next_obs)
obs = next_obs
sample = buffer.sample(256, device="cpu")
plot_cartpole_obs_and_dones(
buffer.obs,
buffer.terminated,
title="Current obs s<sub>t</sub><br>so when d<sub>t+1</sub> = 1, these are the states just before termination",
)
plot_cartpole_obs_and_dones(
buffer.next_obs,
buffer.terminated,
title="Next obs s<sub>t+1</sub><br>so when d<sub>t+1</sub> = 1, these are the terminated states",
)
plot_cartpole_obs_and_dones(
sample.obs,
sample.terminated,
title="Current obs s<sub>t</sub> (sampled)<br>this is what gets fed into our model for training",
)
Click to see the expected output
Exploration
DQN makes no attempt to explore intelligently. The exploration strategy is the same as for Q-Learning: agents take a random action with probability epsilon, but now we gradually decrease epsilon. The Q-network is also randomly initialized (rather than initialized with zeros), so its predictions of what is the best action to take are also pretty random to start.
Some games like Montezuma's Revenge have sparse rewards that require more advanced exploration methods to obtain. The player is required to collect specific keys to unlock specific doors, but unlike humans, DQN has no prior knowledge about what a key or a door is, and it turns out that bumbling around randomly has too low of a probability of correctly matching a key to its door. Even if the agent does manage to do this, the long separation between finding the key and going to the door makes it hard to learn that picking the key up was important.
As a result, DQN scored an embarrassing 0% of average human performance on this game.
Reward Shaping
One solution to sparse rewards is to use human knowledge to define auxillary reward functions that are more dense and made the problem easier (in exchange for leaking in side knowledge and making the algorithm more specific to the problem at hand). What could possibly go wrong?
The canonical example is for a game called CoastRunners, where the goal was given to maximize the score (hoping that the agent would learn to race around the map). Instead, it found it could gain more score by driving in a loop picking up power-ups just as they respawn, crashing and setting the boat alight in the process.
Reward Hacking
For Montezuma's Revenge, the reward was shaped by giving a small reward for picking up the key. One time this was tried, the reward was given slightly too early and the agent learned it could go close to the key without quite picking it up, obtain the auxillary reward, and then back up and repeat.
A collected list of examples of Reward Hacking can be found here.
Advanced Exploration
It would be better if the agent didn't require these auxillary rewards to be hardcoded by humans, but instead reply on other signals from the environment that a state might be worth exploring. One idea is that a state which is "surprising" or "novel" (according to the agent's current belief of how the environment works) in some sense might be valuable. Designing an agent to be innately curious presents a potential solution to exploration, as the agent will focus exploration in areas it is unfamiliar with. In 2018, OpenAI released Random Network Distillation which made progress in formalizing this notion, by measuring the agent's ability to predict the output of a neural network on visited states. States that are hard to predict are poorly explored, and thus highly rewarded. In 2019, an excellent paper First return, then explore found an even better approach. Such reward shaping can also be gamed, leading to the noisy TV problem, where agents that seek novelty become entranced by a source of randomness in the environment (like a analog TV out of tune displaying white noise), and ignore everything else in the environment.
Exercise - implement linear scheduler
For now, implement the basic linearly decreasing exploration schedule.
def linear_schedule(
current_step: int,
start_e: float,
end_e: float,
exploration_fraction: float,
total_timesteps: int,
) -> float:
"""
Return the appropriate epsilon for the current step.
Epsilon should be start_e at step 0 and decrease linearly to end_e at step (exploration_fraction
* total_timesteps). In other words, we are in "explore mode" with start_e >= epsilon >= end_e
for the first `exploration_fraction` fraction of total timesteps, and then stay at end_e for the
rest of the episode.
"""
raise NotImplementedError()
epsilons = [
linear_schedule(step, start_e=1.0, end_e=0.05, exploration_fraction=0.5, total_timesteps=500)
for step in range(500)
]
line(
epsilons,
labels={"x": "steps", "y": "epsilon"},
title="Probability of random action",
height=400,
width=600,
)
tests.test_linear_schedule(linear_schedule)
Click to see the expected output
Solution
def linear_schedule(
current_step: int,
start_e: float,
end_e: float,
exploration_fraction: float,
total_timesteps: int,
) -> float:
"""
Return the appropriate epsilon for the current step.
Epsilon should be start_e at step 0 and decrease linearly to end_e at step (exploration_fraction
* total_timesteps). In other words, we are in "explore mode" with start_e >= epsilon >= end_e
for the first `exploration_fraction` fraction of total timesteps, and then stay at end_e for the
rest of the episode.
"""
return start_e + (end_e - start_e) * min(current_step / (exploration_fraction * total_timesteps), 1)
Epsilon Greedy Policy
In DQN, the policy is implicitly defined by the Q-network: we take the action with the maximum predicted reward. This gives a bias towards optimism. By estimating the maximum of a set of values $v_1, \ldots, v_n$ using the maximum of some noisy estimates $\hat{v}_1, \ldots, \hat{v}_n$ with $\hat{v}_i \approx v$, we get unlucky and get very large positive noise on some samples, which the maximum then chooses. Hence, the agent will choose actions that the Q-network is overly optimistic about.
See Sutton and Barto, Section 6.7 if you'd like a more detailed explanation, or the original Double Q-Learning paper which notes this maximisation bias, and introduces a method to correct for it using two separate Q-value estimators, each used to update the other.
Exercise - implement the epsilon greedy policy
We've given you the first line of code, to convert the numpy array obs into a tensor on the correct device.
Note that we don't decide to explore for each environment individually: we either explore for all environments, or for none of them.
This means we can just avoid doing a forward pass for the Q-network entirely if we're exploring.
Other tips:
- Although you can technically use
envs.action_space.sample()to sample actions, it's better practice to work with the random number generatorrngthat we've provided. You can userng.random()to generate random numbers in the range $[0,1)$, andrng.integers(0, n, size)for an array of shapesizerandom integers in the range $0, 1, \ldots, n-1$. - Don't forget to convert the result back to a
np.ndarray, as this function expects. - Use
envs.single_action_space.nto get the number of actions.
def epsilon_greedy_policy(
envs: gym.vector.SyncVectorEnv,
q_network: QNetwork,
rng: np.random.Generator,
obs: Float[Arr, " num_envs *obs_shape"],
epsilon: float,
) -> Int[Arr, " num_envs *action_shape"]:
"""
With probability epsilon, take a random action. Otherwise, take a greedy action according to the
q_network.
Inputs:
envs: The family of environments to run against
q_network: The QNetwork used to approximate the Q-value function
obs: The current observation for each environment
epsilon: The probability of taking a random action
Returns:
actions: The sampled action for each environment.
"""
# Convert `obs` into a tensor so we can feed it into our model
obs = t.from_numpy(obs).to(device)
raise NotImplementedError()
tests.test_epsilon_greedy_policy(epsilon_greedy_policy)
Help - I'm confused about the action shape here.
In our case, the action shape is envs.single_action_space.shape = () (i.e. trivial, because our action is just a single integer not a vector or tensor) and the number of possible actions is envs.single_action_space.n = 2. This means your return type should just be a vector of ints of length num_envs, with each element being uniformly sampled from [0, 1].
Solution
def epsilon_greedy_policy(
envs: gym.vector.SyncVectorEnv,
q_network: QNetwork,
rng: np.random.Generator,
obs: Float[Arr, " num_envs *obs_shape"],
epsilon: float,
) -> Int[Arr, " num_envs *action_shape"]:
"""
With probability epsilon, take a random action. Otherwise, take a greedy action according to the
q_network.
Inputs:
envs: The family of environments to run against
q_network: The QNetwork used to approximate the Q-value function
obs: The current observation for each environment
epsilon: The probability of taking a random action
Returns:
actions: The sampled action for each environment.
"""
# Convert `obs` into a tensor so we can feed it into our model
obs = t.from_numpy(obs).to(device)
num_actions = envs.single_action_space.n
if rng.random() < epsilon:
return rng.integers(0, num_actions, size=(envs.num_envs,))
else:
q_scores = q_network(obs)
return q_scores.argmax(-1).detach().cpu().numpy()
Probe Environments
Extremely simple probe environments are a great way to debug your algorithm. The first one is given below.
Let's try and break down how this environment works. We see that the function step always returns the same thing. The observation and reward are always the same, and done is always true (i.e. the episode always terminates after one action). We expect the agent to rapidly learn that the value of the constant observation [0.0] is +1. This is in some sense the simplest possible probe.
A note on action spaces
The space we're using here is gym.spaces.Box. This means we're dealing with real-valued quantities, i.e. continuous not discrete. The first two arguments of Box are low and high, and these define a box in $\mathbb{R}^n$. For instance, if these arrays are (0, 0) and (1, 1) respectively, this defines the box $0 \leq x, y \leq 1$ in 2D space.
class Probe1(gym.Env):
"""
One action, observation of [0.0], one timestep long, +1 reward.
We expect the agent to rapidly learn that the value of the constant [0.0] observation is +1.0.
Note we're using a continuous observation space for consistency with CartPole.
"""
action_space: Discrete
observation_space: Box
def __init__(self, render_mode: str = "rgb_array"):
super().__init__()
self.observation_space = Box(np.array([0]), np.array([0]))
self.action_space = Discrete(1)
self.reset()
def step(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]:
return np.array([0]), 1.0, True, True, {}
def reset(self, seed: int | None = None, options=None) -> ObsType | tuple[ObsType, dict]:
super().reset(seed=seed)
return np.array([0.0]), {}
gym.envs.registration.register(id="Probe1-v0", entry_point=Probe1)
env = gym.make("Probe1-v0")
assert env.observation_space.shape == (1,)
assert env.action_space.shape == ()
Exercise - read & understand other probe environments
For each of the probes below, read their implementation code, and understand how they correspond to their docstrings (and to the descriptions given in Andy Jones' post).
It's very important to understand how these probes work, and why they're useful tools for debugging. When you're working on your own RL projects, you might have to write your own probes to suit your particular use cases.
class Probe2(gym.Env):
"""
One action, observation of [-1.0] or [+1.0], one timestep long, reward equals observation.
We expect the agent to rapidly learn the value of each observation is equal to the observation.
"""
action_space: Discrete
observation_space: Box
def __init__(self, render_mode: str = "rgb_array"):
super().__init__()
self.observation_space = Box(np.array([-1.0]), np.array([+1.0]))
self.action_space = Discrete(1)
self.reset()
self.reward = None
def step(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]:
assert self.reward is not None
return np.array([self.observation]), self.reward, True, True, {}
def reset(self, seed: int | None = None, options=None) -> ObsType | tuple[ObsType, dict]:
super().reset(seed=seed)
self.reward = 1.0 if self.np_random.random() < 0.5 else -1.0
self.observation = self.reward
return np.array([self.reward]), {}
class Probe3(gym.Env):
"""
One action, [0.0] then [1.0] observation, two timesteps, +1 reward at the end.
We expect the agent to rapidly learn the discounted value of the initial observation.
"""
action_space: Discrete
observation_space: Box
def __init__(self, render_mode: str = "rgb_array"):
super().__init__()
self.observation_space = Box(np.array([-0.0]), np.array([+1.0]))
self.action_space = Discrete(1)
self.reset()
def step(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]:
self.n += 1
if self.n == 1:
return np.array([1.0]), 0.0, False, False, {}
elif self.n == 2:
return np.array([0.0]), 1.0, True, True, {}
raise ValueError(self.n)
def reset(self, seed: int | None = None, options=None) -> ObsType | tuple[ObsType, dict]:
super().reset(seed=seed)
self.n = 0
return np.array([0.0]), {}
class Probe4(gym.Env):
"""
Two actions, [0.0] observation, one timestep, reward is -1.0 or +1.0 dependent on the action.
We expect the agent to learn to choose the +1.0 action.
"""
action_space: Discrete
observation_space: Box
def __init__(self, render_mode: str = "rgb_array"):
self.observation_space = Box(np.array([-0.0]), np.array([+0.0]))
self.action_space = Discrete(2)
self.reset()
def step(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]:
reward = -1.0 if action == 0 else 1.0
return np.array([0.0]), reward, True, True, {}
def reset(self, seed: int | None = None, options=None) -> ObsType | tuple[ObsType, dict]:
super().reset(seed=seed)
return np.array([0.0]), {}
class Probe5(gym.Env):
"""
Two actions, random 0/1 observation, one timestep, reward is 1 if action equals observation,
otherwise -1.
We expect the agent to learn to match its action to the observation.
"""
action_space: Discrete
observation_space: Box
def __init__(self, render_mode: str = "rgb_array"):
self.observation_space = Box(np.array([-1.0]), np.array([+1.0]))
self.action_space = Discrete(2)
self.reset()
def step(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]:
reward = 1.0 if action == self.obs else -1.0
return np.array([self.obs]), reward, True, True, {}
def reset(self, seed: int | None = None, options=None) -> ObsType | tuple[ObsType, dict]:
super().reset(seed=seed)
self.obs = 1.0 if self.np_random.random() < 0.5 else 0.0
return np.array([self.obs], dtype=float), {}
gym.envs.registration.register(id="Probe2-v0", entry_point=Probe2)
gym.envs.registration.register(id="Probe3-v0", entry_point=Probe3)
gym.envs.registration.register(id="Probe4-v0", entry_point=Probe4)
gym.envs.registration.register(id="Probe5-v0", entry_point=Probe5)
A brief summary of these, along with recommendations of where to go to debug if one of them fails (note that these won't be true 100% of the time, but should hopefully give you some useful direction):
Summary of probes
- Tests basic learning ability. If this fails, it means the agent has failed to learn to associate a constant observation with a constant reward. You should check your loss functions and optimizers in this case.
- Tests the agent's ability to differentiate between 2 different observations (and learn their respective values). If this fails, it means the agent has issues with handling multiple possible observations.
- Tests the agent's ability to handle time & reward delay. If this fails, it means the agent has problems with multi-step scenarios of discounting future rewards. You should look at how your agent step function works.
- Tests the agent's ability to learn from actions leading to different rewards. If this fails, it means the agent has failed to change its policy for different rewards, and you should look closer at how your agent is updating its policy based on the rewards it receives & the loss function.
- Tests the agent's ability to map observations to actions. If this fails, you should look at the code which handles multiple timesteps, as well as the code that handles the agent's map from observations to actions.
Main DQN Algorithm
We now combine all the elements we have designed thus far into the final DQN algorithm. Here, we assume the environment returns three parameters $(s_{new}, r, d)$, a new state $s_{new}$, a reward $r$ and a boolean $d$ indicating whether interaction has terminated yet.
Our Q-value function $Q(s,a)$ is now a network $Q(s,a ; \theta)$ parameterised by weights $\theta$. The key idea, as in Q-learning, is to ensure the Q-value function satisfies the optimal Bellman equation
terminated not terminated | truncated here - we don't want the agent to learn that its value is always zero just before the episode ends and so there's no point in continuing to perform well!
Since we have an expression which should be zero in expectation for our true Q-value function, and we want the model to learn from a variety of experiences at once, we can sample batches of experiences $B = \{s_{t_i}, a_{t_i}, r_{t_i+1}, d_{t_i+1}, s_{t_i+1}\}_{i=1}^{|B|}$ from the replay buffer, and train against the loss function which equals the squared temporal difference error:
Below is the full DQN algorithm from a paper, for reference. The notation isn't identical to ours (e.g. they use an if/else statement to handle the terminal state case), but the basic algorithm is the same.

DQN Dataclass
Below is a dataclass for training your DQN. You can use the arg_help method to see a description of each argument (it will also highlight any arguments which have been changed from their default values).
The exact breakdown of training is as follows:
- The agent takes
total_timestepssteps in the environment during the training loop. - The first
buffer_sizeof these steps are used to fill the replay buffer (we don't update gradients until the buffer is full). - After this point, we perform an optimizer step every
steps_per_trainsteps of our agent. We also copy the weights from our Q-network to our target network everytrains_per_target_updatesteps of our Q-network.
This is shown in the diagram below (the actual numbers aren't representative of the values in our dataclass, they're just to make sure the diagram is understandable - obviously the scale is very different in our actual training).

For example, in the code below we decrease total_timesteps, and this also decreases total training steps (which is computed in the __post_init__ method of our dataclass, as a function of total_timesteps).
@dataclass
class DQNArgs:
# Basic / global
seed: int = 1
env_id: str = "CartPole-v1"
num_envs: int = 1
# Wandb / logging
use_wandb: bool = False
wandb_project_name: str = "DQNCartPole"
wandb_entity: str | None = None
video_log_freq: int | None = 50
steps_per_live_video: int | None = None
# Duration of different phases / buffer memory settings
total_timesteps: int = 500_000
steps_per_train: int = 10
trains_per_target_update: int = 100
buffer_size: int = 10_000
# Optimization hparams
batch_size: int = 128
learning_rate: float = 2.5e-4
# RL-specific
gamma: float = 0.99
exploration_fraction: float = 0.2
start_e: float = 1.0
end_e: float = 0.1
def __post_init__(self):
assert self.total_timesteps - self.buffer_size >= self.steps_per_train
self.total_training_steps = (self.total_timesteps - self.buffer_size) // self.steps_per_train
self.video_save_path = section_dir / "videos"
args = DQNArgs(total_timesteps=400_000) # changing total_timesteps will also change ???
utils.arg_help(args)
Exercise - fill in the agent class
You should now fill in the methods for the DQNAgent class below. This is a class which is designed to handle taking steps in the environment (with an epsilon greedy policy), and updating the buffer.
play_stepshould be somewhat similar to the demo code you saw earlier, which sampled a batch of experiences to add to the buffer. It should:- Get actions (using
self.get_actionsrather than randomly sampling like we did in the demo code before) - Step our environment with these actions
- Add the new experiences to the buffer
- Some of these observations
- Set your new observation as
self.obs, ready for the next step
- Get actions (using
get_actionsshould do the following:- Set
self.epsilonaccording to the linear schedule function & the current global step counter - Sample actions according to the epsilon-greedy policy (i.e. using your
epsilon_greedy_policyfunction), and return them
- Set
A small note on code practices here - the implementation below was designed to follow separation of concerns (SoC), a design principle used in software engineering. The DQNAgent class only responsible for interacting with the environment; it doesn't do anything like create the Q-network or buffer on initialization. This is further reflected in the fact that we don't pass in args to our DQN agent, but instead pass in all the relevant variables separately (if we were forced to pass in args, this would be a sign that the DQN agent class might be doing too much work!).
class DQNAgent:
"""Base Agent class handling the interaction with the environment."""
def __init__(
self,
envs: gym.vector.SyncVectorEnv,
buffer: ReplayBuffer,
q_network: QNetwork,
start_e: float,
end_e: float,
exploration_fraction: float,
total_timesteps: int,
rng: np.random.Generator,
):
self.envs = envs
self.buffer = buffer
self.q_network = q_network
self.start_e = start_e
self.end_e = end_e
self.exploration_fraction = exploration_fraction
self.total_timesteps = total_timesteps
self.rng = rng
self.step = 0 # Tracking number of steps taken (across all environments)
self.obs, _ = self.envs.reset() # Need a starting observation
self.epsilon = start_e # Starting value (will be updated in `get_actions`)
def play_step(self) -> dict:
"""
Carries out a single interaction step between agent & environment, and adds results to the
replay buffer.
Returns `infos` (list of dictionaries containing info we will log).
"""
raise NotImplementedError()
self.step += self.envs.num_envs
return infos
def get_actions(self, obs: np.ndarray) -> np.ndarray:
"""
Samples actions according to the epsilon-greedy policy using the linear schedule for epsilon.
"""
raise NotImplementedError()
tests.test_agent(DQNAgent)
Solution
class DQNAgent:
"""Base Agent class handling the interaction with the environment."""
def __init__(
self,
envs: gym.vector.SyncVectorEnv,
buffer: ReplayBuffer,
q_network: QNetwork,
start_e: float,
end_e: float,
exploration_fraction: float,
total_timesteps: int,
rng: np.random.Generator,
):
self.envs = envs
self.buffer = buffer
self.q_network = q_network
self.start_e = start_e
self.end_e = end_e
self.exploration_fraction = exploration_fraction
self.total_timesteps = total_timesteps
self.rng = rng
self.step = 0 # Tracking number of steps taken (across all environments)
self.obs, _ = self.envs.reset() # Need a starting observation
self.epsilon = start_e # Starting value (will be updated in `get_actions`)
def play_step(self) -> dict:
"""
Carries out a single interaction step between agent & environment, and adds results to the
replay buffer.
Returns `infos` (list of dictionaries containing info we will log).
"""
self.obs = np.array(self.obs, dtype=np.float32)
actions = self.get_actions(self.obs)
next_obs, rewards, terminated, truncated, infos = self.envs.step(actions)
# Get `real_next_obs` by finding all environments where we terminated & replacing `next_obs`
# with the actual terminal states
true_next_obs = next_obs.copy()
for n in range(self.envs.num_envs):
if (terminated | truncated)[n]:
true_next_obs[n] = infos["final_observation"][n]
self.buffer.add(self.obs, actions, rewards, terminated, true_next_obs)
self.obs = next_obs
self.step += self.envs.num_envs
return infos
def get_actions(self, obs: np.ndarray) -> np.ndarray:
"""
Samples actions according to the epsilon-greedy policy using the linear schedule for epsilon.
"""
self.epsilon = linear_schedule(
self.step, self.start_e, self.end_e, self.exploration_fraction, self.total_timesteps
)
actions = epsilon_greedy_policy(self.envs, self.q_network, self.rng, obs, self.epsilon)
assert actions.shape == (len(self.envs.envs),)
return actions
Before we move on to the big exercise of today (completing the DQNTrainer class), we'll briefly discuss logging to Weights and Biases in RL, plus some general advice on what kinds of variables you should be logging.
Logging to wandb in RL
In previous exercises in this chapter, we've just trained the agent, and then plotted the reward per episode after training. For small toy examples that train in a few seconds this is fine, but for longer runs we'd like to watch the run live and make sure the agent is doing something interesting (especially if we were planning to run the model overnight). Luckily, Weights and Biases has got us covered! When you run your experiments, you'll be able to view not only live plots of the loss and average reward per episode while the agent is training - you can also log and view animations, which visualise your agent's progress in real time! The code below will handle all logging.
Sadly, effective logging & debugging in RL isn't just about watching videos, since in the vast majority of cases where your algorithm has a bug, the agent will just fail to learn anything useful and the videos won't be informative. Debugging RL requires knowing what variables to log and how to interpret the results you're getting, which requires some understanding of the underlying theory! This is part of the reason why we've spent so much time discussing the theory behind DQN and other RL algorithms, rather than just giving you a black box to train.
As an example of how logged variables can be misleading and hard to interpret, consider our TD loss function in DQN. This loss function just reflects how close together the Q-network's estimates are to the experiences currently sampled from the replay buffer, which might not adequately represent what the world actually looks like. This means that once the agent starts to learn something and do better at the problem, it's expected for the loss to increase. For example, maybe the Q-network initially learned some state was bad, because an agent that reached them was just flapping around randomly and died shortly after. But now it's getting evidence that the same state is good, now that the agent that reached the state has a better idea what to do next. A higher loss is thus actually a good sign that something is happening (the agent hasn't stagnated), but it's not clear if it's learning anything useful without also checking how the total reward per episode has changed. Key point - just looking at one variable can be misleading, we need to log multiple variables and derive a picture of what's happening from taking all of them into account!
Some useful variables to log during DQN training are:
- TD loss, i.e. the actual loss you're backpropagating through. This should start off high and decrease pretty quickly, but may not be monotonic (i.e. temporary spikes in loss aren't necessarily a bad thing)
- SPS (steps per second), i.e. the total number of agent steps divided by the total time. This helps us debug when the environment steps are a bottleneck (won't be the case in a simple environment like this one, but might matter more when we move to more complex environments)
- Q-values, i.e. the predicted Q-values from the Q-network. Can you guess how these should behave?
Question - what do you think the Q values will do when the agent moves closer to solving the cartpole environment?
Initially they should be near zero, thanks to the randomly initialized model weights. As our episode length get closer to 500 (i.e. we can essentially solve the environment), they should tend to the limit of the total possible time-discounted reward available, which is the geometric sum $1 + \gamma + \gamma^2 + \cdots$ (since we get 1 reward for every second we stand up, and as previously discussed, the way we handle dones in the formula above doesn't assume a truncated environment causes future rewards to be terminated). The limit of this sum is $\frac{1}{1-\gamma}$, which for our default value $\gamma = 0.99$ is approximately 100.
Note, the Q values won't increase smoothly, they'll spike up immediately after we copy over the weights from our Q-network to our target network. This is because each time we copy over weights, our gradient changes and the Q-network rapidly "catches up" to this new target network, causing the Q values to change rapidly. However, our copying over of weights will be frequent enough that these jumps will be relatively small, and so the curve should still appear smooth.
Exercise - write DQN training loop
Now we'll create a new class DQNTrainer, which will handle the full training loop. We've filled in the __init__ for you, which defines all the things you need (the networks, optimizer, replay buffer, and the agent). We've also filled in train for you, which performs the main training loop: it optionally initializes Weights & Biases, fills the buffer using prepopulate_replay_buffer, then alternates between training steps (where we sample from the buffer) & adding to the buffer (adding args.train_frequncy).
You should fill in the remaining 2 methods. First you should get the basic no-logging version working, then once you're running without error (even if maybe you're not learning anything useful) you should move onto logging as this will help you debug.
add_to_replay_buffer- This calls
self.agent.play_step()to takensteps in the environment, which adds the results to the replay buffer - It's used to fill the buffer before training starts, and before each training step to add new experiences to the buffer
- This calls
training_step- This performs an update step from a batch of experiences from the buffer, sampled using
self.buffer.samplewith batch sizeself.args.batch_size - An update step involves:
- Getting the predicted Q-values $Q(s_{t_i}, a_{t_i} ; \theta)$ from the Q-network
- Getting the max target Q-values $\max_a Q(s_{t_i+1}, a ; \theta_\text{target})$ from the target network (remember to use inference mode - we're not training the target network!)
- Computing the TD loss $L(\theta)$ using the formula we gave earlier (we've also copied it below, for convenience)
- Performing an update step with this loss
- You should also copy weights from the Q-network to the target network every
args.trains_per_target_updatesteps (i.e. wheneverself.agent.stepis a multiple of this). Theload_state_dictmethod might be useful here
- This performs an update step from a batch of experiences from the buffer, sampled using
For convenience, here's the full TD loss formula again:
When you get to logging, there are 2 types of data you can log:
- Data for terminated episodes, during buffer filling
- Terminated episode data can be found in the
infosdict returned by theagent.play_stepmethod. If environmentenv_idxterminated, theninfos["final_info"][env_idx]["episode"]will be a dict containing the lengthland rewardrof the terminated episode- We've given you a helper function
get_episode_data_from_infoswhich gives you a dict of the episode length & reward for the first terminated env, orNoneif no envs terminated. See the documentation page for an explanation.
- We've given you a helper function
- You can also log the SPS (steps per second) if you like, this helps figure out if the environment transitions are the bottleneck for your algorithm
- Terminated episode data can be found in the
- Data during training steps
- Mean TD loss, Q values, and the epsilon hyperparameter are all useful to log
Don't be discouraged if your code takes a while to work - it's normal for debugging RL to take longer than you would expect. Add asserts or your own tests, implement an appropriate probe environment, try anything in the Andy Jones post that sounds promising, and try to notice confusion. Reinforcement Learning is often so tricky as even if the algorithm has bugs, the agent might still learn something useful regardless (albeit maybe not as well), or even if everything is correct, the agent might just fail to learn anything useful (like how DQN failed to do anything on Montezuma's Revenge.)
Since the environment is already known to be one DQN can solve, and we've already provided hyperparameters that work for this environment, hopefully that's isolated a lot of the problems one would usually have with solving real world problems with RL.
def get_episode_data_from_infos(infos: dict) -> dict[str, int | float] | None:
"""
Helper function: returns dict of data from the first terminated environment, if at least one
terminated.
"""
for final_info in infos.get("final_info", []):
if final_info is not None and "episode" in final_info:
return {
"episode_length": final_info["episode"]["l"].item(),
"episode_reward": final_info["episode"]["r"].item(),
"episode_duration": final_info["episode"]["t"].item(),
}
class DQNTrainer:
def __init__(self, args: DQNArgs):
set_global_seeds(args.seed)
self.args = args
self.rng = np.random.default_rng(args.seed)
self.run_name = f"{args.env_id}__{args.wandb_project_name}__seed{args.seed}__{time.strftime('%Y%m%d-%H%M%S')}"
self.envs = gym.vector.SyncVectorEnv(
[make_env(idx=idx, run_name=self.run_name, **args.__dict__) for idx in range(args.num_envs)]
)
# Define some basic variables from our environment (note, we assume a single discrete action space)
num_envs = self.envs.num_envs
action_shape = self.envs.single_action_space.shape
num_actions = self.envs.single_action_space.n
obs_shape = self.envs.single_observation_space.shape
assert action_shape == ()
# Create our replay buffer
self.buffer = ReplayBuffer(num_envs, obs_shape, action_shape, args.buffer_size, args.seed)
# Create our networks & optimizer (target network should be initialized with a copy of the Q-network's weights)
self.q_network = QNetwork(obs_shape, num_actions).to(device)
self.target_network = QNetwork(obs_shape, num_actions).to(device)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = t.optim.AdamW(self.q_network.parameters(), lr=args.learning_rate)
# Create our agent
self.agent = DQNAgent(
self.envs,
self.buffer,
self.q_network,
args.start_e,
args.end_e,
args.exploration_fraction,
args.total_timesteps,
self.rng,
)
def add_to_replay_buffer(self, n: int, verbose: bool = False):
"""
Takes n steps with the agent, adding to the replay buffer (and logging any results). Should
return a dict of data from the last terminated episode, if any.
Optional argument `verbose`: if True, we can use a progress bar (useful to check how long
the initial buffer filling is taking).
"""
raise NotImplementedError()
def prepopulate_replay_buffer(self):
"""
Called to fill the replay buffer before training starts.
"""
n_steps_to_fill_buffer = self.args.buffer_size // self.args.num_envs
self.add_to_replay_buffer(n_steps_to_fill_buffer, verbose=True)
def training_step(self, step: int) -> None:
"""
Samples once from the replay buffer, and takes a single training step.
Args:
step (int): The number of training steps taken (used for logging, and for deciding when
to update the target network)
"""
raise NotImplementedError()
def train(self) -> None:
if self.args.use_wandb:
wandb.init(
project=self.args.wandb_project_name,
entity=self.args.wandb_entity,
name=self.run_name,
monitor_gym=self.args.video_log_freq is not None,
)
wandb.watch(self.q_network, log="all", log_freq=50)
self.prepopulate_replay_buffer()
pbar = tqdm(range(self.args.total_training_steps))
last_logged_time = time.time() # so we don't update the progress bar too much
for step in pbar:
data = self.add_to_replay_buffer(self.args.steps_per_train)
if data is not None and time.time() - last_logged_time > 0.5:
last_logged_time = time.time()
pbar.set_postfix(**data)
self.training_step(step)
if self.args.steps_per_live_video is not None and step % self.args.steps_per_live_video == 0:
from IPython.display import display
html_animation = generate_and_plot_trajectory(self, self.args)
display(html_animation)
self.envs.close()
if self.args.use_wandb:
wandb.finish()
Solution (simple, no logging)
def add_to_replay_buffer(self, n: int, verbose: bool = False):
'''
Takes n steps with the agent, adding to the replay buffer (and logging any results). Should return a dict of
data from the last terminated episode, if any.
Optional argument `verbose`: if True, we can use a progress bar (useful to check how long the initial buffer
filling is taking).
'''
data = None
for step in tqdm(range(n), disable=not verbose, desc="Adding to replay buffer"):
infos = self.agent.play_step()
data = data or get_episode_data_from_infos(infos)
return data
def prepopulate_replay_buffer(self):
'''
Called to fill the replay buffer before training starts.
'''
n_steps_to_fill_buffer = self.args.buffer_size // self.args.num_envs
self.add_to_replay_buffer(n_steps_to_fill_buffer, verbose=True)
def training_step(self, step: int) -> Float[Tensor, ""]:
'''
Samples once from the replay buffer, and takes a single training step. The `step` argument is used to track the
number of training steps taken.
'''
data = self.buffer.sample(self.args.batch_size, device) # s_t, a_t, r_{t+1}, d_{t+1}, s_{t+1}
with t.inference_mode():
target_max = self.target_network(data.next_obs).max(-1).values
predicted_q_vals = self.q_network(data.obs)[range(len(data.actions)), data.actions]
td_error = data.rewards + self.args.gamma * target_max * (1 - data.terminated.float()) - predicted_q_vals
loss = td_error.pow(2).mean()
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
if step % self.args.trains_per_target_update == 0:
self.target_network.load_state_dict(self.q_network.state_dict())
Solution (full logging)
def add_to_replay_buffer(self, n: int, verbose: bool = False):
'''
Takes n steps with the agent, adding to the replay buffer (and logging any results). Should return a dict of
data from the last terminated episode, if any.
Optional argument `verbose`: if True, we can use a progress bar (useful to check how long the initial buffer
filling is taking).
'''
data = None
t0 = time.time()
for step in tqdm(range(n), disable=not verbose, desc="Adding to replay buffer"):
infos = self.agent.play_step()
# Get data from environments, and log it if some environment did actually terminate
new_data = get_episode_data_from_infos(infos)
if new_data is not None:
data = new_data # makes sure we return a non-empty dict at the end, if some episode terminates
if self.args.use_wandb:
wandb.log(new_data, step=self.agent.step)
# Log SPS
if self.args.use_wandb:
wandb.log({"SPS": (n * self.envs.num_envs) / (time.time() - t0)}, step=self.agent.step)
return data
def prepopulate_replay_buffer(self):
'''
Called to fill the replay buffer before training starts.
'''
n_steps_to_fill_buffer = self.args.buffer_size // self.args.num_envs
self.add_to_replay_buffer(n_steps_to_fill_buffer, verbose=True)
def training_step(self, step: int) -> Float[Tensor, ""]:
'''
Samples once from the replay buffer, and takes a single training step. The `step` argument is used to track the
number of training steps taken.
'''
data = self.buffer.sample(self.args.batch_size, device) # s_t, a_t, r_{t+1}, d_{t+1}, s_{t+1}
with t.inference_mode():
target_max = self.target_network(data.next_obs).max(-1).values
predicted_q_vals = self.q_network(data.obs)[range(len(data.actions)), data.actions]
td_error = data.rewards + self.args.gamma * target_max * (1 - data.terminated.float()) - predicted_q_vals
loss = td_error.pow(2).mean()
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
if step % self.args.trains_per_target_update == 0:
self.target_network.load_state_dict(self.q_network.state_dict())
if self.args.use_wandb:
wandb.log(
{"td_loss": loss, "q_values": predicted_q_vals.mean().item(), "epsilon": self.agent.epsilon},
step=self.agent.step,
)
Here's some boilerplate code to test out your various probes, which you should make sure you're passing before testing on Cartpole.
def test_probe(probe_idx: int):
"""
Tests a probe environment by training a network on it & verifying that the value functions are
in the expected range.
"""
# Train our network on this probe env
args = DQNArgs(
env_id=f"Probe{probe_idx}-v0",
wandb_project_name=f"test-probe-{probe_idx}",
total_timesteps=3000 if probe_idx <= 2 else 5000,
learning_rate=0.001,
buffer_size=500,
use_wandb=False,
trains_per_target_update=20,
video_log_freq=None,
)
trainer = DQNTrainer(args)
trainer.train()
# Get the correct set of observations, and corresponding values we expect
obs_for_probes = [[[0.0]], [[-1.0], [+1.0]], [[0.0], [1.0]], [[0.0]], [[0.0], [1.0]]]
expected_value_for_probes = [
[[1.0]],
[[-1.0], [+1.0]],
[[args.gamma], [1.0]],
[[-1.0, 1.0]],
[[1.0, -1.0], [-1.0, 1.0]],
]
tolerances = [5e-4, 5e-4, 5e-4, 5e-4, 1e-3]
obs = t.tensor(obs_for_probes[probe_idx - 1]).to(device)
# Calculate the actual value, and verify it
value = trainer.q_network(obs)
expected_value = t.tensor(expected_value_for_probes[probe_idx - 1]).to(device)
t.testing.assert_close(value, expected_value, atol=tolerances[probe_idx - 1], rtol=0)
print("Probe tests passed!\n")
for probe_idx in range(1, 6):
test_probe(probe_idx)
Once you've passed the tests for all 5 probe environments, you should test your model on Cartpole. We recommend you start by not using wandb until you can get it running without error, because this will improve your feedback loops (however if you've passed all probe environments then there's a good chance this code will just work for you).
args = DQNArgs(use_wandb=True, steps_per_live_video=5_000)
trainer = DQNTrainer(args)
trainer.train()
Catastrophic forgetting
Note - you might see performance frequently drop off after it's achieved the maximum for a while, before eventually recovering again and repeating the cycle. Here's an example CartPole run using the solution code:

This is a well-known RL phenomena called catastrophic forgetting. It happens when the replay buffer mostly contains successful experiences, and the model forgets how to adapt or recover from bad states. One way to fix this is to change your buffer to keep 10 of experiences from previous epochs, and 90% of experiences from the current phase. Can you implement this?
When we cover PPO tomorrow, we'll also introduce reward shaping, which is another way this kind of behaviour can be mitigated.
Beyond CartPole
If things go well and your agent masters CartPole, the next harder challenges are Acrobot-v1, and MountainCar-v0. These also have discrete action spaces, which are the only type we're dealing with today. Feel free to Google for appropriate hyperparameters for these other problems - in a real RL problem you would have to do hyperparameter search using the techniques we learned on a previous day because bad hyperparameters in RL often completely fail to learn, even if the algorithm is perfectly correct.
There are many more exciting environments to play in, but generally they're going to require more compute and more optimization than we have time for today. If you finish the main material, some we recommend are:
- Minimalistic Gridworld Environments - a fast gridworld environment for experiments with sparse rewards and natural language instruction.
- microRTS - a small real-time strategy game suitable for experimentation.
- Megastep - RL environment that runs fully on the GPU (fast!)
- Procgen - A family of 16 procedurally generated gym environments to measure the ability for an agent to generalize. Optimized to run quickly on the CPU.
- Atari - although you might want to wait until tomorrow to try this on DQN, because we'll be going through some guided exercises implementing Atari with PPO tomorrow!
Some (very unpolished) code for setting up Atari with DQN
This is based on a hybrid of tomorro's agent/critic network setup for Atari, and the DQN implementation in this notebook. I've achieved decent performance in 40 mins training this, but not as good as we get when we do PPO on Atari tomorrow, so I think this is somewhat underoptimized - if anyone finds improvements then feel free to make a PR!
def layer_init(layer: nn.Linear, std=np.sqrt(2), bias_const=0.0):
t.nn.init.orthogonal_(layer.weight, std)
t.nn.init.constant_(layer.bias, bias_const)
return layer
class QNetwork(nn.Module):
layers: nn.Sequential
def __init__(self, obs_shape: tuple[int, ...], num_actions: int, hidden_sizes: list[int] = [120, 84]):
super().__init__()
assert len(obs_shape) == 3, "We're only supporting Atari for now, obs should be RGB images"
assert obs_shape[-1] % 8 == 4
L_after_convolutions = (obs_shape[-1] // 8) - 3
in_features = 64 * L_after_convolutions * L_after_convolutions
self.layers = nn.Sequential(
layer_init(nn.Conv2d(4, 32, 8, stride=4, padding=0)),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, 4, stride=2, padding=0)),
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, 3, stride=1, padding=0)),
nn.ReLU(),
nn.Flatten(),
layer_init(nn.Linear(in_features, 512)),
nn.ReLU(),
layer_init(nn.Linear(512, num_actions), std=0.01),
)
def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
args = DQNArgs(
use_wandb=True,
buffer_size=1000,
batch_size=32,
end_e=0.01,
learning_rate=1e-4,
total_timesteps=20_000,
steps_per_train=5,
mode="atari",
env_id="ALE/Breakout-v5",
wandb_project_name="DQNAtari",
num_envs=4,
)
trainer = DQNTrainer(args)
trainer.train()
Bonus
Target Network
Why have the target network? Modify the DQN code above, but this time use the same network for both the target and the Q-value network, rather than updating the target every so often.
Compare the performance of this against using the target network.
Shrink the Brain
Can DQN still learn to solve CartPole with a Q-network with fewer parameters? Could we get away with three-quarters or even half as many parameters? Try comparing the resulting training curves with a shrunken version of the Q-network. What about the same number of parameters, but with more/less layers, and less/more parameters per layer?
Dueling DQN
Implement dueling DQN according to the paper and compare its performance.

