4️⃣ Self-Play & Training
Learning Objectives
- Implement the self-play sampler: the tree policy, the network policy, and using the critic to estimate the value of rollouts.
- Understand the loss function for the network and how it distills the planning provided by the tree search.
- Train an agent to match a perfect solver's moves (and hopefully beat you too!)
Now we close the loop. We need two ingredients: a sampler that turns MCTS into training
data, and a loss that trains the network on that data. All the boring parts
we've already seen on PPO day (the replay buffer, the optimiser, the generation loop)
are given in the AlphaZeroTrainer below.
The value target: compute_z_targets
During a self-play generation we record, for every game b and move t, whether the move
ended the game (dones[b,t]) and the mover's reward (rewards[b,t]). The value target z[b,t]
is the eventual outcome of that game, from the perspective of the mover at state t:
so it flips sign every ply,
and resets at each game boundary (games auto-reset and replay within a generation).
The clean way to compute this is a single backward scan over time. Going from the last move
to the first: if move t was terminal, the running value is just its reward; otherwise it's the
negation of the running value from t+1 (negamax again). This propagates each game's
outcome back to all its states with the correct alternating signs.
This is a similar function to compute_returns from PPO day, but with the alternating signs of the negamax.
Worked example: \
Consider a batch of three games.
The first game latest two plies before
a win, the second game lasted four plies, and the third game lasted five plies before a draw.
The corresponding done
and reward tensors are:
dones = torch.Tensor([[0,0,1,1,1,1],
[0,0,0,0,1,0],
[0,0,0,0,0,1]]).bool()
reward = torch.Tensor([[0,0,1,0,0,0],
[0,0,0,0,1,0],
[0,0,0,0,0,0]]).float()
Recall that it is not actually possible for the environment to every give a negative reward, as a player can only every win or draw after making their move. A player cannot move and then immediately lose.
The corresponding z tensor would be:
z = tensor([[ 1., -1., 1., 0., 0., 0.],
[ 1., -1., 1., -1., 1., 0.],
[-0., 0., -0., 0., -0., 0.]])
Exercise - implement compute_z_targets
def compute_z_targets(
dones: Bool[Tensor, "batch timesteps"],
rewards: Float[Tensor, "batch timesteps"]
) -> Float[Tensor, "batch timesteps"]:
"""Negamax value targets for a batch of `B` self-play games of `T` plies.
Walking each game backwards from its terminal rewards, the target at each ply is the game's final
reward with its sign flipped once per step back. Every recorded value in this project is from the
perspective of the player about to move; stepping back one ply changes whose turn it is, hence the
negation (negamax: good for the mover is bad for its parent).
Args:
dones: (batch, timesteps) marks the ply where each game ended
rewards: (batch, timesteps) rewards to the mover at each ply (nonzero only where dones)
Returns:
(batch, timesteps) the mover-perspective outcome `z` for every recorded state
"""
batch, timesteps = dones.shape
z = torch.zeros((batch, timesteps), device=dones.device)
raise NotImplementedError()
return z
tests.test_compute_z_targets(compute_z_targets)
Solution
def compute_z_targets(
dones: Bool[Tensor, "batch timesteps"],
rewards: Float[Tensor, "batch timesteps"]
) -> Float[Tensor, "batch timesteps"]:
"""Negamax value targets for a batch of `B` self-play games of `T` plies.
Walking each game backwards from its terminal rewards, the target at each ply is the game's final
reward with its sign flipped once per step back. Every recorded value in this project is from the
perspective of the player about to move; stepping back one ply changes whose turn it is, hence the
negation (negamax: good for the mover is bad for its parent).
Args:
dones: (batch, timesteps) marks the ply where each game ended
rewards: (batch, timesteps) rewards to the mover at each ply (nonzero only where dones)
Returns:
(batch, timesteps) the mover-perspective outcome `z` for every recorded state
"""
batch, timesteps = dones.shape
z = torch.zeros((batch, timesteps), device=dones.device)
running = torch.zeros((batch,), device=dones.device)
for t in range(timesteps - 1, -1, -1):
running = torch.where(dones[:, t], rewards[:, t], -running)
z[:, t] = running
return z
tests.test_compute_z_targets(compute_z_targets)
The AlphaZero loss: compute_az_loss
Given the network's value (N,) and logits (N,7) on a minibatch, and the targets
pi (N,7) (MCTS visit distribution) and z (N,) (game outcome), the loss is
averaged over the minibatch (the weight decay is not included explicitly in the loss as it is handled by the optimiser).
Exercise - implement compute_az_loss
def compute_az_loss(
value: Float[Tensor, "N"],
logits: Float[Tensor, "N 7"],
pi: Float[Tensor, "N 7"],
z: Float[Tensor, "N"],
value_coef: float = 1.0,
) -> Float[Tensor, ""]:
"""Scalar AlphaZero loss over a minibatch of `N` positions: policy cross-entropy + value MSE.
Loss = mean of `-sum_a pi_a log softmax(logits)_a` + `value_coef * (value - z)^2`.
Args:
value: (N,) critic outputs
logits: (N, 7) actor outputs
pi: (N, 7) MCTS visit-count policy target
z: (N,) game-outcome value target
value_coef: weight on the value-MSE term
Returns:
scalar tensor: the mean total loss
"""
assert value.shape == z.shape
assert logits.shape == pi.shape
raise NotImplementedError()
tests.test_compute_az_loss(compute_az_loss)
Solution
def compute_az_loss(
value: Float[Tensor, "N"],
logits: Float[Tensor, "N 7"],
pi: Float[Tensor, "N 7"],
z: Float[Tensor, "N"],
value_coef: float = 1.0,
) -> Float[Tensor, ""]:
"""Scalar AlphaZero loss over a minibatch of `N` positions: policy cross-entropy + value MSE.
Loss = mean of `-sum_a pi_a log softmax(logits)_a` + `value_coef * (value - z)^2`.
Args:
value: (N,) critic outputs
logits: (N, 7) actor outputs
pi: (N, 7) MCTS visit-count policy target
z: (N,) game-outcome value target
value_coef: weight on the value-MSE term
Returns:
scalar tensor: the mean total loss
"""
assert value.shape == z.shape
assert logits.shape == pi.shape
logprobs = F.log_softmax(logits, dim=-1)
policy_loss = -(pi * logprobs).sum(-1).mean()
critic_loss = F.mse_loss(value, z)
# alternative non-mse solution:
# critic_loss = ((value - z) ** 2).mean()
return policy_loss + value_coef * critic_loss
tests.test_compute_az_loss(compute_az_loss)
Training Hyperparameters
@dataclass
class AZConfig:
"""All the knobs for self-play + training. The defaults are a fast in-notebook recipe (4096
self-play games/gen, 32 sims/move, 6 generations) that reaches ~85% Pons-solver accuracy in
~4-5 min on a GPU. For a stronger agent raise `sims` to 64 and `num_generations` to ~50
(≈182k optimiser steps); to run faster still, dial `num_games` / `sims` / `num_generations` down.
c_puct (1.5->1.0), lr (1e-3->5e-3) and buffer_gens (8->4) were tuned by a
full-run sweep (pons_CE ~0.466->0.420 at ~1/4 the compute); every gain came from fresher self-play
data — capacity/loss/regularisation were inert. Usable lr band ~3e-3..6e-3 (≥7e-3 is seed-unstable)."""
# self-play / data
num_games: int = 4096 # parallel self-play games per generation
sims: int = 16 # MCTS simulations per move (32≈64 for learning but ~2x faster self-play; raise to 64 for a stronger run)
num_generations: int = 12 # training generations (~85% Pons acc in ~4-5 min; raise to ~50 for the full recipe)
buffer_gens: int = 4 # replay buffer = the last this-many generations (tuned 8->4: fresher data)
moves_per_gen: int = 42 # plies per generation (a full Connect-4 game)
temperature: float = 1.0 # visit-count sampling temperature (first `temp_cutoff` plies)
temp_cutoff: int = 12 # after this many plies, play greedily
augment: bool = True # mirror-symmetry data augmentation
# MCTS / exploration
c_puct: float = 1.0 # tuned 1.5->1.0 (full-run sweep, both seeds)
max_depth: int = 42
dirichlet_alpha: float = 10 / 7 # ≈ 1.43, root exploration-noise concentration
dirichlet_eps: float = 0.25 # weight of the root Dirichlet noise
# optimiser / schedule
lr: float = 5e-3 # initial learning rate (tuned 1e-3->5e-3; >=7e-3 is seed-unstable)
lr_min: float = 2e-5 # cosine-decay target over the run
weight_decay: float = 1e-4
grad_clip: float = 1.0 # global grad-norm clip
minibatch: int = 1024
value_coef: float = 1.0 # weight on the value-MSE loss term
# logging
use_wandb: bool = False # log loss / lr / Pons metrics to Weights & Biases
wandb_project: str = "alphazero-connect4"
The replay buffer (given)
Self-play and training talk through a small given ReplayBuffer, in the same spirit as the VPG
rollout buffer from [2.2]. It preallocates the per-generation rollout tensors once, you write one
ply at a time, and end_generation turns the finished rollout into flat (obs, pi, z) training
examples (with the negamax value targets computed and dropping states whose game didn't finish) keeping
the last cfg.buffer_gens generations. get_dataloader then hands back a shuffled DataLoader over
those examples. This keeps the exercises tiny: self_play_step just writes each ply, and the
training loop just iterates get_dataloader(...) calling your training_step.
class ReplayBuffer:
"""Given. Rolling replay of the last `cfg.buffer_gens` self-play generations.
Usage: `write(...)` one ply at a time, `end_generation()` after each generation (computes value
targets, drops unfinished states, flattens over (game, ply), evicts the oldest generation), then
`get_dataloader(mb)` for a shuffled training `DataLoader`. Mirrors the VPG rollout buffer from [2.2]."""
def __init__(self, cfg: AZConfig, device):
self.cfg, self.device = cfg, device
B, T = cfg.num_games, cfg.moves_per_gen
# preallocated rollout for the CURRENT generation (written one ply at a time, in place)
self.obs = torch.empty((B, T, 3, 6, 7), device=device)
self.pi = torch.empty((B, T, 7), device=device)
self.dones = torch.empty((B, T), dtype=torch.bool, device=device)
self.rews = torch.empty((B, T), device=device)
self.t = 0 # next free ply slot in the current generation
self.gens = [] # rolling list of finished generations, each a flat (obs, pi, z)
def write(self, obs_canon, pi, done, reward):
"""Record one ply of the current generation into row `self.t`, then advance."""
self.obs[:, self.t] = obs_canon
self.pi[:, self.t] = pi
self.dones[:, self.t] = done
self.rews[:, self.t] = reward
self.t += 1
def end_generation(self):
"""Finish the current generation: compute negamax `z`, keep only states whose game finished
(reverse cumulative-OR of dones over time), flatten over (game, ply), append to the rolling
buffer (evicting the oldest), and reset the write pointer."""
z = compute_z_targets(self.dones, self.rews) # (B, T)
keep = (self.dones.int().flip(-1).cumsum(-1).flip(-1) > 0).reshape(-1)
self.gens.append((self.obs.reshape(-1, 3, 6, 7)[keep],
self.pi.reshape(-1, 7)[keep],
z.reshape(-1)[keep]))
if len(self.gens) > self.cfg.buffer_gens:
self.gens.pop(0)
self.t = 0
def get_dataloader(self, batch_size):
"""Snapshot the whole buffer into a `DataLoader` over `(obs, pi, z)` training examples.
The DataLoader handles shuffling + batching internally; iterate it once per epoch (it
reshuffles each time). `drop_last` keeps every batch exactly `batch_size` (great for a fixed
compiled forward), but only when there's at least one full batch, so tiny configs still train.
Tensors already live on the GPU, so the default `num_workers=0` / no pinning is correct."""
obs = torch.cat([g[0] for g in self.gens])
pi = torch.cat([g[1] for g in self.gens])
z = torch.cat([g[2] for g in self.gens])
ds = TensorDataset(obs, pi, z)
return DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=len(ds) >= batch_size)
def reset(self):
self.gens, self.t = [], 0
def __len__(self):
return sum(g[0].shape[0] for g in self.gens)
Evaluating against a perfect solver (given)
Recall that Connect4 is a solved game, and fast solvers exist that can in milliseconds report back the optimal move for any given position (with the use of a pre-computed opening bible.
We make use of the Pascal Pons' perfect Connect-4 solver, and pregenerate ~6700 board positions that have been labelled with the optimal move for each position, as well as pre-filtered for positions that are decisive (i.e. it is not the case that all moves lead to a loss/win/draw under optimal play on both sides).
The given evaluate_policy runs the
network over this set in one batched forward pass, and reports back the metrics:
acc— fraction of positions where the policy's top move is solver-optimal,ce— the policy cross-entropy against the optimal-move set: $ - \log \sum_{a : a \text{ is optimal}} p_\theta(a|s) $, a smoother signal thanacc.val_signacc— fraction of decisive positions where the value head has the right sign (predicts win vs loss correctly).
This eval is far cheaper than playing out games as it requires only one batched forward pass, while
providing a lot of feedback on how well the policy is performing.
You should see acc climb (and ce fall) over a handful of generations.
If you like, have a look in the pascal_pons directory to see how the dataset was generated. It's
included in the repo for convenience. It basically downloads the opening book and Pascal Pons' solver,
generates a series of games between epsilon-greedy optimal policies, and labels the positions with the optimal move.
It then filters out positions that are not decisive or duplicates, and saves the remaining positions to a file.
It's also avaliable here on HuggingFace just
in case, which will auto-download the dataset the first time you run evaluate_policy.
from pascal_pons.eval_pons import evaluate_policy
Exercise - implement self_play_step
self_play_step should run one ply of self-play. Given the current boards obs and who is to move
to_move (all num_games games at once):
- Search: run the batched MCTS to get the root visit counts.
- Policy target: turn the visit counts into the target policy
pi(normalised visit counts). - Act: sample an action from the tree policy (
self.sample_actions) andstepthe environment. - Record the ply into the given replay buffer:
self.buffer.write(obs_canon, pi, done, reward), whereobs_canon = canonicalise_obs(obs, to_move)is the board from the mover's perspective.
Return (next_obs, done). The given self_play loop calls this moves_per_gen times, uses done to
advance to_move (auto-reset), and calls self.buffer.end_generation() at the end.
The buffer handles the value targets, masking out unfinished states, the rolling replay, and serving minibatches.
Exercise - implement sample_actions
Implement the sample_actions function, which samples actions from the root visit counts using the temperature parameter.
torch.multinomial to sample an action from the distribution.
Exercise - implement training_step
training_step : one optimiser step on a minibatch (obs, pi, z):
* forward the net, compute the AlphaZero loss with compute_az_loss, do an optimizer step,
and then return the scalar loss. Optionally, you can also clip the gradient to cfg.grad_clip
with nn.utils.clip_grad_norm_.
class AlphaZeroTrainer:
"""Owns the full per-generation loop pictured above. Each generation `train()` runs:
1. `self_play()`: `moves_per_gen` plies of batched self-play with the frozen net
(your `self_play_step` per ply, writing `(obs_canon, π, done, reward)` into `self.buffer`),
then `buffer.end_generation()` turns the rollout into flat `(obs, π, z)` training rows.
2. one supervised pass over `buffer.get_dataloader(minibatch)` (your `training_step` per batch),
3. a cosine LR-schedule step and (periodically) `evaluate()` against the Pons solver.
Attributes: `env` (Connect4Env), `cfg` (AZConfig), `model` (Connect4Model), `opt` (AdamW),
`mcts` (BatchedMCTS — built from cfg's sims/c_puct/dirichlet settings), `buffer` (ReplayBuffer).
"""
def __init__(self, env, cfg, model):
self.env = env
self.cfg = cfg
self.device = env.device
self.model = model
self.opt = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
self.mcts = BatchedMCTS(env, MCTSConfig(
sims=cfg.sims, c_puct=cfg.c_puct, max_depth=cfg.max_depth,
dirichlet_alpha=cfg.dirichlet_alpha, dirichlet_eps=cfg.dirichlet_eps))
self.buffer = ReplayBuffer(cfg, self.device)
def sample_actions(self, root_N: Float[Tensor, "B 7"], temperature: float = 1.0) -> Float[Tensor, "B"]:
"""Sample one action per game from the tree policy π(a) ∝ N(s,a)^(1/τ).
Args:
root_N: (B, 7) root visit counts from the MCTS search
temperature: τ; 1.0 samples proportionally to visit counts, →0 approaches greedy
argmax, larger flattens the distribution
Returns:
(B, 1) sampled column indices (one per game). Actions with zero visits have
probability 0 and are never sampled. (No need to special-case tiny temperatures here —
the trainer always calls this with moderate τ; `sample_tree_policy` in section 2 is the one
that must handle the greedy τ→0 limit.)
"""
raise NotImplementedError()
@torch.no_grad()
def self_play_step(self, obs, to_move):
"""One ply of self-play for all `num_games` games at once: MCTS -> policy target -> record the
ply into `self.buffer` -> sample -> step.
Args:
obs: (B, 3, 6, 7) current (absolute) boards
to_move: (B,) whether player-1 (red) is to move in each game
Returns:
next_obs: (B, 3, 6, 7) boards after the move
done: (B,) whether the move ended each game (the loop uses this to flip `to_move`)
"""
raise NotImplementedError()
@torch.no_grad()
def self_play(self, progress: bool = True):
"""Play one generation: `num_games` games for `moves_per_gen` plies, calling your
`self_play_step` each ply (which `write`s into `self.buffer`), then finalise the generation in
the buffer. The buffer handles the value targets, masking and replay -- nothing to return.
The ply loop is wrapped in a tqdm bar with `unit_scale` (like the 2.2 SPS bar) so you get a
live, SI-formatted env-steps/sec readout during self-play rather than one update per generation."""
B, T = self.cfg.num_games, self.cfg.moves_per_gen
obs = self.env.reset(B)
to_move = torch.ones((B,), dtype=torch.bool, device=self.device)
self.model.eval()
# each ply runs `sims` batched env.steps over all B games (one per MCTS simulation), so the
# generation does B * T * sims env transitions in total -- tqdm turns that into a live rate,
# and the postfix shows cumulative MCTS simulations done out of `moves_per_gen * sims`.
total_sims = T * self.cfg.sims
bar = tqdm(total=B * T * self.cfg.sims, unit=" env steps", unit_scale=True,
desc="self-play", leave=False, disable=not progress)
for ply in range(T):
obs, done = self.self_play_step(obs, to_move)
to_move = torch.where(done, torch.ones_like(to_move), ~to_move) # auto-reset -> player 1
bar.update(B * self.cfg.sims)
bar.set_postfix_str(f"sims {(ply + 1) * self.cfg.sims}/{total_sims}")
bar.close()
self.buffer.end_generation()
def training_step(self, obs, pi, z):
"""One optimiser step on a single minibatch `(obs, pi, z)`: forward the net, compute the
AlphaZero loss (`compute_az_loss`), zero the grads (`self.opt.zero_grad(set_to_none=True)`),
backprop, clip the gradient norm to `cfg.grad_clip`, and step the optimiser. AlphaZero's update is just supervised learning -- regress the value head onto
`z` and the policy head onto `pi`.
Args:
obs: (mb, 3, 6, 7) mover-canonical boards
pi: (mb, 7) MCTS policy targets
z: (mb,) value targets
Returns:
float: the minibatch loss
"""
raise NotImplementedError()
@torch.no_grad()
def evaluate(self) -> dict:
"""Given. Score the current network against the frozen Pons solver set (policy accuracy /
cross-entropy / value sign-accuracy). One cached forward pass; see the section above."""
return evaluate_policy(self.model, self.env)
def train(self, num_generations=None, eval_every=1):
"""Given. The full training loop -- you don't need to touch this. Each generation: run
self-play into the buffer, do one supervised pass over the buffer's `DataLoader`
(calling your `training_step`), step the cosine LR schedule, and periodically `evaluate`
against the Pons solver. Logs loss / lr / eval to a tqdm bar (and to wandb if `cfg.use_wandb`)."""
import time
num_generations = num_generations or self.cfg.num_generations
# cosine-decay the LR; schedule over a >=10-gen horizon so a short quick run doesn't crater the
# LR before it finishes (a 6-gen T_max=6 cosine would decay to lr_min by the last gen and stall).
sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=max(num_generations, 10), eta_min=self.cfg.lr_min)
# effective env transitions/sec for the tqdm bar: each generation steps the env
# num_games * moves_per_gen * sims times (one batched env.step per MCTS simulation). It's just a
# counter + a clock, so it adds no per-step overhead -- handy for spotting when self-play bottlenecks.
t0, env_steps = time.time(), 0
steps_per_gen = self.cfg.num_games * self.cfg.moves_per_gen * self.cfg.sims
if self.cfg.use_wandb:
import wandb
wandb.init(project=self.cfg.wandb_project, config=asdict(self.cfg))
metrics = {}
bar = tqdm(range(1, num_generations + 1))
for gen in bar:
self.self_play() # fill + roll the replay buffer
env_steps += steps_per_gen
self.model.train()
loader = self.buffer.get_dataloader(self.cfg.minibatch)
total_loss, n_batches = 0.0, 0
tbar = tqdm(loader, desc="train", leave=False) # one supervised pass over the buffer
for obs, pi, z in tbar:
total_loss += self.training_step(obs, pi, z)
n_batches += 1
tbar.set_postfix_str(f"loss={total_loss / n_batches:.3f}") # live running-mean loss
loss = total_loss / max(n_batches, 1)
sched.step()
if eval_every and gen % eval_every == 0:
metrics = self.evaluate() # Pons solver benchmark
lr = sched.get_last_lr()[0]
sps = env_steps / max(time.time() - t0, 1e-9) # effective env steps/sec (cumulative)
bar.set_postfix_str(f"loss={loss:.3f} acc={metrics.get('pons/acc', float('nan')):.3f} "
f"ce={metrics.get('pons/ce', float('nan')):.3f} env/s={fmt_si(sps)}")
if self.cfg.use_wandb:
wandb.log({"generation": gen, "loss": loss, "lr": lr, "env_steps_per_sec": sps, **metrics})
if self.cfg.use_wandb:
wandb.finish()
return self.model
tests.test_sample_actions(AlphaZeroTrainer)
tests.test_self_play_step(AlphaZeroTrainer)
tests.test_training_step(AlphaZeroTrainer)
Solution
class AlphaZeroTrainer:
"""Owns the full per-generation loop pictured above. Each generation `train()` runs:
1. `self_play()`: `moves_per_gen` plies of batched self-play with the frozen net
(your `self_play_step` per ply, writing `(obs_canon, π, done, reward)` into `self.buffer`),
then `buffer.end_generation()` turns the rollout into flat `(obs, π, z)` training rows.
2. one supervised pass over `buffer.get_dataloader(minibatch)` (your `training_step` per batch),
3. a cosine LR-schedule step and (periodically) `evaluate()` against the Pons solver.
Attributes: `env` (Connect4Env), `cfg` (AZConfig), `model` (Connect4Model), `opt` (AdamW),
`mcts` (BatchedMCTS — built from cfg's sims/c_puct/dirichlet settings), `buffer` (ReplayBuffer).
"""
def __init__(self, env, cfg, model):
self.env = env
self.cfg = cfg
self.device = env.device
self.model = model
self.opt = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
self.mcts = BatchedMCTS(env, MCTSConfig(
sims=cfg.sims, c_puct=cfg.c_puct, max_depth=cfg.max_depth,
dirichlet_alpha=cfg.dirichlet_alpha, dirichlet_eps=cfg.dirichlet_eps))
self.buffer = ReplayBuffer(cfg, self.device)
def sample_actions(self, root_N: Float[Tensor, "B 7"], temperature: float = 1.0) -> Float[Tensor, "B"]:
"""Sample one action per game from the tree policy π(a) ∝ N(s,a)^(1/τ).
Args:
root_N: (B, 7) root visit counts from the MCTS search
temperature: τ; 1.0 samples proportionally to visit counts, →0 approaches greedy
argmax, larger flattens the distribution
Returns:
(B, 1) sampled column indices (one per game). Actions with zero visits have
probability 0 and are never sampled. (No need to special-case tiny temperatures here —
the trainer always calls this with moderate τ; `sample_tree_policy` in section 2 is the one
that must handle the greedy τ→0 limit.)
"""
temp_visits = root_N ** (1 / temperature)
probs = temp_visits / temp_visits.sum(-1, keepdim=True)
action = torch.multinomial(probs, num_samples=1)
return action
@torch.no_grad()
def self_play_step(self, obs, to_move):
"""One ply of self-play for all `num_games` games at once: MCTS -> policy target -> record the
ply into `self.buffer` -> sample -> step.
Args:
obs: (B, 3, 6, 7) current (absolute) boards
to_move: (B,) whether player-1 (red) is to move in each game
Returns:
next_obs: (B, 3, 6, 7) boards after the move
done: (B,) whether the move ended each game (the loop uses this to flip `to_move`)
"""
root_N = self.mcts.search(self.model, obs, to_move, add_noise=True) # root noise -> exploration
pi = root_N / root_N.sum(-1, keepdim=True)
obs_canon = canonicalise_obs(obs, to_move)
action = self.sample_actions(root_N, self.cfg.temperature)
next_obs, done, reward = self.env.step(obs, action, to_move)
self.buffer.write(obs_canon, pi, done, reward) # record this ply (mover-canonical board + targets)
return next_obs, done
@torch.no_grad()
def self_play(self, progress: bool = True):
"""Play one generation: `num_games` games for `moves_per_gen` plies, calling your
`self_play_step` each ply (which `write`s into `self.buffer`), then finalise the generation in
the buffer. The buffer handles the value targets, masking and replay -- nothing to return.
The ply loop is wrapped in a tqdm bar with `unit_scale` (like the 2.2 SPS bar) so you get a
live, SI-formatted env-steps/sec readout during self-play rather than one update per generation."""
B, T = self.cfg.num_games, self.cfg.moves_per_gen
obs = self.env.reset(B)
to_move = torch.ones((B,), dtype=torch.bool, device=self.device)
self.model.eval()
# each ply runs `sims` batched env.steps over all B games (one per MCTS simulation), so the
# generation does B * T * sims env transitions in total -- tqdm turns that into a live rate,
# and the postfix shows cumulative MCTS simulations done out of `moves_per_gen * sims`.
total_sims = T * self.cfg.sims
bar = tqdm(total=B * T * self.cfg.sims, unit=" env steps", unit_scale=True,
desc="self-play", leave=False, disable=not progress)
for ply in range(T):
obs, done = self.self_play_step(obs, to_move)
to_move = torch.where(done, torch.ones_like(to_move), ~to_move) # auto-reset -> player 1
bar.update(B * self.cfg.sims)
bar.set_postfix_str(f"sims {(ply + 1) * self.cfg.sims}/{total_sims}")
bar.close()
self.buffer.end_generation()
def training_step(self, obs, pi, z):
"""One optimiser step on a single minibatch `(obs, pi, z)`: forward the net, compute the
AlphaZero loss (`compute_az_loss`), zero the grads (`self.opt.zero_grad(set_to_none=True)`),
backprop, clip the gradient norm to `cfg.grad_clip`, and step the optimiser. AlphaZero's update is just supervised learning -- regress the value head onto
`z` and the policy head onto `pi`.
Args:
obs: (mb, 3, 6, 7) mover-canonical boards
pi: (mb, 7) MCTS policy targets
z: (mb,) value targets
Returns:
float: the minibatch loss
"""
value, logits = self.model(obs.contiguous())
loss = compute_az_loss(value, logits, pi, z, self.cfg.value_coef)
self.opt.zero_grad(set_to_none=True)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
self.opt.step()
return float(loss.item())
@torch.no_grad()
def evaluate(self) -> dict:
"""Given. Score the current network against the frozen Pons solver set (policy accuracy /
cross-entropy / value sign-accuracy). One cached forward pass; see the section above."""
return evaluate_policy(self.model, self.env)
def train(self, num_generations=None, eval_every=1):
"""Given. The full training loop -- you don't need to touch this. Each generation: run
self-play into the buffer, do one supervised pass over the buffer's `DataLoader`
(calling your `training_step`), step the cosine LR schedule, and periodically `evaluate`
against the Pons solver. Logs loss / lr / eval to a tqdm bar (and to wandb if `cfg.use_wandb`)."""
import time
num_generations = num_generations or self.cfg.num_generations
# cosine-decay the LR; schedule over a >=10-gen horizon so a short quick run doesn't crater the
# LR before it finishes (a 6-gen T_max=6 cosine would decay to lr_min by the last gen and stall).
sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=max(num_generations, 10), eta_min=self.cfg.lr_min)
# effective env transitions/sec for the tqdm bar: each generation steps the env
# num_games * moves_per_gen * sims times (one batched env.step per MCTS simulation). It's just a
# counter + a clock, so it adds no per-step overhead -- handy for spotting when self-play bottlenecks.
t0, env_steps = time.time(), 0
steps_per_gen = self.cfg.num_games * self.cfg.moves_per_gen * self.cfg.sims
if self.cfg.use_wandb:
import wandb
wandb.init(project=self.cfg.wandb_project, config=asdict(self.cfg))
metrics = {}
bar = tqdm(range(1, num_generations + 1))
for gen in bar:
self.self_play() # fill + roll the replay buffer
env_steps += steps_per_gen
self.model.train()
loader = self.buffer.get_dataloader(self.cfg.minibatch)
total_loss, n_batches = 0.0, 0
tbar = tqdm(loader, desc="train", leave=False) # one supervised pass over the buffer
for obs, pi, z in tbar:
total_loss += self.training_step(obs, pi, z)
n_batches += 1
tbar.set_postfix_str(f"loss={total_loss / n_batches:.3f}") # live running-mean loss
loss = total_loss / max(n_batches, 1)
sched.step()
if eval_every and gen % eval_every == 0:
metrics = self.evaluate() # Pons solver benchmark
lr = sched.get_last_lr()[0]
sps = env_steps / max(time.time() - t0, 1e-9) # effective env steps/sec (cumulative)
bar.set_postfix_str(f"loss={loss:.3f} acc={metrics.get('pons/acc', float('nan')):.3f} "
f"ce={metrics.get('pons/ce', float('nan')):.3f} env/s={fmt_si(sps)}")
if self.cfg.use_wandb:
wandb.log({"generation": gen, "loss": loss, "lr": lr, "env_steps_per_sec": sps, **metrics})
if self.cfg.use_wandb:
wandb.finish()
return self.model
tests.test_sample_actions(AlphaZeroTrainer)
tests.test_self_play_step(AlphaZeroTrainer)
tests.test_training_step(AlphaZeroTrainer)
Train your agent!
Put it all together: build the config, the model and the trainer, and call trainer.train(). The
given training loop handles self-play, the supervised passes over the buffer (your training_step),
the cosine LR schedule, the periodic Pons eval, and the tqdm / wandb logging. So there's nothing to
wire up here. Set use_wandb=True in the config to stream loss / lr / pons/* to Weights & Biases.
The defaults are a fast recipe (4096 games/gen, 16 sims, 12 generations) that reaches ~85% Pons-solver
accuracy in ~4-5 min on a GPU; raise sims to 64 and num_generations to ~50 for a stronger agent,
or dial them down to go faster. On a GPU you should see acc climb and ce fall each generation.
cfg = AZConfig() # fast recipe (4096 games, 16 sims, 12 gens) -> ~85% Pons acc in ~4-5 min on a GPU
model = Connect4Model(device)
if TRAINING:
# `BatchedMCTS` is built later, in the Vectorized-MCTS bonus. So training runs even if you
# haven't done that bonus yet, pull the finished class from `solutions` when not yet defined.
if "BatchedMCTS" not in globals():
from solutions import BatchedMCTS
trainer = AlphaZeroTrainer(env, cfg, model)
trainer.train() # eval + logging handled inside; set cfg.use_wandb=True to log to wandb
Play against your agent
Time to see how good it is! utils gives you two ready-made ways to play your trained model (which we've implemented for you):
- Web app (recommended):
play_web(model, env, port=8080)serves a little browser game athttp://localhost:8080. In VSCode / Cursor the port is auto-forwarded, so it should open straight away. You (🔴) move first and the agent (🟡) replies with MCTS. Interrupt the cell (■ / Ctrl-C) to stop. - Terminal (fallback):
play_cli(model, env)plays the same game in the cell output — type a legal column (0-6) on your turn.
from part5_mcts_alphazero.utils import play_web, play_cli
play_web(model, env, port=8080) # browser game on http://localhost:8080 (auto-forwarded); ■/Ctrl-C to stop
# play_cli(model, env) # ...or play in the terminal instead