2️⃣ VPG
Learning Objectives
- Understand the Policy Gradient Theorem
- Understand the VPG algorithm: how to perform on-policy policy gradient
- Implement VPG using PyTorch, on the CartPole environment
Policy Gradient Theorem
Instead of learning action-values and deriving a policy (as in Q-learning or DQN), policy gradient methods learn the policy directly.
- Policy is parameterized: $\pi_\theta(a|s)$ with parameters $\theta$ (often a neural network).
- Objective: Choose $\theta$ to maximize expected return $J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[G(\tau)]$ (joy), where $\tau$ is a trajectory and $G(\tau)$ its return.
We would desire to update the policy directly via gradient ascent against $J(\theta)$:
The problem is that the return is a sum of rewards from the trajectory, and the trajectory itself is a result of sampling from the policy, over and over, as well as being dependant on the environmental distribution, which we do not have access to. There is no clear way to directly compute the gradient of the return with respect to the policy parameters. The solution here is the policy gradient theorem, which states that we can instead use the log-probability weighted return as an unbiased estimator of the gradient of the return.
Derivation
The probability of sampling a trajectory $\tau = (s_0, a_0, s_1, a_1, \dots, s_T)$ is given by
The dynamics $\mu$ do not depend on $\theta$, so:
Thus:
Plugging back into the gradient:
This is the Vanilla Policy Gradient estimator, also called REINFORCE. Each $\log \pi_\theta(a_t|s_t)$ is multiplied by the full return $G(\tau)$. However, the action $a_t$ cannot influence rewards before time $t$, only those afterwards. This means that all the rewards before timestep $t$ merely add noise, as no changes to the policy can affect them. To reduce variance, replace $G(\tau)$ with the return $G_t$ at timestep $t$, also called the reward-to-go:
Thus, the lower-variance unbiased estimator is:
There are many other variants of the policy gradient estimator, as described in Schulman, 2018.

