[2.2] - DQN & VPG
Please send any problems / bugs on the #errata channel in the Slack group, and ask any questions on the dedicated channels for this chapter of material.
If you want to change to dark mode, you can do this by clicking the three horizontal lines in the top-right, then navigating to Settings → Theme.
Links to all other chapters: (0) Fundamentals, (1) Transformer Interpretability, (2) RL.

Introduction
In this section, you'll implement Deep Q-Learning, often referred to as DQN for "Deep Q-Network". This was used in a landmark paper Playing Atari with Deep Reinforcement Learning.
You'll also implement Vanilla Policy Gradient (VPG), the first policy gradient algorithm upon which many modern RL algorithms are based (including PPO).
Content & Learning Objectives
1️⃣ DQN
In this section, you'll implement Deep Q-Learning, often referred to as DQN for "Deep Q-Network". This was used in a landmark paper Playing Atari with Deep Reinforcement Learning.
You'll apply the technique of DQN to master the famous CartPole environment (below), and then (if you have time) move on to harder challenges like Acrobot and MountainCar.
Learning Objectives
- Understand the DQN algorithm
- Learn more about RL debugging, and build probe environments to debug your agents
- Create a replay buffer to store environment transitions
- Implement DQN using PyTorch, on the CartPole environment
2️⃣ VPG
The Policy Gradient Theorem is what all policy gradient methods are based on: it allows us to compute the gradient of the return, something that would naively not have a well defined gradient. We'll then implement Vanilla Policy Gradient (VPG) on the CartPole environment.
Learning Objectives
- Understand the Policy Gradient Theorem
- Understand the VPG algorithm: how to perform on-policy policy gradient
- Implement VPG using PyTorch, on the CartPole environment
Setup (don't read, just run!)
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import sys
import time
import warnings
from collections import namedtuple
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import gymnasium as gym
import numpy as np
import torch as t
import torch.nn.functional as F
import wandb
from eindex import eindex
from gpu_env import CartPole
from gymnasium.spaces import Box, Discrete
from jaxtyping import Bool, Float, Int
from torch import Tensor, nn
from torchinfo import summary
from tqdm import tqdm
warnings.filterwarnings("ignore")
Arr = np.ndarray
ActType = Int
ObsType = Int
# Make sure exercises are in the path
chapter = "chapter2_rl"
section = "part2_q_learning_and_policy_gradient"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
sys.path.append(str(exercises_dir))
import part2_q_learning_and_policy_gradient.tests as tests
import part2_q_learning_and_policy_gradient.utils as utils
from part1_intro_to_rl.utils import set_global_seeds
from part2_q_learning_and_policy_gradient.probe import Probe4, Probe5
from part2_q_learning_and_policy_gradient.utils import make_env
from plotly_utils import line, plot_cartpole_obs_and_dones
from rl_utils import generate_and_plot_trajectory, make_env
device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")
MAIN = __name__ == "__main__"