4️⃣ Self-Play & Training

Learning Objectives
  • Implement the self-play sampler: the tree policy, the network policy, and using the critic to estimate the value of rollouts.
  • Understand the loss function for the network and how it distills the planning provided by the tree search.
  • Train an agent to match a perfect solver's moves (and hopefully beat you too!)

Now we close the loop. We need two ingredients: a sampler that turns MCTS into training data, and a loss that trains the network on that data. All the boring parts we've already seen on PPO day (the replay buffer, the optimiser, the generation loop) are given in the AlphaZeroTrainer below.

The value target: compute_z_targets

During a self-play generation we record, for every game b and move t, whether the move ended the game (dones[b,t]) and the mover's reward (rewards[b,t]). The value target z[b,t] is the eventual outcome of that game, from the perspective of the mover at state t: so it flips sign every ply, and resets at each game boundary (games auto-reset and replay within a generation).

The clean way to compute this is a single backward scan over time. Going from the last move to the first: if move t was terminal, the running value is just its reward; otherwise it's the negation of the running value from t+1 (negamax again). This propagates each game's outcome back to all its states with the correct alternating signs.

This is a similar function to compute_returns from PPO day, but with the alternating signs of the negamax.

Worked example: \ Consider a batch of three games. The first game latest two plies before a win, the second game lasted four plies, and the third game lasted five plies before a draw. The corresponding done and reward tensors are:

dones  = torch.Tensor([[0,0,1,1,1,1],
                       [0,0,0,0,1,0],
                       [0,0,0,0,0,1]]).bool()

reward = torch.Tensor([[0,0,1,0,0,0],
                       [0,0,0,0,1,0],
                       [0,0,0,0,0,0]]).float()

Recall that it is not actually possible for the environment to every give a negative reward, as a player can only every win or draw after making their move. A player cannot move and then immediately lose.

The corresponding z tensor would be:

z = tensor([[ 1., -1.,  1.,  0.,  0.,  0.],
            [ 1., -1.,  1., -1.,  1.,  0.],
            [-0.,  0., -0.,  0., -0.,  0.]])

Exercise - implement compute_z_targets

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 15-20 minutes on this exercise.
def compute_z_targets(
    dones: Bool[Tensor, "batch timesteps"], 
    rewards: Float[Tensor, "batch timesteps"]
) -> Float[Tensor, "batch timesteps"]:
    """Negamax value targets for a batch of `B` self-play games of `T` plies.

    Walking each game backwards from its terminal rewards, the target at each ply is the game's final
    reward with its sign flipped once per step back. Every recorded value in this project is from the
    perspective of the player about to move; stepping back one ply changes whose turn it is, hence the
    negation (negamax: good for the mover is bad for its parent).

    Args:
        dones: (batch, timesteps) marks the ply where each game ended
        rewards:  (batch, timesteps) rewards to the mover at each ply (nonzero only where dones)

    Returns:
        (batch, timesteps) the mover-perspective outcome `z` for every recorded state
    """
    batch, timesteps = dones.shape
    z = torch.zeros((batch, timesteps), device=dones.device)
    raise NotImplementedError()
    return z


tests.test_compute_z_targets(compute_z_targets)
Solution
def compute_z_targets(
    dones: Bool[Tensor, "batch timesteps"], 
    rewards: Float[Tensor, "batch timesteps"]
) -> Float[Tensor, "batch timesteps"]:
    """Negamax value targets for a batch of `B` self-play games of `T` plies.

    Walking each game backwards from its terminal rewards, the target at each ply is the game's final
    reward with its sign flipped once per step back. Every recorded value in this project is from the
    perspective of the player about to move; stepping back one ply changes whose turn it is, hence the
    negation (negamax: good for the mover is bad for its parent).

    Args:
        dones: (batch, timesteps) marks the ply where each game ended
        rewards:  (batch, timesteps) rewards to the mover at each ply (nonzero only where dones)

    Returns:
        (batch, timesteps) the mover-perspective outcome `z` for every recorded state
    """
    batch, timesteps = dones.shape
    z = torch.zeros((batch, timesteps), device=dones.device)
    running = torch.zeros((batch,), device=dones.device)
    for t in range(timesteps - 1, -1, -1):
        running = torch.where(dones[:, t], rewards[:, t], -running)
        z[:, t] = running
    return z


tests.test_compute_z_targets(compute_z_targets)

The AlphaZero loss: compute_az_loss