Implementation
We make use of the same CartPole environment as before, but now we have a vectorized version that is entirely defined in terms of tensor operations (see chapter2_rl/exercises/gpu_env.py). This environment is identical to the one used for DQN, but it now runs entirely on the GPU. This means
* we don't need to constantly convert between numpy and torch tensors
* we can run large numbers of environments in parallel (~thousands of environments for ~millions of environmental steps per second)
* we avoid copying data back and forth between the CPU and GPU, which can be a significant bottleneck
Policy Network
Here, the policy is learned directly as a neural network, rather than learning a Q-value table approximator. We'll use the same architecture as the Q-network from DQN, so we've just included that here for you.
class PolicyNetwork(nn.Module):
"""
For consistency with your tests, please wrap your modules in a `nn.Sequential` called `layers`.
"""
layers: nn.Sequential
def __init__(self, obs_shape: tuple[int], num_actions: int, hidden_sizes: list[int] = [120, 84]):
super().__init__()
# assert len(obs_shape) == 1, f"Expecting a single vector of observations, got {obs_shape}"
assert len(hidden_sizes) == 2, f"Expecting 2 hidden layers, got {len(hidden_sizes)}"
self.layers = nn.Sequential(
nn.Linear(obs_shape[-1], hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], num_actions),
)
def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
net = PolicyNetwork(obs_shape=(4,), num_actions=2)
summary(net)
Rollout Buffer
The way that our implementation of VPG will work is simple: we perform a rollout acrosss num_envs many environments in parallel, and store the trajectories for each. We then learn from that set of rollouts, and then discard it afterwards. One rollout, one learning step. This means we are always learning on-policy: we only every learn from data that the current model actually generated. We will use a rollout buffer to store the trajectories.
Exercise - implement Rollout Buffer
The Rollout class will store a set of num_envs many trajectories. WE do not shuffle up anything, or break up a episode into little experiences as we did for DQN. The smallest datapoint is one full trajectory:
The following methods need to be completed:
__init__- initializes the rollout buffer, to storeobs,actions,logprobs,rewards,dones,entropy,infos,timestep.add_step- adds information gathered from timestep $t$ to the rollout bufferget_batches- returns a list ofRolloutTensorsobjects, each containingbatch_sizemany trajectories.
Hint
Use t.split to write get_batches.
RolloutTensors = namedtuple("RolloutTensors", ["obs", "actions", "logprobs", "rewards", "dones"])
class Rollout:
obs: Float[Tensor, " num_envs max_size *obs_shape"]
actions: Int[Tensor, " num_envs max_size *action_shape"]
logprobs: Float[Tensor, " num_envs max_size"]
rewards: Float[Tensor, " num_envs max_size"]
dones: Bool[Tensor, " num_envs max_size"]
infos: dict[str, Any]
timestep: int
def __init__(
self, num_envs: int, max_steps: int, obs_shape: tuple[int], action_shape: tuple[int], device: t.device
):
"""
Args:
num_envs: number of environments to rollout
max_steps: maximum number of steps to rollout per environment
obs_shape: shape of the observation
action_shape: shape of the action
device: device to use
"""
self.MAX_SIZE = max_steps
# self.max_rollout_steps = args.max_rollout_steps
# self.min_rollout_steps = args.min_rollout_steps
raise NotImplementedError()
self.tensors = RolloutTensors(self.obs, self.actions, self.logprobs, self.rewards, self.dones)
def add_step(
self,
obs: Float[Tensor, " num_envs *obs_shape"],
actions: Int[Tensor, " num_envs *action_shape"],
logprobs: Float[Tensor, " num_envs"],
rewards: Float[Tensor, " num_envs"],
dones: Bool[Tensor, " num_envs"],
infos: dict[str, Any],
):
"""
Adds information to the replay buffer for the current self.timestep
Don't forget to increment self.timestep afterwards!
"""
if self.timestep >= self.MAX_SIZE:
raise ValueError("Rollout is full, cannot add more steps")
raise NotImplementedError()
def reset(self):
self.timestep = 0
def get(self) -> tuple[Tensor, ...]:
assert self.timestep == self.MAX_SIZE, "Rollout is not full"
return self.tensors
def get_batches(self, batch_size: int) -> list[RolloutTensors]:
"""
Splits the rollout buffer into batches of size `batch_size`, and returns a list of
`RolloutTensors` objects, each containing `batch_size` many trajectories.
"""
raise NotImplementedError()
Solution
RolloutTensors = namedtuple("RolloutTensors", ["obs", "actions", "logprobs", "rewards", "dones"])
class Rollout:
obs: Float[Tensor, " num_envs max_size *obs_shape"]
actions: Int[Tensor, " num_envs max_size *action_shape"]
logprobs: Float[Tensor, " num_envs max_size"]
rewards: Float[Tensor, " num_envs max_size"]
dones: Bool[Tensor, " num_envs max_size"]
infos: dict[str, Any]
timestep: int
def __init__(
self, num_envs: int, max_steps: int, obs_shape: tuple[int], action_shape: tuple[int], device: t.device
):
"""
Args:
num_envs: number of environments to rollout
max_steps: maximum number of steps to rollout per environment
obs_shape: shape of the observation
action_shape: shape of the action
device: device to use
"""
self.MAX_SIZE = max_steps
# self.max_rollout_steps = args.max_rollout_steps
# self.min_rollout_steps = args.min_rollout_steps
self.obs = t.empty([num_envs, self.MAX_SIZE, *obs_shape], dtype=t.float32, device=device)
self.actions = t.empty([num_envs, self.MAX_SIZE, *action_shape], dtype=t.int64, device=device)
self.logprobs = t.empty([num_envs, self.MAX_SIZE], dtype=t.float32, device=device)
self.rewards = t.empty([num_envs, self.MAX_SIZE], dtype=t.float32, device=device)
self.dones = t.empty([num_envs, self.MAX_SIZE], dtype=t.bool, device=device)
self.infos = {}
self.timestep = 0
self.tensors = RolloutTensors(self.obs, self.actions, self.logprobs, self.rewards, self.dones)
def add_step(
self,
obs: Float[Tensor, " num_envs *obs_shape"],
actions: Int[Tensor, " num_envs *action_shape"],
logprobs: Float[Tensor, " num_envs"],
rewards: Float[Tensor, " num_envs"],
dones: Bool[Tensor, " num_envs"],
infos: dict[str, Any],
):
"""
Adds information to the replay buffer for the current self.timestep
Don't forget to increment self.timestep afterwards!
"""
if self.timestep >= self.MAX_SIZE:
raise ValueError("Rollout is full, cannot add more steps")
self.obs[:, self.timestep] = obs
self.actions[:, self.timestep] = actions
self.logprobs[:, self.timestep] = logprobs
self.rewards[:, self.timestep] = rewards
self.dones[:, self.timestep] = dones
self.infos[self.timestep] = infos
self.timestep += 1
def reset(self):
self.timestep = 0
def get(self) -> tuple[Tensor, ...]:
assert self.timestep == self.MAX_SIZE, "Rollout is not full"
return self.tensors
def get_batches(self, batch_size: int) -> list[RolloutTensors]:
"""
Splits the rollout buffer into batches of size `batch_size`, and returns a list of
`RolloutTensors` objects, each containing `batch_size` many trajectories.
"""
obs = t.split(self.obs, batch_size, dim=0)
acts = t.split(self.actions, batch_size, dim=0)
logprobs = t.split(self.logprobs, batch_size, dim=0)
rewards = t.split(self.rewards, batch_size, dim=0)
dones = t.split(self.dones, batch_size, dim=0)
batches = [RolloutTensors(*tensors) for tensors in zip(obs, acts, logprobs, rewards, dones)]
return batches
VPG Args
We've provided a dataclass for the training arguments, and will explain as needed later on.
@dataclass
class VPGArgs:
# Basic / global
seed: int = 1
env_id: str = "CartPole-gpu"
# Wandb / logging
use_wandb: bool = False
wandb_project_name: str = "VPGCartPole"
wandb_entity: str | None = None
video_log_freq: int | None = 50
# Duration of different phases / buffer memory settings
total_timesteps: int = 500_000
# max_rollout_steps: int = 500
# min_rollout_steps: int = 64
num_envs: int = 4
num_steps_per_rollout: int = 128
lr: float = 2.5e-4
gamma: float = 1
frac_dead_rollout: float = 1
ent_coef: float = 0.01
max_grad_norm: float = 0.5
rollout_use_count: int = 4
num_minibatches: int = 4
clip_coef: float = 0.2
compile: bool = False
device: str = "cpu"
normalize_returns: bool = True
show_probs: bool = False
num_batches_per_rollout: int = 1
# LR decay settings
use_lr_decay: bool = False
lr_end: Optional[float] = None
lr_frac: Optional[float] = None
use_iw: bool = False
def __post_init__(self):
self.batch_size = self.num_envs // self.num_batches_per_rollout
self.device = t.device(self.device)
if self.use_lr_decay:
assert self.lr_end is not None, "lr_end must be set if use_lr_decay is True"
assert self.lr_frac is not None, "lr_frac must be set if use_lr_decay is True"
self.env_steps_per_update = self.num_steps_per_rollout * self.num_envs // self.num_batches_per_rollout
if not self.use_iw:
assert self.rollout_use_count == 1, "rollout_use_count must be 1 if use_iw is False"
assert self.num_batches_per_rollout == 1, "num_batches_per_rollout must be 1 if use_iw is False"
VPG Agent
The following class will be out agent, that will generate rollouts via interaction between the agent and environment, as well as generate actions my sampling them from the policy network. Recall that the policy network now maps observations to logits for each action, so we can sample actions from the distribution.
Exercise - implement VPGAgent
Implement the functions:
* gen_rollout - this function compute the episode rollout, by interacting with the environment for args.num_steps_per_rollout steps. If an episode terminates, we reset the environment and continue. We will track the length of the episode in the lifespan variable, which indicates how long each episode runs before termination. FOr the cartpole environment, this will allow us to track performance (the longer the cart lives, the better it does.)
get_actions- this function takes in an observation, and returns the actions, logprobs, and entropy for that observation. You can uset.distributions.Categorical(logits=logits)to construct a distribution, from which you can get the actions, logprobs, and entropy. See the docs for details.
class VPGAgent:
"""Base Agent class handling the interaction with the environment."""
dead: Bool[Tensor, " num_envs"]
lifespan: Int[Tensor, " num_envs"]
def __init__(
self,
envs: gym.Env,
policy_network: PolicyNetwork,
args: VPGArgs,
rng: Optional[np.random.Generator] = None,
):
self.envs = envs
self.policy_network = policy_network
self.rng = rng
self.args = args
self.obs_shape = envs.observation_space.shape
self.action_shape = envs.action_space.shape
@t.no_grad()
def gen_rollout(self, rollout: Rollout) -> tuple[Rollout, dict[str, Any]]:
"""
Compute the full episode rollout for all environments in parallel, adding them to the rollout buffer.
It then returns the rollout buffer, and a dictionary of info contining the lifespan.
Returns `infos` (list of dictionaries containing info we will log).
"""
obs, _ = self.envs.reset() # Need a starting observation
device = self.args.device
dead = t.zeros(self.args.num_envs, dtype=t.bool, device=device)
lifespan = t.zeros(self.args.num_envs, dtype=t.int32, device=device)
rollout.reset()
raise NotImplementedError()
info = {"lifespan": lifespan}
return rollout, info
def get_actions(
self, obs: Float[Tensor, " num_envs *obs_shape"]
) -> tuple[Int[Tensor, " num_envs *action_shape"], Float[Tensor, " num_envs"], Float[Tensor, " num_envs"]]:
"""
Computes the agents turn: given an observation for eahc environment,
sample the action the agent takes, along with the log_probs of that action,
and the entropy of the action distribution.
"""
raise NotImplementedError()
Solution
class VPGAgent:
"""Base Agent class handling the interaction with the environment."""
dead: Bool[Tensor, " num_envs"]
lifespan: Int[Tensor, " num_envs"]
def __init__(
self,
envs: gym.Env,
policy_network: PolicyNetwork,
args: VPGArgs,
rng: Optional[np.random.Generator] = None,
):
self.envs = envs
self.policy_network = policy_network
self.rng = rng
self.args = args
self.obs_shape = envs.observation_space.shape
self.action_shape = envs.action_space.shape
@t.no_grad()
def gen_rollout(self, rollout: Rollout) -> tuple[Rollout, dict[str, Any]]:
"""
Compute the full episode rollout for all environments in parallel, adding them to the rollout buffer.
It then returns the rollout buffer, and a dictionary of info contining the lifespan.
Returns `infos` (list of dictionaries containing info we will log).
"""
obs, _ = self.envs.reset() # Need a starting observation
device = self.args.device
dead = t.zeros(self.args.num_envs, dtype=t.bool, device=device)
lifespan = t.zeros(self.args.num_envs, dtype=t.int32, device=device)
rollout.reset()
for timestep in range(self.args.num_steps_per_rollout):
actions, logprobs, entropy = self.get_actions(obs)
new_obs, rewards, terminates, _, info = self.envs.step(actions)
done = terminates
rollout.add_step(obs, actions, logprobs, rewards, done, info)
obs = new_obs
dead = dead | done
lifespan += ~dead
info = {"lifespan": lifespan}
return rollout, info
def get_actions(
self, obs: Float[Tensor, " num_envs *obs_shape"]
) -> tuple[Int[Tensor, " num_envs *action_shape"], Float[Tensor, " num_envs"], Float[Tensor, " num_envs"]]:
"""
Computes the agents turn: given an observation for eahc environment,
sample the action the agent takes, along with the log_probs of that action,
and the entropy of the action distribution.
"""
logits = self.policy_network(obs)
dist = t.distributions.Categorical(logits=logits)
actions = dist.sample()
entropy = dist.entropy()
logprobs = dist.log_prob(actions)
return actions, logprobs, entropy
Returns
To compute the REINFORCE loss, we need to compute the return for each step in the trajectory. This gets a little messy as trajectories may be of different lengths, so an episode may have terminated part way through the rollout. You'll need to walk backward through the trajectory, and compute the return for each step.
Exercise - implement compute_returns
Compute the returns for each trajectory. Easiest to write as a simple reverse for-loop for now, though if you wish later on you can try a vectorized solution.
def compute_returns(
rewards: Float[Tensor, " num_envs num_steps"], done: Bool[Tensor, " num_envs num_steps"], gamma: float = 0.9
):
"""
ARGS:
rewards: The rewards for each trajectory
done: A boolean tensor indicating if an episode finished on the current timestep
gamma: The discount factor
Returns:
The returns G_t for each trajectory.
For example:
- If Rewards = [0, 0, 1, 0, 1]
- And Done = [0, 0, 1, 0, 1]
- Then Returns = [g**2 + g + 1, g + 1, 1, g, 1]
"""
num_envs, num_steps = rewards.shape
returns = t.zeros_like(rewards)
raise NotImplementedError()
tests.test_compute_returns(compute_returns)
Solution
def compute_returns(
rewards: Float[Tensor, " num_envs num_steps"], done: Bool[Tensor, " num_envs num_steps"], gamma: float = 0.9
):
"""
ARGS:
rewards: The rewards for each trajectory
done: A boolean tensor indicating if an episode finished on the current timestep
gamma: The discount factor
Returns:
The returns G_t for each trajectory.
For example:
- If Rewards = [0, 0, 1, 0, 1]
- And Done = [0, 0, 1, 0, 1]
- Then Returns = [g**2 + g + 1, g + 1, 1, g, 1]
"""
num_envs, num_steps = rewards.shape
returns = t.zeros_like(rewards)
G = t.zeros_like(rewards[:, 0]) # (num_envs)
for i in reversed(range(num_steps)):
G = rewards[:, i] + gamma * G * (~done[:, i])
returns[:, i] = G
return returns
tests.test_compute_returns(compute_returns)
Exercise - implement compute_logprobs_and_entropy
Computes the logprobs of actions taken, and the entropy of the action distribution on each timestep. Needed for the loss function.
def compute_logprobs_and_entropy(
tau: RolloutTensors, pi: PolicyNetwork
) -> tuple[Float[Tensor, " num_envs num_steps"], Float[Tensor, " num_envs num_steps"]]:
"""
Computes the logprobs and entropy of the action distribution on each timestep.
"""
raise NotImplementedError()
Solution
def compute_logprobs_and_entropy(
tau: RolloutTensors, pi: PolicyNetwork
) -> tuple[Float[Tensor, " num_envs num_steps"], Float[Tensor, " num_envs num_steps"]]:
"""
Computes the logprobs and entropy of the action distribution on each timestep.
"""
logits = pi(tau.obs)
log_probs = F.log_softmax(logits, dim=-1)
log_probs_taken = eindex(log_probs, tau.actions, "env time [env time] -> env time")
probs_taken = log_probs_taken.exp()
entropy = -(probs_taken * log_probs_taken).sum(dim=-1)
return log_probs_taken, entropy
Building up to the loss function
We need to compute the probability ratio $\pi(a_t | s_t) / \pi_{old}(a_t | s_t)$ for each timestep taken in the rollout. This is used to compute the importance weights $\text{iw}_t$, which allows us to learn off-policy. If args.clip_coef is not none, we also clamp the importance weights between 1 - args.clip_coef and 1 + args.clip_coef.
Exercise - implement compute_importance_weights
Keep the result numerically stable by exponentiating the difference between the logprobs.
Gradietns should NOT flow through the importance weights. Make sure to use .detach() to prevent this.
def compute_importance_weights(logprobs_taken, tau: RolloutTensors, clip_coef: Optional[float]) -> t.Tensor:
raise NotImplementedError()
Solution
def compute_importance_weights(logprobs_taken, tau: RolloutTensors, clip_coef: Optional[float]) -> t.Tensor:
iw = t.exp(logprobs_taken - tau.logprobs).detach() # Detach to prevent gradient flow
if clip_coef is not None:
iw = t.clamp(iw, 1 - clip_coef, 1 + clip_coef)
return iw
Exercise - implement normalize_returns
Normalize the returns by ensuring zero mean, unit variance across all trajectories and timesteps. Don't overthink this one, should be a one-liner. Don't worry about having the episode inside or outside the square root in the denominator. Doesn't really matter.
def normalize_returns(returns: Float[Tensor, " num_envs num_steps"]) -> Float[Tensor, " num_envs num_steps"]:
"""
Normalizes the returns by ensuring zero mean, unit variance across all trajectories and timesteps.
"""
raise NotImplementedError()
Solution
def normalize_returns(returns: Float[Tensor, " num_envs num_steps"]) -> Float[Tensor, " num_envs num_steps"]:
"""
Normalizes the returns by ensuring zero mean, unit variance across all trajectories and timesteps.
"""
return (returns - returns.mean()) / (returns.std() + 1e-8)
Exercise - implement compute_reinforce_loss
This should be easy with everything else you've got. The loss on timestep $t$ is
The total loss is the mean of the losses over all timesteps, over all trajectories.
def compute_reinforce_loss(
returns: Float[Tensor, " num_envs num_steps"],
logprobs_taken: Float[Tensor, " num_envs num_steps"],
iw: Float[Tensor, " num_envs num_steps"],
) -> Float[Tensor, ""]:
raise NotImplementedError()
Solution
def compute_reinforce_loss(
returns: Float[Tensor, " num_envs num_steps"],
logprobs_taken: Float[Tensor, " num_envs num_steps"],
iw: Float[Tensor, " num_envs num_steps"],
) -> Float[Tensor, ""]:
target = returns - returns.mean()
return (iw * logprobs_taken * target.detach()).mean()
Trainer
This is the function that will handle the full training loop. We've provided you with the template of a training loop which should be very similar to yesterday's.
Exercise - implement VPGTrainer
You should fill in the following methods. Ignore logging, can just copy from the solution later.
compute_loss- this method should compute the loss for the VPG objective function.
The training loop is rather standard once everything else is done: we do a rollout, we cut the result into batches, compute the loss, and update the weights from each batch, so we've just included that gor
class VPGTrainer:
def __init__(self, args: VPGArgs):
set_global_seeds(args.seed)
self.args = args
device = args.device
self.rng = t.Generator(device=device).manual_seed(args.seed)
self.run_name = f"{args.env_id}__{args.wandb_project_name}__seed{args.seed}__{time.strftime('%Y%m%d-%H%M%S')}"
if args.env_id == "CartPole-gpu":
self.envs = CartPole(args.num_envs, device=device)
elif args.env_id == "Probe4-v0":
self.envs = Probe4(args.num_envs)
elif args.env_id == "Probe5-v0":
self.envs = Probe5(args.num_envs)
else:
raise ValueError(f"Environment {args.env_id} not supported")
# Define some basic variables from our environment (note, we assume a single discrete action space)
self.num_envs = args.num_envs
self.action_shape = self.envs.action_space.shape
self.num_actions = self.envs.action_space.n
self.obs_shape = self.envs.observation_space.shape
# Create our networks & optimizer
self.policy_network = PolicyNetwork(self.obs_shape, self.num_actions).to(device)
# Compile the policy network for faster inference
if self.args.compile:
self.policy_network = t.compile(self.policy_network)
self.optimizer = t.optim.Adam(self.policy_network.parameters(), lr=args.lr, eps=1e-5, maximize=True)
self.optimizer.zero_grad()
# Create our agent
self.agent = VPGAgent(envs=self.envs, policy_network=self.policy_network, args=self.args, rng=self.rng)
def compute_loss(self, tau: RolloutTensors) -> tuple[t.Tensor, dict[str, Any]]:
raise NotImplementedError()
info = {
"entropy": avg_entropy.item(),
"r_joy": r_joy.item(),
"iw": iw.mean().item() if self.args.use_iw else None,
}
return joy, info
def update_learning_rate(self, time_steps, args):
if args.use_lr_decay and args.lr_frac > 0:
progress = min(1.0, max(time_steps / args.total_timesteps, 0) / args.lr_frac)
return (progress * args.lr_end) + ((1 - progress) * args.lr)
return args.lr
def train(self) -> None:
"""
Trains the agent by generating rollouts and updating the policy.
The progress bar tracks total environment steps.
"""
# --- Setup ---
rollout = Rollout(
num_envs=self.num_envs,
max_steps=self.args.num_steps_per_rollout,
obs_shape=self.obs_shape,
action_shape=self.action_shape,
device=self.args.device,
)
# Calculate the number of environment steps collected per rollout generation
# Calculate the total number of updates (rollouts) to perform
# Use integer division to ensure we don't exceed total_timesteps
env_steps_per_train_step = (
self.args.num_steps_per_rollout * self.args.num_envs // (self.args.num_batches_per_rollout)
)
num_updates = self.args.total_timesteps // env_steps_per_train_step
train_steps = 0 # Counter for gradient updates
# --- Training Loop ---
# The progress bar is managed manually with a `with` statement.
# `total` is set to the total environment steps we want to run.
# The loop iterates `num_updates` times, not `total_timesteps` times.
with tqdm(
total=self.args.total_timesteps,
unit=" env steps",
unit_scale=True,
desc="Training",
miniters=1,
mininterval=0.02,
) as pbar:
env_steps_consumed = 0
for update_num in range(num_updates):
# 1. Generate a new rollout from the environment
rollout, agent_info = self.agent.gen_rollout(rollout)
# 2. Split the rollout into batches along the num_envs dimension
rollout_batches = rollout.get_batches(self.args.batch_size)
# 3. Logging and Progress Bar Update
# This part is outside the inner loop to only log once per rollout
avg_lifespan = agent_info["lifespan"].float().mean().item()
std_lifespan = agent_info["lifespan"].float().std().item()
max_lifespan = agent_info["lifespan"].max().item()
if (avg_lifespan + 0.5) > self.args.num_steps_per_rollout and std_lifespan < 0.01:
print("Agent has learned to play optimally!")
break
# 4. For each batch, perform multiple gradient updates
for batch in rollout_batches:
for i in range(self.args.rollout_use_count):
loss, reinforce_info = self.compute_loss(batch)
info = {**agent_info, **reinforce_info}
loss.backward()
if self.args.max_grad_norm is not None:
t.nn.utils.clip_grad_norm_(
self.policy_network.parameters(), max_norm=self.args.max_grad_norm
)
grad_norm = t.nn.utils.clip_grad_norm_(self.policy_network.parameters(), max_norm=float("inf"))
self.optimizer.step()
self.optimizer.zero_grad()
train_steps += 1
new_lr = self.update_learning_rate(pbar.n, self.args)
for pg in self.optimizer.param_groups:
pg["lr"] = new_lr
# Create info string to display in the progress bar
current_lr = self.optimizer.param_groups[0]["lr"]
info_dict = {
"joy": f"{info['r_joy']:.4f}",
"traj_len": f"{avg_lifespan:.2f} ± {std_lifespan:.2f} (max: {max_lifespan:.2f})",
"H": f"{info['entropy']:.4f}",
"iw": f"{info['iw']:.4f}" if self.args.use_iw else None,
"∇": f"{grad_norm:.4f}",
"lr": f"{current_lr:.2e}",
}
pbar.set_postfix(info_dict)
pbar.update(env_steps_per_train_step)
# --- Cleanup ---
self.envs.close()
if self.args.use_wandb:
wandb.finish()
Solution
class VPGTrainer:
def __init__(self, args: VPGArgs):
set_global_seeds(args.seed)
self.args = args
device = args.device
self.rng = t.Generator(device=device).manual_seed(args.seed)
self.run_name = f"{args.env_id}__{args.wandb_project_name}__seed{args.seed}__{time.strftime('%Y%m%d-%H%M%S')}"
if args.env_id == "CartPole-gpu":
self.envs = CartPole(args.num_envs, device=device)
elif args.env_id == "Probe4-v0":
self.envs = Probe4(args.num_envs)
elif args.env_id == "Probe5-v0":
self.envs = Probe5(args.num_envs)
else:
raise ValueError(f"Environment {args.env_id} not supported")
# Define some basic variables from our environment (note, we assume a single discrete action space)
self.num_envs = args.num_envs
self.action_shape = self.envs.action_space.shape
self.num_actions = self.envs.action_space.n
self.obs_shape = self.envs.observation_space.shape
# Create our networks & optimizer
self.policy_network = PolicyNetwork(self.obs_shape, self.num_actions).to(device)
# Compile the policy network for faster inference
if self.args.compile:
self.policy_network = t.compile(self.policy_network)
self.optimizer = t.optim.Adam(self.policy_network.parameters(), lr=args.lr, eps=1e-5, maximize=True)
self.optimizer.zero_grad()
# Create our agent
self.agent = VPGAgent(envs=self.envs, policy_network=self.policy_network, args=self.args, rng=self.rng)
def compute_loss(self, tau: RolloutTensors) -> tuple[t.Tensor, dict[str, Any]]:
returns = compute_returns(tau.rewards, tau.dones, self.args.gamma) # (num_envs, timestep)
if self.args.normalize_returns:
returns = normalize_returns(returns)
logprobs_taken, entropy = compute_logprobs_and_entropy(tau, self.policy_network)
iw = compute_importance_weights(logprobs_taken, tau, self.args.clip_coef)
r_joy = compute_reinforce_loss(returns, logprobs_taken, iw)
avg_entropy = entropy.mean()
joy = r_joy + self.args.ent_coef * avg_entropy
info = {
"entropy": avg_entropy.item(),
"r_joy": r_joy.item(),
"iw": iw.mean().item() if self.args.use_iw else None,
}
return joy, info
def update_learning_rate(self, time_steps, args):
if args.use_lr_decay and args.lr_frac > 0:
progress = min(1.0, max(time_steps / args.total_timesteps, 0) / args.lr_frac)
return (progress * args.lr_end) + ((1 - progress) * args.lr)
return args.lr
def train(self) -> None:
"""
Trains the agent by generating rollouts and updating the policy.
The progress bar tracks total environment steps.
"""
# --- Setup ---
rollout = Rollout(
num_envs=self.num_envs,
max_steps=self.args.num_steps_per_rollout,
obs_shape=self.obs_shape,
action_shape=self.action_shape,
device=self.args.device,
)
# Calculate the number of environment steps collected per rollout generation
# Calculate the total number of updates (rollouts) to perform
# Use integer division to ensure we don't exceed total_timesteps
env_steps_per_train_step = (
self.args.num_steps_per_rollout * self.args.num_envs // (self.args.num_batches_per_rollout)
)
num_updates = self.args.total_timesteps // env_steps_per_train_step
train_steps = 0 # Counter for gradient updates
# --- Training Loop ---
# The progress bar is managed manually with a `with` statement.
# `total` is set to the total environment steps we want to run.
# The loop iterates `num_updates` times, not `total_timesteps` times.
with tqdm(
total=self.args.total_timesteps,
unit=" env steps",
unit_scale=True,
desc="Training",
miniters=1,
mininterval=0.02,
) as pbar:
env_steps_consumed = 0
for update_num in range(num_updates):
# 1. Generate a new rollout from the environment
rollout, agent_info = self.agent.gen_rollout(rollout)
# 2. Split the rollout into batches along the num_envs dimension
rollout_batches = rollout.get_batches(self.args.batch_size)
# 3. Logging and Progress Bar Update
# This part is outside the inner loop to only log once per rollout
avg_lifespan = agent_info["lifespan"].float().mean().item()
std_lifespan = agent_info["lifespan"].float().std().item()
max_lifespan = agent_info["lifespan"].max().item()
if (avg_lifespan + 0.5) > self.args.num_steps_per_rollout and std_lifespan < 0.01:
print("Agent has learned to play optimally!")
break
# 4. For each batch, perform multiple gradient updates
for batch in rollout_batches:
for i in range(self.args.rollout_use_count):
loss, reinforce_info = self.compute_loss(batch)
info = {**agent_info, **reinforce_info}
loss.backward()
if self.args.max_grad_norm is not None:
t.nn.utils.clip_grad_norm_(
self.policy_network.parameters(), max_norm=self.args.max_grad_norm
)
grad_norm = t.nn.utils.clip_grad_norm_(self.policy_network.parameters(), max_norm=float("inf"))
self.optimizer.step()
self.optimizer.zero_grad()
train_steps += 1
new_lr = self.update_learning_rate(pbar.n, self.args)
for pg in self.optimizer.param_groups:
pg["lr"] = new_lr
# Create info string to display in the progress bar
current_lr = self.optimizer.param_groups[0]["lr"]
info_dict = {
"joy": f"{info['r_joy']:.4f}",
"traj_len": f"{avg_lifespan:.2f} ± {std_lifespan:.2f} (max: {max_lifespan:.2f})",
"H": f"{info['entropy']:.4f}",
"iw": f"{info['iw']:.4f}" if self.args.use_iw else None,
"∇": f"{grad_norm:.4f}",
"lr": f"{current_lr:.2e}",
}
pbar.set_postfix(info_dict)
pbar.update(env_steps_per_train_step)
# --- Cleanup ---
self.envs.close()
if self.args.use_wandb:
wandb.finish()
Probes
As yesterday, we will be using probes to test our model. They've been implemented for you.
def test_probe(probe_idx: int):
"""
Tests a probe environment by training a network on it & verifying that the value functions are
in the expected range.
"""
# Train our network
args = VPGArgs(
env_id=f"Probe{probe_idx}-v0",
wandb_project_name=f"test-probe-{probe_idx}",
total_timesteps=[7500, 7500, 12500, 20000, 20000][probe_idx - 1],
lr=5e-3,
num_envs=4,
video_log_freq=None,
use_wandb=False,
device="cpu",
ent_coef=0.0,
clip_coef=None,
normalize_returns=False,
rollout_use_count=1,
show_probs=True,
)
trainer = VPGTrainer(args)
trainer.train()
agent = trainer.agent
# Get the correct set of observations, and corresponding values we expect
obs_for_probes = [[[0.0]], [[-1.0], [+1.0]], [[0.0], [1.0]], [[0.0]], [[0.0], [1.0]]]
expected_value_for_probes = [
[[1.0]],
[[-1.0], [+1.0]],
[[args.gamma], [1.0]],
[[1.0]],
[[1.0], [1.0]],
]
expected_probs_for_probes = [None, None, None, [[0.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]
tolerances = [1e-3, 1e-3, 1e-3, 2e-3, 2e-3]
obs = t.tensor(obs_for_probes[probe_idx - 1]).to(args.device)
# Calculate the actual value & probs, and verify them
with t.inference_mode():
probs = agent.policy_network(obs).softmax(-1)
expected_probs = expected_probs_for_probes[probe_idx - 1]
if expected_probs is not None:
print(f"Probs: {probs}")
print(f"Expected probs: {t.tensor(expected_probs).to(args.device)}")
t.testing.assert_close(probs, t.tensor(expected_probs).to(args.device), atol=tolerances[probe_idx - 1], rtol=0)
print(f"Probe {probe_idx} tests passed!\n")
gym.envs.registration.register(id="Probe4-v0", entry_point=Probe4)
gym.envs.registration.register(id="Probe5-v0", entry_point=Probe5)
for probe_idx in [4, 5]:
test_probe(probe_idx)
Training Run
Vanilla Policy Gradient can often be a bit finicky and unstable to train (which is why in practice we use PPO instead). None-the-less, I've tried to find a good set of hyperparameters that work reasoanbly okay, and a set that (if you're lucky), trains to optimality in ~15 seconds on CartPole!
args = VPGArgs(use_wandb=False,
num_envs=4,
num_batches_per_rollout=1,
total_timesteps=500_000,
num_steps_per_rollout=500,
rollout_use_count=4, # this seems to matter a lot
ent_coef=0.3, #works with zero
clip_coef=0.2, #can sometimes work with no clipping, but it helps
max_grad_norm=0.5,
normalize_returns=True,
use_iw = True,
lr = 1e-4,
gamma=0.99,
device="cpu") #may run faster on cpu due to few envs/small batchsize
trainer = VPGTrainer(args)
trainer.train()
generate_and_plot_trajectory(trainer, args, mode = "pg")
# There's a somewhat critical region where the cartpole really picks up,
# and we need the LR to decay rather fast before the gradients explode
# No guarantees that this will work for other environments, but it's a good starting point
# sub 15 seconds to optimal on A4000!!
# might need to rerun a few times to get a lucky initialization, it's rather sensitive!
device = t.device("cuda")
args_fast = VPGArgs(
use_wandb=False,
num_envs=256,
num_batches_per_rollout=4,
total_timesteps=4_000_000,
num_steps_per_rollout=500,
rollout_use_count=1, # this seems to matter a lot
ent_coef=0.5, # works with zero
clip_coef=0.1, # can sometimes work with no clipping, but it helps
max_grad_norm=1,
normalize_returns=True,
lr=1e-2, # risky!
use_lr_decay=True,
use_iw=True, # dont' need it if we only use each rollout once in one
lr_end=1e-3,
lr_frac=0.6,
compile=False,
gamma=0.99,
seed=1337,
device=device,
)
trainer = VPGTrainer(args_fast)
trainer.train()
generate_and_plot_trajectory(trainer, args, mode="pg")