Exercise Status: All exercises complete and verified

[2.3] - PPO

Colab: exercises | solutions

Please send any problems / bugs on the #errata channel in the Slack group, and ask any questions on the dedicated channels for this chapter of material.

If you want to change to dark mode, you can do this by clicking the three horizontal lines in the top-right, then navigating to Settings → Theme.

Links to all other chapters: (0) Fundamentals, (1) Transformer Interpretability, (2) RL.

Introduction

Proximal Policy Optimization (PPO) is a cutting-edge reinforcement learning algorithm that has gained significant attention in recent years. As an improvement over traditional policy optimization methods, PPO addresses key challenges such as sample efficiency, stability, and robustness in training deep neural networks for reinforcement learning tasks. With its ability to strike a balance between exploration and exploitation, PPO has demonstrated remarkable performance across a wide range of complex environments, including robotics, game playing, and autonomous control systems.

In this section, you'll build your own agent to perform PPO on the CartPole environment. By the end, you should be able to train your agent to near perfect performance in about 30 seconds. You'll also be able to try out other things like reward shaping, to make it easier for your agent to learn to balance, or to do fun tricks! There are also additional exercises which allow you to experiment with other tasks, including Atari and the 3D physics engine MuJoCo.

A lot of the setup as we go through these exercises will be similar to what we did yesterday for DQN, so you might find yourself moving quickly through certain sections.

For a lecture on the material today, which provides some high-level understanding before you dive into the material, watch the video below:

Content & Learning Objectives

0️⃣ Whirlwind Tour of PPO

In this non-exercise-based section, we discuss some of the mathematical intuitions underpinning PPO. It's not compulsory to go through all of it (and various recommended reading material / online lectures may provide better alternatives), although we strongly recommend everyone to at least read the summary boxes at the end of each subsection.

Learning Objectives
  • Understand the mathematical intuitions of PPO
  • Learn how expressions like the PPO objective function are derived

1️⃣ Setting up our agent

We'll start by building up most of our PPO infrastructure. Most importantly, this involves creating our actor & critic networks and writing methods using both of them which take steps in our environment. The result will be a PPOAgent and ReplayMemory class, analogous to our DQNAgent and ReplayBuffer from yesterday.

Learning Objectives
  • Understand the difference between the actor & critic networks, and what their roles are
  • Learn about & implement generalised advantage estimation
  • Build a replay memory to store & sample experiences
  • Design an agent class to step through the environment & record experiences

2️⃣ Learning Phase

The PPO objective function is considerably more complex than DQN and involves a lot of moving parts. In this section we'll go through each of those parts one by one, understanding its role and how to implement it.

Learning Objectives
  • Implement the total objective function (sum of three separate terms)
  • Understand the importance of each of these terms for the overall algorithm
  • Write a function to return an optimizer and learning rate scheduler for your model

3️⃣ Training Loop

Lastly, we'll assemble everything together into a PPOTrainer class just like our DQNTrainer class from yesterday, and use it to train on CartPole. We can also go further than yesterday by using reward shaping to fast-track our agent's learning trajectory.

Learning Objectives
  • Build a full training loop for the PPO algorithm
  • Train our agent, and visualise its performance with Weights & Biases media logger
  • Use reward shaping to improve your agent's training (and make it do tricks!)

4️⃣ Atari

Now that we've got training working on CartPole, we'll extend to the more complex environment of Atari. There are no massively new concepts in this section, although we do have to deal with a very different architecture that takes into account the visual structure of our observations (Atari frames), in particular this will also require a shared architecture between the actor & critic networks.


Learning Objectives
  • Understand how PPO can be used in visual domains, with appropriate architectures (CNNs)
  • Understand the idea of policy and value heads
  • Train an agent to solve the Breakout environment

5️⃣ MuJoCo

The last new set of environments we'll look at is MuJoCo. This is a 3D physics engine, which you might be familiar with in the context of OpenAI's famous backflipping noodle which laid the background for RLHF (see tomorrow for more on this!). The most important concept MuJoCo introduces for us is the idea of a continuous action space, where actions aren't chosen discretely from a set of finite options but are sampled from some probability distribution (in this case, a parameterized normal distribution). This is one setting that PPO can work in, but DQN can't.