Given the network's value (N,) and logits (N,7) on a minibatch, and the targets pi (N,7) (MCTS visit distribution) and z (N,) (game outcome), the loss is

$$\mathcal L = \underbrace{-\sum_a \pi_a \log \text{softmax}(\text{logits})_a}_{\text{policy cross-entropy}} \;+\; c_v \underbrace{(\text{value} - z)^2}_{\text{value MSE}},$$

averaged over the minibatch (the weight decay is not included explicitly in the loss as it is handled by the optimiser).

Exercise - implement compute_az_loss

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 10-15 minutes on this exercise.
def compute_az_loss(
    value: Float[Tensor, "N"],
    logits: Float[Tensor, "N 7"],
    pi: Float[Tensor, "N 7"],
    z: Float[Tensor, "N"],
    value_coef: float = 1.0,
) -> Float[Tensor, ""]:
    """Scalar AlphaZero loss over a minibatch of `N` positions: policy cross-entropy + value MSE.

    Loss = mean of `-sum_a pi_a log softmax(logits)_a` + `value_coef * (value - z)^2`.

    Args:
        value:      (N,) critic outputs
        logits:     (N, 7) actor outputs
        pi:         (N, 7) MCTS visit-count policy target
        z:          (N,) game-outcome value target
        value_coef: weight on the value-MSE term

    Returns:
        scalar tensor: the mean total loss
    """
    assert value.shape == z.shape
    assert logits.shape == pi.shape
    raise NotImplementedError()


tests.test_compute_az_loss(compute_az_loss)
Solution
def compute_az_loss(
    value: Float[Tensor, "N"],
    logits: Float[Tensor, "N 7"],
    pi: Float[Tensor, "N 7"],
    z: Float[Tensor, "N"],
    value_coef: float = 1.0,
) -> Float[Tensor, ""]:
    """Scalar AlphaZero loss over a minibatch of `N` positions: policy cross-entropy + value MSE.

    Loss = mean of `-sum_a pi_a log softmax(logits)_a` + `value_coef * (value - z)^2`.

    Args:
        value:      (N,) critic outputs
        logits:     (N, 7) actor outputs
        pi:         (N, 7) MCTS visit-count policy target
        z:          (N,) game-outcome value target
        value_coef: weight on the value-MSE term

    Returns:
        scalar tensor: the mean total loss
    """
    assert value.shape == z.shape
    assert logits.shape == pi.shape
    logprobs = F.log_softmax(logits, dim=-1)
    policy_loss = -(pi * logprobs).sum(-1).mean()
    critic_loss = F.mse_loss(value, z)
    # alternative non-mse solution:
    # critic_loss = ((value - z) ** 2).mean()
    return policy_loss + value_coef * critic_loss


tests.test_compute_az_loss(compute_az_loss)

Training Hyperparameters

@dataclass
class AZConfig:
    """All the knobs for self-play + training. The defaults are a fast in-notebook recipe (4096
    self-play games/gen, 32 sims/move, 6 generations) that reaches ~85% Pons-solver accuracy in
    ~4-5 min on a GPU. For a stronger agent raise `sims` to 64 and `num_generations` to ~50
    (≈182k optimiser steps); to run faster still, dial `num_games` / `sims` / `num_generations` down.

    c_puct (1.5->1.0), lr (1e-3->5e-3) and buffer_gens (8->4) were tuned by a
    full-run sweep (pons_CE ~0.466->0.420 at ~1/4 the compute); every gain came from fresher self-play
    data — capacity/loss/regularisation were inert. Usable lr band ~3e-3..6e-3 (≥7e-3 is seed-unstable)."""
    # self-play / data
    num_games: int = 4096          # parallel self-play games per generation
    sims: int = 16                 # MCTS simulations per move (32≈64 for learning but ~2x faster self-play; raise to 64 for a stronger run)
    num_generations: int = 12       # training generations (~85% Pons acc in ~4-5 min; raise to ~50 for the full recipe)
    buffer_gens: int = 4           # replay buffer = the last this-many generations (tuned 8->4: fresher data)
    moves_per_gen: int = 42        # plies per generation (a full Connect-4 game)
    temperature: float = 1.0       # visit-count sampling temperature (first `temp_cutoff` plies)
    temp_cutoff: int = 12          # after this many plies, play greedily
    augment: bool = True           # mirror-symmetry data augmentation
    # MCTS / exploration
    c_puct: float = 1.0               # tuned 1.5->1.0 (full-run sweep, both seeds)
    max_depth: int = 42
    dirichlet_alpha: float = 10 / 7   # ≈ 1.43, root exploration-noise concentration
    dirichlet_eps: float = 0.25       # weight of the root Dirichlet noise
    # optimiser / schedule
    lr: float = 5e-3               # initial learning rate (tuned 1e-3->5e-3; >=7e-3 is seed-unstable)
    lr_min: float = 2e-5           # cosine-decay target over the run
    weight_decay: float = 1e-4
    grad_clip: float = 1.0         # global grad-norm clip
    minibatch: int = 1024
    value_coef: float = 1.0        # weight on the value-MSE loss term
    # logging
    use_wandb: bool = False        # log loss / lr / Pons metrics to Weights & Biases
    wandb_project: str = "alphazero-connect4"

