5️⃣ Bonus

Claude has some suggestions for you. I personally haven't vetted the below, so take with a grain of salt.


Some directions if you have time:

  • Dirichlet exploration noise at the root. Classic AlphaZero mixes a little Dirichlet noise into the root prior on every search — $P(s_0, a) = (1-\epsilon)\, p_\theta(s_0,a) + \epsilon\, \eta$ with $\eta \sim \mathrm{Dir}(\alpha)$ — so self-play occasionally tries moves the current policy underrates instead of collapsing onto the prior's favourite. The provided search already implements this behind the add_noise flag (with dirichlet_eps / dirichlet_alpha on the config), but it's off by default: on Connect4 at this scale the agent trains basically fine without it. In an ablation (noise on vs off, same seed) it gave only a modest, noisy edge in the Pons optimal-move accuracy — the no-noise run stalled mid-training but had caught up by the end. Turn it on (pass add_noise=True in self_play's search call), sweep dirichlet_eps and dirichlet_alpha, and measure whether it actually helps. Does the benefit grow on a bigger board, with more simulations, or with more training generations?
  • Temperature schedule. AlphaZero samples with temperature 1 for the first few moves of each game (for opening diversity), then plays greedily. Add a per-move temperature schedule to self_play and see whether it helps.
  • Tune the search. How does strength (Pons optimal-move accuracy) change with sims (simulations per move) and c_puct? Plot it — see the "strength vs search budget" bonus above. (More play-time sims at evaluation makes the agent stronger without any retraining.)
  • Subtree reuse. Between consecutive moves of one game, the new root is a child of the old root — its subtree is already partly searched. Reuse it instead of starting from scratch.
  • Bigger network. Add more residual blocks or channels. Where are the diminishing returns?
  • Play it yourself. The research code ships a terminal and browser-based UI (play_cli.py, play_web.py) — load your trained checkpoint and try to beat it. Can you?
  • Compare to PPO self-play. How does AlphaZero compare to training the same network with the PPO self-play from [2.3]? Which is more sample-efficient here, and why?

Exercise - data augmentation by mirror symmetry

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-20 minutes on this exercise.

Connect 4 is left-right mirror-symmetric: reflecting the board across the centre column gives a strategically identical position. So every self-play example (obs, pi, z) comes with a free twin — reflect the board, reverse the action distribution column-wise (column $c \leftrightarrow 6 - c$), and keep the value unchanged. Training on both doubles your data at zero self-play cost. (This is a standard AlphaZero trick; AlphaGo Zero exploited all 8 symmetries of the Go board.)

Implement augment_with_mirror, returning the batch concatenated with its mirror image. The simplest place to apply it is at the top of training_step (doubling each minibatch); augmenting in the buffer's DataLoader works too, the trainer just sees bigger batches either way. Then call it on each batch inside the trainer (e.g. at the top of training_step, or when the buffer builds its DataLoader) and see whether the agent reaches a given strength in fewer self-play games.

def augment_with_mirror(
    obs: Float[Tensor, "batch 3 H W"],
    pi: Float[Tensor, "batch 7"],
    z: Float[Tensor, "batch"],
) -> tuple[Float[Tensor, "b2 3 H W"], Float[Tensor, "b2 7"], Float[Tensor, "b2"]]:
    """Concatenate (obs, pi, z) with their left-right mirror image (Connect-4's only symmetry).

    Args:
        obs: (B, 3, H, W) boards
        pi:  (B, 7) policy targets
        z:   (B,) value targets

    Returns:
        obs: (2B, 3, H, W) original + width-flipped boards
        pi:  (2B, 7) original + column-reversed policies
        z:   (2B,) value targets, duplicated unchanged
    """
    raise NotImplementedError()


tests.test_augment_with_mirror(augment_with_mirror)
Solution
def augment_with_mirror(
    obs: Float[Tensor, "batch 3 H W"],
    pi: Float[Tensor, "batch 7"],
    z: Float[Tensor, "batch"],
) -> tuple[Float[Tensor, "b2 3 H W"], Float[Tensor, "b2 7"], Float[Tensor, "b2"]]:
    """Concatenate (obs, pi, z) with their left-right mirror image (Connect-4's only symmetry).

    Args:
        obs: (B, 3, H, W) boards
        pi:  (B, 7) policy targets
        z:   (B,) value targets

    Returns:
        obs: (2B, 3, H, W) original + width-flipped boards
        pi:  (2B, 7) original + column-reversed policies
        z:   (2B,) value targets, duplicated unchanged
    """
    obs_m = obs.flip(dims=[-1])   # reflect the board across the centre column (width is the last dim)
    pi_m = pi.flip(dims=[-1])     # column c <-> column 6 - c
    return torch.cat([obs, obs_m]), torch.cat([pi, pi_m]), torch.cat([z, z])


tests.test_augment_with_mirror(augment_with_mirror)

Bonus - strength vs search budget

A trained AlphaZero net can be made stronger at play time just by searching more — no retraining. With M = 0 simulations the agent plays its raw policy head (no planning — exactly the cheap eval we run each generation); with M > 0 it runs MCTS for M sims per move. The given evaluate_with_search (from pascal_pons.eval_pons) replays the frozen Pons set and, for each budget M, measures how often the agent's chosen move lands in the solver's optimal set — i.e. strength against perfect play, a sharper signal than win-rate against any fixed bot. The sweep over M ∈ {0, 1, 2, 4, 8, 16, 32, 64} runs MCTS over the whole set at each budget, so it's SLOW — set SLOW = True at the top to run it. You should see accuracy climb with M.

from pascal_pons.eval_pons import evaluate_with_search

if SLOW:   # slow (runs MCTS over the whole Pons set at each budget); set SLOW=True at the top to enable
    import matplotlib.pyplot as plt

    sims_list = [0, 1, 2, 4, 8, 16, 32, 64]
    curve = evaluate_with_search(trainer.model, env, sims_list=sims_list)
    for M in sims_list:
        tag = "  (raw policy, no planning)" if M == 0 else ""
        print(f"M={M:3d} sims{tag:<27}: optimal-move acc vs solver = {curve[M]['acc']:.3f}")

    accs = [curve[M]["acc"] for M in sims_list]
    fig, ax = plt.subplots(figsize=(7, 4.5))
    ax.plot(range(len(sims_list)), accs, "o-")
    ax.set_xticks(range(len(sims_list))); ax.set_xticklabels(sims_list)
    ax.set_xlabel("MCTS simulations per move  (M=0 → raw policy, no planning)")
    ax.set_ylabel("optimal-move accuracy vs solver"); ax.set_ylim(0, 1)
    ax.grid(alpha=0.3); ax.set_title("Strength scales with search budget (no retraining)")
    fig.tight_layout()

Bonus - the AlphaZero scaling law: Elo vs log(search)

The curve above shows accuracy vs a perfect solver, which saturates once the agent plays near- optimally. A cleaner way to see how much search alone is worth is a self-play ladder: take the same trained network and have it play itself at different simulation budgets, then fit an Elo rating to the round-robin results. Plotting Elo against $\log_2(\text{sims})$ reproduces the well-known AlphaZero result that playing strength is roughly linear in the log of the search budget — every doubling of thinking time buys a roughly constant Elo gain, with no change to the weights.

(This is SLOW: it runs a full round-robin of MCTS-vs-MCTS matches. Set SLOW = True to run it, ideally on a strong network — load one of the pretrained checkpoints/az_step_*.pt into trainer.model.)

@torch.no_grad()
def _ladder_action(model, env, obs, is_player1, sims):
    """Move for the side to move: raw policy if sims == 0, else MCTS with `sims` simulations."""
    if sims == 0:
        return greedy_policy_action(model, canonicalise_obs(obs, is_player1))
    return BatchedMCTS(env, MCTSConfig(sims=sims)).search(model, obs, is_player1, add_noise=False).argmax(-1)


@torch.no_grad()
def ladder_match(model, env, sims_a, sims_b):
    """Player A (sims_a) vs player B (sims_b), same network, over all 98 openings (A as both
    colours). Returns A's score (win + ½·draw) in [0, 1]."""
    obs, is_player1, a_is_red = two_ply_positions(env)
    N = obs.shape[0]
    finished = torch.zeros(N, dtype=torch.bool, device=env.device)
    result = torch.zeros(N, device=env.device)
    for _ in range(42):
        if bool(finished.all()):
            break
        a_to_move = (is_player1 == a_is_red)
        move = torch.where(a_to_move,
                           _ladder_action(model, env, obs, is_player1, sims_a),
                           _ladder_action(model, env, obs, is_player1, sims_b))
        nobs, done, rew = env.step(obs, move, is_player1)
        newly = done & (~finished)
        win = newly & (rew > 0.5)
        result = torch.where(win & a_to_move, torch.ones_like(result), result)
        result = torch.where(win & (~a_to_move), -torch.ones_like(result), result)
        finished = finished | newly
        obs = nobs
        is_player1 = ~is_player1
    w = int((result > 0.5).sum()); l = int((result < -0.5).sum()); d = N - w - l
    return (w + 0.5 * d) / N


def fit_elo(score_matrix, iters=3000, lr=10.0):
    """Least-squares Elo fit to a pairwise score matrix (score[i,j] = i's score vs j), centred at 0."""
    S = score_matrix.shape[0]
    R = torch.zeros(S, requires_grad=True)
    P = torch.as_tensor(score_matrix, dtype=torch.float32)
    off = ~torch.eye(S, dtype=torch.bool)
    opt = torch.optim.Adam([R], lr=lr)
    for _ in range(iters):
        pred = torch.sigmoid((R[:, None] - R[None, :]) * (math.log(10) / 400))
        loss = ((pred - P)[off] ** 2).mean()
        opt.zero_grad(); loss.backward(); opt.step()
    return (R.detach() - R.detach().mean())


if SLOW:
    import matplotlib.pyplot as plt

    levels = [1, 2, 4, 8, 16, 32, 64]
    S = len(levels)
    score = torch.full((S, S), 0.5)
    for i in range(S):
        for j in range(S):
            if i != j:
                score[i, j] = ladder_match(trainer.model, env, levels[i], levels[j])
    elo = fit_elo(score.numpy())
    elo = elo - elo.min()   # anchor the weakest at 0 for readability
    for M, e in zip(levels, elo.tolist()):
        print(f"{M:3d} sims:  Elo {e:6.0f}")

    fig, ax = plt.subplots(figsize=(7, 4.5))
    ax.plot([math.log2(M) for M in levels], elo.tolist(), "o-")
    ax.set_xticks([math.log2(M) for M in levels]); ax.set_xticklabels(levels)
    ax.set_xlabel("MCTS simulations per move (log scale)")
    ax.set_ylabel("Elo (self-play ladder)")
    ax.set_title("Strength is ~linear in log(search) — the AlphaZero scaling law")
    ax.grid(alpha=0.3); fig.tight_layout()