4️⃣ Atari
Learning Objectives
- Understand how PPO can be used in visual domains, with appropriate architectures (CNNs)
- Understand the idea of policy and value heads
- Train an agent to solve the Breakout environment
Introduction
In this section, you'll extend your PPO implementation to play Atari games.
The gymnasium library supports a variety of different Atari games - you can find them here (if you get a message when you click on this link asking whether you want to switch to gymnasium, ignore this and proceed to the gym site). You can try whichever ones you want, but we recommend you stick with the easier environments like Pong, Breakout, and Space Invaders.
The environments in this game are very different. Rather than having observations of shape (4,) (representing a vector of (x, v, theta, omega)), the raw observations are now images of shape (210, 160, 3), representing pixels in the game screen. This leads to a variety of additional challenges relative to the Cartpole environment, for example:
- We need a much larger network, because finding the optimal strategy isn't as simple as solving a basic differential equation
- Reward shaping is much more difficult, because our observations are low-level and don't contain easily-accessible information about the high-level abstractions in the game (finding these abstractions in the first place is part of the model's challenge!)
The action space is also different for each environment. For example, in Breakout, the environment has 4 actions - run the code below to see this (if you get an error, try restarting the kernel and running everything again, minus the library installs).
env = gym.make("ALE/Breakout-v5", render_mode="rgb_array")
print(env.action_space) # Discrete(4): 4 actions to choose from
print(env.observation_space) # Box(0, 255, (210, 160, 3), uint8): an RGB image of the game screen
Discrete(4) Box(0, 255, (210, 160, 3), uint8)
These 4 actions are "do nothing", "fire the ball", "move right", and "move left" respectively, which you can see from:
print(env.get_action_meanings())
['NOOP', 'FIRE', 'RIGHT', 'LEFT']
You can see more details on the game-specific documentation page. On this documentation page, you can also see information like the reward for this environment. In this case, the reward comes from breaking bricks in the wall (more reward from breaking bricks higher up). This is a more challenging reward function than the one for CartPole, where a very simple strategy (move in the direction you're tipping) leads directly to a higher reward by marginally prolonging episode length.
We can also run the code below to take some random steps in our environment and animate the results:
def display_frames(frames: Int[Arr, "timesteps height width channels"], figsize=(4, 5)):
fig, ax = plt.subplots(figsize=figsize)
im = ax.imshow(frames[0])
plt.close()
def update(frame):
im.set_array(frame)
return [im]
ani = FuncAnimation(fig, update, frames=frames, interval=100)
display(HTML(ani.to_jshtml()))
nsteps = 150
frames = []
obs, info = env.reset()
for _ in tqdm(range(nsteps)):
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
frames.append(obs)
display_frames(np.stack(frames))
Playing Breakout
Just like for Cartpole and MountainCar, we're given you a Python file to play Atari games yourself. The file is called play_breakout.py, and running it (i.e. python play_breakout.py) will open up a window for you to play the game. Take note of the key instructions, which will be printed in your terminal.
You should also be able to try out other games, by changing the relevant parts of the play_breakout.py file to match those games' documentation pages.
Implementational details of Atari
The 37 Implementational Details of PPO post describes how to get PPO working for games like Atari. In the sections below, we'll go through these steps.
Wrappers (details #1-7, and #9)
All the extra details except for one are just wrappers on the environment, which implement specific behaviours. For example:
- Frame Skipping - we repeat the agent's action for a number of frames (by default 4), and sum the reward over these frames. This saves time when the model's forward pass is computationally cheaper than an environment step.
- Image Transformations - we resize the image from
(210, 160)to(L, L)for some smaller valueL(in this case we'll use 84), and convert it to grayscale.
We've written some environment wrappers for you (and imported some others from the gymnasium library), combining them all together into the prepare_atari_env function in the part3_ppo/utils.py file. You can have a read of this and see how it works, but since we're implementing these for you, you won't have to worry about them too much.
The code below visualizes the results of them (with the frames stacked across rows, so we can see them all at once). You might want to have a think about how the kind of information your actor & critic networks are getting here, and how this might make the RL task easier.
env_wrapped = prepare_atari_env(env)
frames = []
obs, info = env_wrapped.reset()
for _ in tqdm(range(nsteps)):
action = env_wrapped.action_space.sample()
obs, reward, terminated, truncated, info = env_wrapped.step(action)
obs = einops.repeat(np.array(obs), "frames h w -> h (frames w) 3") # stack frames across the row
frames.append(obs)
display_frames(np.stack(frames), figsize=(12, 3))
Shared CNN for actor & critic (detail #8)
This is the most interesting one conceptually. If we have a new observation space then it naturally follows that we need a new architecture, and if we're working with images then using a convolutional neural network is reasonable. But another particularly interesting feature here is that we use a shared architecture for the actor and critic networks. The idea behind this is that the early layers of our model extract features from the environment (i.e. they find the high-level abstractions contained in the image), and then the actor and critic components perform feature extraction to turn these features into actions / value estimates respectively. This is commonly referred to as having a policy head and a value head. We'll see this idea come up later, when we perform RL on transformers.
Exercise - rewrite get_actor_and_critic
The function get_actor_and_critic had a boolean argument atari, which we ignored previously, but which we'll now return to. When this argument is False then the function should behave exactly as it did before (i.e. the Cartpole version), but when True then it should return a shared CNN architecture for the actor and critic. The architecture should be as follows (you can open it in a new tab if it's hard to see clearly):
Note - when calculating the number of input features for the linear layer, you can assume that the value L is 4 modulo 8, i.e. we can write L = 8m + 4 for some integer m. This will make the convolutions easier to track. You shouldn't hardcode the number of input features assuming an input shape of (4, 84, 84); this is bad practice!
We leave the exercise of finding the number of input features to the linear layer as a challenge for you. If you're stuck, you can find a hint in the section below (this isn't a particularly conceptually important detail).
Help - I don't know what the number of inputs for the first linear layer should be.
You can test this empirically by just doing a forward pass through the first half of the network and seeing what the shape of the output is.
Alternatively, you can use the convolution formula. There's never any padding, so for a conv with parameters (size, stride), the dimensions change as L -> 1 + (L - size) // stride (see the documentation page). So we have:
8m+4 -> 1 + (8m-4)//4 = 2m
2m -> 1 + (2m-4)//2 = m-1
m-1 -> 1 + (m-4)//1 = m-3
For instance, if L = 84 then m = 10 and L_new = m-3 = 7. So the linear layer is fed 64 features of shape (64, 7, 7).
Now, you can fill in the get_actor_and_critic_atari function below, which is called when we call get_actor_and_critic with mode == "atari".
Note that we take the observation shape as argument, not the number of observations. It should be (4, L, L) as indicated by the diagram. The shape (4, L, L) is a reflection of the fact that we're using 4 frames of history per input (which helps the model calculate things like velocity), and each of these frames is a monochrome resized square image.
def get_actor_and_critic_atari(obs_shape: tuple[int,], num_actions: int) -> tuple[nn.Sequential, nn.Sequential]:
"""
Returns (actor, critic) in the "atari" case, according to diagram above.
"""
assert obs_shape[-1] % 8 == 4
raise NotImplementedError()
tests.test_get_actor_and_critic(get_actor_and_critic, mode="atari")
Solution
def get_actor_and_critic_atari(obs_shape: tuple[int,], num_actions: int) -> tuple[nn.Sequential, nn.Sequential]:
"""
Returns (actor, critic) in the "atari" case, according to diagram above.
"""
assert obs_shape[-1] % 8 == 4
L_after_convolutions = (obs_shape[-1] // 8) - 3
in_features = 64 * L_after_convolutions * L_after_convolutions
hidden = nn.Sequential(
layer_init(nn.Conv2d(4, 32, 8, stride=4, padding=0)),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, 4, stride=2, padding=0)),
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, 3, stride=1, padding=0)),
nn.ReLU(),
nn.Flatten(),
layer_init(nn.Linear(in_features, 512)),
nn.ReLU(),
)
actor = nn.Sequential(hidden, layer_init(nn.Linear(512, num_actions), std=0.01))
critic = nn.Sequential(hidden, layer_init(nn.Linear(512, 1), std=1))
return actor, critic
Training Atari
Now, you should be able to run an Atari training loop!
We recommend you use the following parameters, for fidelity:
args = PPOArgs(
env_id="ALE/Breakout-v5",
wandb_project_name="PPOAtari",
use_wandb=True,
mode="atari",
clip_coef=0.1,
num_envs=8,
video_log_freq=25,
)
trainer = PPOTrainer(args)
trainer.train()
Note that this will probably take a lot longer to train than your previous experiments, because the architecture is much larger, and finding an initial strategy is much harder than it was for CartPole. Don't worry if it starts off with pretty bad performance (on my machine the code above takes about 40 minutes to run, and I only start seeing any improvement after about the 5-10 minute mark, or approx 70k total agent steps). You can always experiment with different methods to try and boost performance early on, like an entroy bonus which is initially larger then decays (analogous to our epsilon scheduling in DQN, which would reduce the probability of exploration over time).
Here is a video produced from a successful run, using the parameters above:
and here's the corresponding plot of episodic returns (with episoic lengths following a similar pattern):

A note on debugging crashed kernels
This section is more relevant if you're doing these exercises on VSCode; you can skip it if you're in Colab.
Because the gymnasium library is a bit fragile, sometimes you can get uninformative kernel errors like this:

which annoyingly doesn't tell you much about the nature or location of the error. When this happens, it's often good practice to replace your code with lower-level code bit by bit, until the error message starts being informative.
For instance, you might start with trainer.train(), and if this fails without an informative error message then you might try replacing this function call with the actual contents of the train function (which should involve the methods trainer.rollout_phase() and trainer.learning_phase()). If the problem is in rollout_phase, you can again replace this line with the actual contents of this method.
If you're working in .py files rather than .ipynb, a useful tip - as well as running Shift + Enter to run the cell your cursor is in, if you have text highlighted (and you've turned on Send Selection To Interactive Window in VSCode settings) then using Shift + Enter will run just the code you've highlighted. This could be a single variable name, a single line, or a single block of code.