The replay buffer (given)

Self-play and training talk through a small given ReplayBuffer, in the same spirit as the VPG rollout buffer from [2.2]. It preallocates the per-generation rollout tensors once, you write one ply at a time, and end_generation turns the finished rollout into flat (obs, pi, z) training examples (with the negamax value targets computed and dropping states whose game didn't finish) keeping the last cfg.buffer_gens generations. get_dataloader then hands back a shuffled DataLoader over those examples. This keeps the exercises tiny: self_play_step just writes each ply, and the training loop just iterates get_dataloader(...) calling your training_step.

class ReplayBuffer:
    """Given. Rolling replay of the last `cfg.buffer_gens` self-play generations.

    Usage: `write(...)` one ply at a time, `end_generation()` after each generation (computes value
    targets, drops unfinished states, flattens over (game, ply), evicts the oldest generation), then
    `get_dataloader(mb)` for a shuffled training `DataLoader`. Mirrors the VPG rollout buffer from [2.2]."""

    def __init__(self, cfg: AZConfig, device):
        self.cfg, self.device = cfg, device
        B, T = cfg.num_games, cfg.moves_per_gen
        # preallocated rollout for the CURRENT generation (written one ply at a time, in place)
        self.obs = torch.empty((B, T, 3, 6, 7), device=device)
        self.pi = torch.empty((B, T, 7), device=device)
        self.dones = torch.empty((B, T), dtype=torch.bool, device=device)
        self.rews = torch.empty((B, T), device=device)
        self.t = 0                 # next free ply slot in the current generation
        self.gens = []             # rolling list of finished generations, each a flat (obs, pi, z)

    def write(self, obs_canon, pi, done, reward):
        """Record one ply of the current generation into row `self.t`, then advance."""
        self.obs[:, self.t] = obs_canon
        self.pi[:, self.t] = pi
        self.dones[:, self.t] = done
        self.rews[:, self.t] = reward
        self.t += 1

    def end_generation(self):
        """Finish the current generation: compute negamax `z`, keep only states whose game finished
        (reverse cumulative-OR of dones over time), flatten over (game, ply), append to the rolling
        buffer (evicting the oldest), and reset the write pointer."""
        z = compute_z_targets(self.dones, self.rews)                          # (B, T)
        keep = (self.dones.int().flip(-1).cumsum(-1).flip(-1) > 0).reshape(-1)
        self.gens.append((self.obs.reshape(-1, 3, 6, 7)[keep],
                          self.pi.reshape(-1, 7)[keep],
                          z.reshape(-1)[keep]))
        if len(self.gens) > self.cfg.buffer_gens:
            self.gens.pop(0)
        self.t = 0

    def get_dataloader(self, batch_size):
        """Snapshot the whole buffer into a `DataLoader` over `(obs, pi, z)` training examples.

        The DataLoader handles shuffling + batching internally; iterate it once per epoch (it
        reshuffles each time). `drop_last` keeps every batch exactly `batch_size` (great for a fixed
        compiled forward), but only when there's at least one full batch, so tiny configs still train.
        Tensors already live on the GPU, so the default `num_workers=0` / no pinning is correct."""
        obs = torch.cat([g[0] for g in self.gens])
        pi = torch.cat([g[1] for g in self.gens])
        z = torch.cat([g[2] for g in self.gens])
        ds = TensorDataset(obs, pi, z)
        return DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=len(ds) >= batch_size)

    def reset(self):
        self.gens, self.t = [], 0

    def __len__(self):
        return sum(g[0].shape[0] for g in self.gens)

Evaluating against a perfect solver (given)