Learning Objectives
  • Understand how PPO can be used to train agents in continuous action spaces
  • Install and interact with the MuJoCo physics engine
  • Train an agent to solve the Hopper environment

☆ Bonus

We conclude with a set of optional bonus exercises, which you can try out before moving on to the RLHF sections.

Notes on today's workflow

Your implementation might get good benchmark scores by the end of the day, but don't worry if it struggles to learn the simplest of tasks. RL can be frustrating because the feedback you get is extremely noisy: the agent can fail even with correct code, and succeed with buggy code. Forming a systematic process for coping with the confusion and uncertainty is the point of today, more so than producing a working PPO implementation.

Some parts of your process could include:

  • Forming hypotheses about why it isn't working, and thinking about what tests you could write, or where you could set a breakpoint to confirm the hypothesis.
  • Implementing some of the even more basic gymnasium environments and testing your agent on those.
  • Getting a sense for the meaning of various logged metrics, and what this implies about the training process
  • Noticing confusion and sections that don't make sense, and investigating this instead of hand-waving over it.

Readings

In section 0️⃣, we've included a whirlwind tour of PPO which is specifically tailored to today's exercises. Going through the entire thing isn't required (since it can get quite mathematically dense), but we strongly recommend everyone at least read the summary boxes at the end of each subsection. Many of the resources listed below are also useful, but they don't cover everything which is specifically relevant to today's exercises.

If you find this section sufficient then you can move on to the exercises, if not then other strongly recommended reading includes:

  • An introduction to Policy Gradient methods - Deep RL (20 mins)
    • This is a useful video which motivates the core setup of PPO (and in particular the clipped objective function) without spending too much time with the precise derivations. We recommend watching this video before doing the exercises.
    • Note - you can ignore the short section on multi-GPU setup.
    • Also, near the end the video says that PPO outputs parameters $\mu$ and $\sigma$ from which actions are sampled, this is true for non-discrete action spaces (which we'll be using later on) but we'll start by implementing PPO on CartPole meaning our observation and action space is discrete just like yesterday.
  • The 37 Implementation Details of Proximal Policy Optimization
    • This is not required reading before the exercises, but *it will be a useful reference point as you go through the exercises- (and it's also a useful thing to take away from the course as a whole, since your future work in RL will likely be less guided than these exercises).
    • The good news is that you won't need all 37 of these today, so no need to read to the end.
    • We will be tackling the 13 "core" details, not in the same order as presented here. Some of the sections below are labelled with the number they correspond to in this page (e.g. Minibatch Update (detail #6)).
  • Proximal Policy Optimization Algorithms
    • This is not required reading before the exercises, but it will be a useful reference point for many of the key equations as you go through the exercises. In particular, you will find up to page 5 useful.

Optional Reading

Setup code

import itertools
import os
import sys
import time
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import einops
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch as t
import torch.nn as nn
import torch.optim as optim
import wandb
from IPython.display import HTML, display
from jaxtyping import Bool, Float, Int
from matplotlib.animation import FuncAnimation
from numpy.random import Generator
from torch import Tensor
from torch.distributions.categorical import Categorical
from torch.optim.optimizer import Optimizer
from tqdm import tqdm

warnings.filterwarnings("ignore")

# Make sure exercises are in the path
chapter = "chapter2_rl"
section = "part3_ppo"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
    sys.path.append(str(exercises_dir))

import part3_ppo.tests as tests
from part1_intro_to_rl.utils import set_global_seeds
from part3_ppo.utils import arg_help
from part21_dqn.solutions import (
    Probe1,
    Probe2,
    Probe3,
    Probe4,
    Probe5,
    get_episode_data_from_infos,
)
from plotly_utils import plot_cartpole_obs_and_dones
from rl_utils import make_env, prepare_atari_env

# Register our probes from last time
for idx, probe in enumerate([Probe1, Probe2, Probe3, Probe4, Probe5]):
    gym.envs.registration.register(id=f"Probe{idx + 1}-v0", entry_point=probe)

Arr = np.ndarray

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")