Recall that Connect4 is a solved game, and fast solvers exist that can in milliseconds report back the optimal move for any given position (with the use of a pre-computed opening bible.

We make use of the Pascal Pons' perfect Connect-4 solver, and pregenerate ~6700 board positions that have been labelled with the optimal move for each position, as well as pre-filtered for positions that are decisive (i.e. it is not the case that all moves lead to a loss/win/draw under optimal play on both sides).

The given evaluate_policy runs the network over this set in one batched forward pass, and reports back the metrics:

  • acc — fraction of positions where the policy's top move is solver-optimal,
  • ce — the policy cross-entropy against the optimal-move set: $ - \log \sum_{a : a \text{ is optimal}} p_\theta(a|s) $, a smoother signal than acc.
  • val_signacc — fraction of decisive positions where the value head has the right sign (predicts win vs loss correctly).

This eval is far cheaper than playing out games as it requires only one batched forward pass, while providing a lot of feedback on how well the policy is performing. You should see acc climb (and ce fall) over a handful of generations.

If you like, have a look in the pascal_pons directory to see how the dataset was generated. It's included in the repo for convenience. It basically downloads the opening book and Pascal Pons' solver, generates a series of games between epsilon-greedy optimal policies, and labels the positions with the optimal move. It then filters out positions that are not decisive or duplicates, and saves the remaining positions to a file.

It's also avaliable here on HuggingFace just in case, which will auto-download the dataset the first time you run evaluate_policy.

from pascal_pons.eval_pons import evaluate_policy

Exercise - implement self_play_step

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 10-15 minutes on this exercise.

self_play_step should run one ply of self-play. Given the current boards obs and who is to move to_move (all num_games games at once):

  1. Search: run the batched MCTS to get the root visit counts.
  2. Policy target: turn the visit counts into the target policy pi (normalised visit counts).
  3. Act: sample an action from the tree policy (self.sample_actions) and step the environment.
  4. Record the ply into the given replay buffer: self.buffer.write(obs_canon, pi, done, reward), where obs_canon = canonicalise_obs(obs, to_move) is the board from the mover's perspective.

Return (next_obs, done). The given self_play loop calls this moves_per_gen times, uses done to advance to_move (auto-reset), and calls self.buffer.end_generation() at the end. The buffer handles the value targets, masking out unfinished states, the rolling replay, and serving minibatches.

Exercise - implement sample_actions

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵⚪⚪⚪
You should spend up to 10-15 minutes on this exercise.

Implement the sample_actions function, which samples actions from the root visit counts using the temperature parameter.

$$ \pi(a | s) = \frac{N(s, a)^{1 / \tau}}{\sum_{a'} N(s, a')^{1 / \tau}} $$
Use torch.multinomial to sample an action from the distribution.

Exercise - implement training_step

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 10-15 minutes on this exercise.

training_step : one optimiser step on a minibatch (obs, pi, z): * forward the net, compute the AlphaZero loss with compute_az_loss, do an optimizer step, and then return the scalar loss. Optionally, you can also clip the gradient to cfg.grad_clip with nn.utils.clip_grad_norm_.

class AlphaZeroTrainer:
    """Owns the full per-generation loop pictured above. Each generation `train()` runs:

    1. `self_play()`: `moves_per_gen` plies of batched self-play with the frozen net
       (your `self_play_step` per ply, writing `(obs_canon, π, done, reward)` into `self.buffer`),
       then `buffer.end_generation()` turns the rollout into flat `(obs, π, z)` training rows.
    2. one supervised pass over `buffer.get_dataloader(minibatch)` (your `training_step` per batch),
    3. a cosine LR-schedule step and (periodically) `evaluate()` against the Pons solver.

    Attributes: `env` (Connect4Env), `cfg` (AZConfig), `model` (Connect4Model), `opt` (AdamW),
    `mcts` (BatchedMCTS — built from cfg's sims/c_puct/dirichlet settings), `buffer` (ReplayBuffer).
    """
    def __init__(self, env, cfg, model):
        self.env = env
        self.cfg = cfg
        self.device = env.device
        self.model = model
        self.opt = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
        self.mcts = BatchedMCTS(env, MCTSConfig(
            sims=cfg.sims, c_puct=cfg.c_puct, max_depth=cfg.max_depth,
            dirichlet_alpha=cfg.dirichlet_alpha, dirichlet_eps=cfg.dirichlet_eps))
        self.buffer = ReplayBuffer(cfg, self.device)


    def sample_actions(self, root_N: Float[Tensor, "B 7"], temperature: float = 1.0) -> Float[Tensor, "B"]:
        """Sample one action per game from the tree policy π(a) ∝ N(s,a)^(1/τ).

        Args:
            root_N:      (B, 7) root visit counts from the MCTS search
            temperature: τ; 1.0 samples proportionally to visit counts, →0 approaches greedy
                         argmax, larger flattens the distribution

        Returns:
            (B, 1) sampled column indices (one per game). Actions with zero visits have
            probability 0 and are never sampled. (No need to special-case tiny temperatures here —
            the trainer always calls this with moderate τ; `sample_tree_policy` in section 2 is the one
            that must handle the greedy τ→0 limit.)
        """
        raise NotImplementedError()

    @torch.no_grad()
    def self_play_step(self, obs, to_move):
        """One ply of self-play for all `num_games` games at once: MCTS -> policy target -> record the
        ply into `self.buffer` -> sample -> step.

        Args:
            obs:     (B, 3, 6, 7) current (absolute) boards
            to_move: (B,) whether player-1 (red) is to move in each game

        Returns:
            next_obs: (B, 3, 6, 7) boards after the move
            done:     (B,) whether the move ended each game (the loop uses this to flip `to_move`)
        """
        raise NotImplementedError()

    @torch.no_grad()
    def self_play(self, progress: bool = True):
        """Play one generation: `num_games` games for `moves_per_gen` plies, calling your
        `self_play_step` each ply (which `write`s into `self.buffer`), then finalise the generation in
        the buffer. The buffer handles the value targets, masking and replay -- nothing to return.

        The ply loop is wrapped in a tqdm bar with `unit_scale` (like the 2.2 SPS bar) so you get a
        live, SI-formatted env-steps/sec readout during self-play rather than one update per generation."""
        B, T = self.cfg.num_games, self.cfg.moves_per_gen
        obs = self.env.reset(B)
        to_move = torch.ones((B,), dtype=torch.bool, device=self.device)
        self.model.eval()
        # each ply runs `sims` batched env.steps over all B games (one per MCTS simulation), so the
        # generation does B * T * sims env transitions in total -- tqdm turns that into a live rate,
        # and the postfix shows cumulative MCTS simulations done out of `moves_per_gen * sims`.
        total_sims = T * self.cfg.sims
        bar = tqdm(total=B * T * self.cfg.sims, unit=" env steps", unit_scale=True,
                   desc="self-play", leave=False, disable=not progress)
        for ply in range(T):
            obs, done = self.self_play_step(obs, to_move)
            to_move = torch.where(done, torch.ones_like(to_move), ~to_move)   # auto-reset -> player 1
            bar.update(B * self.cfg.sims)
            bar.set_postfix_str(f"sims {(ply + 1) * self.cfg.sims}/{total_sims}")
        bar.close()
        self.buffer.end_generation()

    def training_step(self, obs, pi, z):
        """One optimiser step on a single minibatch `(obs, pi, z)`: forward the net, compute the
        AlphaZero loss (`compute_az_loss`), zero the grads (`self.opt.zero_grad(set_to_none=True)`),
        backprop, clip the gradient norm to `cfg.grad_clip`, and step the optimiser. AlphaZero's update is just supervised learning -- regress the value head onto
        `z` and the policy head onto `pi`.

        Args:
            obs: (mb, 3, 6, 7) mover-canonical boards
            pi:  (mb, 7) MCTS policy targets
            z:   (mb,) value targets

        Returns:
            float: the minibatch loss
        """
        raise NotImplementedError()

    @torch.no_grad()
    def evaluate(self) -> dict:
        """Given. Score the current network against the frozen Pons solver set (policy accuracy /
        cross-entropy / value sign-accuracy). One cached forward pass; see the section above."""
        return evaluate_policy(self.model, self.env)

    def train(self, num_generations=None, eval_every=1):
        """Given. The full training loop -- you don't need to touch this. Each generation: run
        self-play into the buffer, do one supervised pass over the buffer's `DataLoader`
        (calling your `training_step`), step the cosine LR schedule, and periodically `evaluate`
        against the Pons solver. Logs loss / lr / eval to a tqdm bar (and to wandb if `cfg.use_wandb`)."""
        import time
        num_generations = num_generations or self.cfg.num_generations
        # cosine-decay the LR; schedule over a >=10-gen horizon so a short quick run doesn't crater the
        # LR before it finishes (a 6-gen T_max=6 cosine would decay to lr_min by the last gen and stall).
        sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=max(num_generations, 10), eta_min=self.cfg.lr_min)
        # effective env transitions/sec for the tqdm bar: each generation steps the env
        # num_games * moves_per_gen * sims times (one batched env.step per MCTS simulation). It's just a
        # counter + a clock, so it adds no per-step overhead -- handy for spotting when self-play bottlenecks.
        t0, env_steps = time.time(), 0
        steps_per_gen = self.cfg.num_games * self.cfg.moves_per_gen * self.cfg.sims
        if self.cfg.use_wandb:
            import wandb
            wandb.init(project=self.cfg.wandb_project, config=asdict(self.cfg))
        metrics = {}
        bar = tqdm(range(1, num_generations + 1))
        for gen in bar:
            self.self_play()                                       # fill + roll the replay buffer
            env_steps += steps_per_gen
            self.model.train()
            loader = self.buffer.get_dataloader(self.cfg.minibatch)
            total_loss, n_batches = 0.0, 0
            tbar = tqdm(loader, desc="train", leave=False)        # one supervised pass over the buffer
            for obs, pi, z in tbar:
                total_loss += self.training_step(obs, pi, z)
                n_batches += 1
                tbar.set_postfix_str(f"loss={total_loss / n_batches:.3f}")   # live running-mean loss
            loss = total_loss / max(n_batches, 1)
            sched.step()
            if eval_every and gen % eval_every == 0:
                metrics = self.evaluate()                          # Pons solver benchmark
            lr = sched.get_last_lr()[0]
            sps = env_steps / max(time.time() - t0, 1e-9)          # effective env steps/sec (cumulative)
            bar.set_postfix_str(f"loss={loss:.3f}  acc={metrics.get('pons/acc', float('nan')):.3f}  "
                                f"ce={metrics.get('pons/ce', float('nan')):.3f}  env/s={fmt_si(sps)}")
            if self.cfg.use_wandb:
                wandb.log({"generation": gen, "loss": loss, "lr": lr, "env_steps_per_sec": sps, **metrics})
        if self.cfg.use_wandb:
            wandb.finish()
        return self.model


tests.test_sample_actions(AlphaZeroTrainer)
tests.test_self_play_step(AlphaZeroTrainer)
tests.test_training_step(AlphaZeroTrainer)
Solution
class AlphaZeroTrainer:
    """Owns the full per-generation loop pictured above. Each generation `train()` runs:

    1. `self_play()`: `moves_per_gen` plies of batched self-play with the frozen net
       (your `self_play_step` per ply, writing `(obs_canon, π, done, reward)` into `self.buffer`),
       then `buffer.end_generation()` turns the rollout into flat `(obs, π, z)` training rows.
    2. one supervised pass over `buffer.get_dataloader(minibatch)` (your `training_step` per batch),
    3. a cosine LR-schedule step and (periodically) `evaluate()` against the Pons solver.

    Attributes: `env` (Connect4Env), `cfg` (AZConfig), `model` (Connect4Model), `opt` (AdamW),
    `mcts` (BatchedMCTS — built from cfg's sims/c_puct/dirichlet settings), `buffer` (ReplayBuffer).
    """
    def __init__(self, env, cfg, model):
        self.env = env
        self.cfg = cfg
        self.device = env.device
        self.model = model
        self.opt = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
        self.mcts = BatchedMCTS(env, MCTSConfig(
            sims=cfg.sims, c_puct=cfg.c_puct, max_depth=cfg.max_depth,
            dirichlet_alpha=cfg.dirichlet_alpha, dirichlet_eps=cfg.dirichlet_eps))
        self.buffer = ReplayBuffer(cfg, self.device)


    def sample_actions(self, root_N: Float[Tensor, "B 7"], temperature: float = 1.0) -> Float[Tensor, "B"]:
        """Sample one action per game from the tree policy π(a) ∝ N(s,a)^(1/τ).

        Args:
            root_N:      (B, 7) root visit counts from the MCTS search
            temperature: τ; 1.0 samples proportionally to visit counts, →0 approaches greedy
                         argmax, larger flattens the distribution

        Returns:
            (B, 1) sampled column indices (one per game). Actions with zero visits have
            probability 0 and are never sampled. (No need to special-case tiny temperatures here —
            the trainer always calls this with moderate τ; `sample_tree_policy` in section 2 is the one
            that must handle the greedy τ→0 limit.)
        """
        temp_visits = root_N ** (1 / temperature)
        probs = temp_visits / temp_visits.sum(-1, keepdim=True)
        action = torch.multinomial(probs, num_samples=1)
        return action

    @torch.no_grad()
    def self_play_step(self, obs, to_move):
        """One ply of self-play for all `num_games` games at once: MCTS -> policy target -> record the
        ply into `self.buffer` -> sample -> step.

        Args:
            obs:     (B, 3, 6, 7) current (absolute) boards
            to_move: (B,) whether player-1 (red) is to move in each game

        Returns:
            next_obs: (B, 3, 6, 7) boards after the move
            done:     (B,) whether the move ended each game (the loop uses this to flip `to_move`)
        """
        root_N = self.mcts.search(self.model, obs, to_move, add_noise=True)   # root noise -> exploration
        pi = root_N / root_N.sum(-1, keepdim=True)
        obs_canon = canonicalise_obs(obs, to_move)
        action = self.sample_actions(root_N, self.cfg.temperature)
        next_obs, done, reward = self.env.step(obs, action, to_move)
        self.buffer.write(obs_canon, pi, done, reward)   # record this ply (mover-canonical board + targets)
        return next_obs, done

    @torch.no_grad()
    def self_play(self, progress: bool = True):
        """Play one generation: `num_games` games for `moves_per_gen` plies, calling your
        `self_play_step` each ply (which `write`s into `self.buffer`), then finalise the generation in
        the buffer. The buffer handles the value targets, masking and replay -- nothing to return.

        The ply loop is wrapped in a tqdm bar with `unit_scale` (like the 2.2 SPS bar) so you get a
        live, SI-formatted env-steps/sec readout during self-play rather than one update per generation."""
        B, T = self.cfg.num_games, self.cfg.moves_per_gen
        obs = self.env.reset(B)
        to_move = torch.ones((B,), dtype=torch.bool, device=self.device)
        self.model.eval()
        # each ply runs `sims` batched env.steps over all B games (one per MCTS simulation), so the
        # generation does B * T * sims env transitions in total -- tqdm turns that into a live rate,
        # and the postfix shows cumulative MCTS simulations done out of `moves_per_gen * sims`.
        total_sims = T * self.cfg.sims
        bar = tqdm(total=B * T * self.cfg.sims, unit=" env steps", unit_scale=True,
                   desc="self-play", leave=False, disable=not progress)
        for ply in range(T):
            obs, done = self.self_play_step(obs, to_move)
            to_move = torch.where(done, torch.ones_like(to_move), ~to_move)   # auto-reset -> player 1
            bar.update(B * self.cfg.sims)
            bar.set_postfix_str(f"sims {(ply + 1) * self.cfg.sims}/{total_sims}")
        bar.close()
        self.buffer.end_generation()

    def training_step(self, obs, pi, z):
        """One optimiser step on a single minibatch `(obs, pi, z)`: forward the net, compute the
        AlphaZero loss (`compute_az_loss`), zero the grads (`self.opt.zero_grad(set_to_none=True)`),
        backprop, clip the gradient norm to `cfg.grad_clip`, and step the optimiser. AlphaZero's update is just supervised learning -- regress the value head onto
        `z` and the policy head onto `pi`.

        Args:
            obs: (mb, 3, 6, 7) mover-canonical boards
            pi:  (mb, 7) MCTS policy targets
            z:   (mb,) value targets

        Returns:
            float: the minibatch loss
        """
        value, logits = self.model(obs.contiguous())
        loss = compute_az_loss(value, logits, pi, z, self.cfg.value_coef)
        self.opt.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
        self.opt.step()
        return float(loss.item())

    @torch.no_grad()
    def evaluate(self) -> dict:
        """Given. Score the current network against the frozen Pons solver set (policy accuracy /
        cross-entropy / value sign-accuracy). One cached forward pass; see the section above."""
        return evaluate_policy(self.model, self.env)

    def train(self, num_generations=None, eval_every=1):
        """Given. The full training loop -- you don't need to touch this. Each generation: run
        self-play into the buffer, do one supervised pass over the buffer's `DataLoader`
        (calling your `training_step`), step the cosine LR schedule, and periodically `evaluate`
        against the Pons solver. Logs loss / lr / eval to a tqdm bar (and to wandb if `cfg.use_wandb`)."""
        import time
        num_generations = num_generations or self.cfg.num_generations
        # cosine-decay the LR; schedule over a >=10-gen horizon so a short quick run doesn't crater the
        # LR before it finishes (a 6-gen T_max=6 cosine would decay to lr_min by the last gen and stall).
        sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=max(num_generations, 10), eta_min=self.cfg.lr_min)
        # effective env transitions/sec for the tqdm bar: each generation steps the env
        # num_games * moves_per_gen * sims times (one batched env.step per MCTS simulation). It's just a
        # counter + a clock, so it adds no per-step overhead -- handy for spotting when self-play bottlenecks.
        t0, env_steps = time.time(), 0
        steps_per_gen = self.cfg.num_games * self.cfg.moves_per_gen * self.cfg.sims
        if self.cfg.use_wandb:
            import wandb
            wandb.init(project=self.cfg.wandb_project, config=asdict(self.cfg))
        metrics = {}
        bar = tqdm(range(1, num_generations + 1))
        for gen in bar:
            self.self_play()                                       # fill + roll the replay buffer
            env_steps += steps_per_gen
            self.model.train()
            loader = self.buffer.get_dataloader(self.cfg.minibatch)
            total_loss, n_batches = 0.0, 0
            tbar = tqdm(loader, desc="train", leave=False)        # one supervised pass over the buffer
            for obs, pi, z in tbar:
                total_loss += self.training_step(obs, pi, z)
                n_batches += 1
                tbar.set_postfix_str(f"loss={total_loss / n_batches:.3f}")   # live running-mean loss
            loss = total_loss / max(n_batches, 1)
            sched.step()
            if eval_every and gen % eval_every == 0:
                metrics = self.evaluate()                          # Pons solver benchmark
            lr = sched.get_last_lr()[0]
            sps = env_steps / max(time.time() - t0, 1e-9)          # effective env steps/sec (cumulative)
            bar.set_postfix_str(f"loss={loss:.3f}  acc={metrics.get('pons/acc', float('nan')):.3f}  "
                                f"ce={metrics.get('pons/ce', float('nan')):.3f}  env/s={fmt_si(sps)}")
            if self.cfg.use_wandb:
                wandb.log({"generation": gen, "loss": loss, "lr": lr, "env_steps_per_sec": sps, **metrics})
        if self.cfg.use_wandb:
            wandb.finish()
        return self.model


tests.test_sample_actions(AlphaZeroTrainer)
tests.test_self_play_step(AlphaZeroTrainer)
tests.test_training_step(AlphaZeroTrainer)

Train your agent!

Put it all together: build the config, the model and the trainer, and call trainer.train(). The given training loop handles self-play, the supervised passes over the buffer (your training_step), the cosine LR schedule, the periodic Pons eval, and the tqdm / wandb logging. So there's nothing to wire up here. Set use_wandb=True in the config to stream loss / lr / pons/* to Weights & Biases.

The defaults are a fast recipe (4096 games/gen, 16 sims, 12 generations) that reaches ~85% Pons-solver accuracy in ~4-5 min on a GPU; raise sims to 64 and num_generations to ~50 for a stronger agent, or dial them down to go faster. On a GPU you should see acc climb and ce fall each generation.

cfg = AZConfig()   # fast recipe (4096 games, 16 sims, 12 gens) -> ~85% Pons acc in ~4-5 min on a GPU
model = Connect4Model(device)

if TRAINING:
    # `BatchedMCTS` is built later, in the Vectorized-MCTS bonus. So training runs even if you
    # haven't done that bonus yet, pull the finished class from `solutions` when not yet defined.
    if "BatchedMCTS" not in globals():
        from solutions import BatchedMCTS
    trainer = AlphaZeroTrainer(env, cfg, model)
    trainer.train()   # eval + logging handled inside; set cfg.use_wandb=True to log to wandb

Play against your agent

Time to see how good it is! utils gives you two ready-made ways to play your trained model (which we've implemented for you):

  • Web app (recommended): play_web(model, env, port=8080) serves a little browser game at http://localhost:8080. In VSCode / Cursor the port is auto-forwarded, so it should open straight away. You (🔴) move first and the agent (🟡) replies with MCTS. Interrupt the cell (■ / Ctrl-C) to stop.
  • Terminal (fallback): play_cli(model, env) plays the same game in the cell output — type a legal column (0-6) on your turn.
from part5_mcts_alphazero.utils import play_web, play_cli

play_web(model, env, port=8080)   # browser game on http://localhost:8080 (auto-forwarded); ■/Ctrl-C to stop
# play_cli(model, env)            # ...or play in the terminal instead