5️⃣ Bonus
Claude has some suggestions for you. I personally haven't vetted the below, so take with a grain of salt.
Some directions if you have time:
- Dirichlet exploration noise at the root. Classic AlphaZero mixes a little Dirichlet noise
into the root prior on every search — $P(s_0, a) = (1-\epsilon)\, p_\theta(s_0,a) + \epsilon\, \eta$
with $\eta \sim \mathrm{Dir}(\alpha)$ — so self-play occasionally tries moves the current policy
underrates instead of collapsing onto the prior's favourite. The provided
searchalready implements this behind theadd_noiseflag (withdirichlet_eps/dirichlet_alphaon the config), but it's off by default: on Connect4 at this scale the agent trains basically fine without it. In an ablation (noise on vs off, same seed) it gave only a modest, noisy edge in the Pons optimal-move accuracy — the no-noise run stalled mid-training but had caught up by the end. Turn it on (passadd_noise=Trueinself_play'ssearchcall), sweepdirichlet_epsanddirichlet_alpha, and measure whether it actually helps. Does the benefit grow on a bigger board, with more simulations, or with more training generations? - Temperature schedule. AlphaZero samples with temperature 1 for the first few moves of
each game (for opening diversity), then plays greedily. Add a per-move temperature schedule
to
self_playand see whether it helps. - Tune the search. How does strength (Pons optimal-move accuracy) change with
sims(simulations per move) andc_puct? Plot it — see the "strength vs search budget" bonus above. (More play-timesimsat evaluation makes the agent stronger without any retraining.) - Subtree reuse. Between consecutive moves of one game, the new root is a child of the old root — its subtree is already partly searched. Reuse it instead of starting from scratch.
- Bigger network. Add more residual blocks or channels. Where are the diminishing returns?
- Play it yourself. The research code ships a terminal and browser-based UI (
play_cli.py,play_web.py) — load your trained checkpoint and try to beat it. Can you? - Compare to PPO self-play. How does AlphaZero compare to training the same network with the PPO self-play from [2.3]? Which is more sample-efficient here, and why?
Exercise - data augmentation by mirror symmetry
Connect 4 is left-right mirror-symmetric: reflecting the board across the centre column gives a
strategically identical position. So every self-play example (obs, pi, z) comes with a free twin —
reflect the board, reverse the action distribution column-wise (column $c \leftrightarrow 6 - c$),
and keep the value unchanged. Training on both doubles your data at zero self-play cost. (This is a
standard AlphaZero trick; AlphaGo Zero exploited all 8 symmetries of the Go board.)
Implement augment_with_mirror, returning the batch concatenated with its mirror image. The
simplest place to apply it is at the top of training_step (doubling each minibatch); augmenting in
the buffer's DataLoader works too, the trainer just sees bigger batches either way. Then call
it on each batch inside the trainer (e.g. at the top of training_step, or when the buffer builds its
DataLoader) and see whether the agent reaches a given strength in fewer self-play games.
def augment_with_mirror(
obs: Float[Tensor, "batch 3 H W"],
pi: Float[Tensor, "batch 7"],
z: Float[Tensor, "batch"],
) -> tuple[Float[Tensor, "b2 3 H W"], Float[Tensor, "b2 7"], Float[Tensor, "b2"]]:
"""Concatenate (obs, pi, z) with their left-right mirror image (Connect-4's only symmetry).
Args:
obs: (B, 3, H, W) boards
pi: (B, 7) policy targets
z: (B,) value targets
Returns:
obs: (2B, 3, H, W) original + width-flipped boards
pi: (2B, 7) original + column-reversed policies
z: (2B,) value targets, duplicated unchanged
"""
raise NotImplementedError()
tests.test_augment_with_mirror(augment_with_mirror)
Solution
def augment_with_mirror(
obs: Float[Tensor, "batch 3 H W"],
pi: Float[Tensor, "batch 7"],
z: Float[Tensor, "batch"],
) -> tuple[Float[Tensor, "b2 3 H W"], Float[Tensor, "b2 7"], Float[Tensor, "b2"]]:
"""Concatenate (obs, pi, z) with their left-right mirror image (Connect-4's only symmetry).
Args:
obs: (B, 3, H, W) boards
pi: (B, 7) policy targets
z: (B,) value targets
Returns:
obs: (2B, 3, H, W) original + width-flipped boards
pi: (2B, 7) original + column-reversed policies
z: (2B,) value targets, duplicated unchanged
"""
obs_m = obs.flip(dims=[-1]) # reflect the board across the centre column (width is the last dim)
pi_m = pi.flip(dims=[-1]) # column c <-> column 6 - c
return torch.cat([obs, obs_m]), torch.cat([pi, pi_m]), torch.cat([z, z])
tests.test_augment_with_mirror(augment_with_mirror)
Bonus - strength vs search budget
A trained AlphaZero net can be made stronger at play time just by searching more — no retraining.
With M = 0 simulations the agent plays its raw policy head (no planning — exactly the cheap
eval we run each generation); with M > 0 it runs MCTS for M sims per move. The given
evaluate_with_search (from pascal_pons.eval_pons) replays the frozen Pons set and, for each
budget M, measures how often the agent's chosen move lands in the solver's optimal set — i.e.
strength against perfect play, a sharper signal than win-rate against any fixed bot. The sweep over
M ∈ {0, 1, 2, 4, 8, 16, 32, 64} runs MCTS over the whole set at each budget, so it's SLOW — set
SLOW = True at the top to run it. You should see accuracy climb with M.
from pascal_pons.eval_pons import evaluate_with_search
if SLOW: # slow (runs MCTS over the whole Pons set at each budget); set SLOW=True at the top to enable
import matplotlib.pyplot as plt
sims_list = [0, 1, 2, 4, 8, 16, 32, 64]
curve = evaluate_with_search(trainer.model, env, sims_list=sims_list)
for M in sims_list:
tag = " (raw policy, no planning)" if M == 0 else ""
print(f"M={M:3d} sims{tag:<27}: optimal-move acc vs solver = {curve[M]['acc']:.3f}")
accs = [curve[M]["acc"] for M in sims_list]
fig, ax = plt.subplots(figsize=(7, 4.5))
ax.plot(range(len(sims_list)), accs, "o-")
ax.set_xticks(range(len(sims_list))); ax.set_xticklabels(sims_list)
ax.set_xlabel("MCTS simulations per move (M=0 → raw policy, no planning)")
ax.set_ylabel("optimal-move accuracy vs solver"); ax.set_ylim(0, 1)
ax.grid(alpha=0.3); ax.set_title("Strength scales with search budget (no retraining)")
fig.tight_layout()
Bonus - the AlphaZero scaling law: Elo vs log(search)
The curve above shows accuracy vs a perfect solver, which saturates once the agent plays near- optimally. A cleaner way to see how much search alone is worth is a self-play ladder: take the same trained network and have it play itself at different simulation budgets, then fit an Elo rating to the round-robin results. Plotting Elo against $\log_2(\text{sims})$ reproduces the well-known AlphaZero result that playing strength is roughly linear in the log of the search budget — every doubling of thinking time buys a roughly constant Elo gain, with no change to the weights.
(This is SLOW: it runs a full round-robin of MCTS-vs-MCTS matches. Set SLOW = True to run it,
ideally on a strong network — load one of the pretrained checkpoints/az_step_*.pt into trainer.model.)
@torch.no_grad()
def _ladder_action(model, env, obs, is_player1, sims):
"""Move for the side to move: raw policy if sims == 0, else MCTS with `sims` simulations."""
if sims == 0:
return greedy_policy_action(model, canonicalise_obs(obs, is_player1))
return BatchedMCTS(env, MCTSConfig(sims=sims)).search(model, obs, is_player1, add_noise=False).argmax(-1)
@torch.no_grad()
def ladder_match(model, env, sims_a, sims_b):
"""Player A (sims_a) vs player B (sims_b), same network, over all 98 openings (A as both
colours). Returns A's score (win + ½·draw) in [0, 1]."""
obs, is_player1, a_is_red = two_ply_positions(env)
N = obs.shape[0]
finished = torch.zeros(N, dtype=torch.bool, device=env.device)
result = torch.zeros(N, device=env.device)
for _ in range(42):
if bool(finished.all()):
break
a_to_move = (is_player1 == a_is_red)
move = torch.where(a_to_move,
_ladder_action(model, env, obs, is_player1, sims_a),
_ladder_action(model, env, obs, is_player1, sims_b))
nobs, done, rew = env.step(obs, move, is_player1)
newly = done & (~finished)
win = newly & (rew > 0.5)
result = torch.where(win & a_to_move, torch.ones_like(result), result)
result = torch.where(win & (~a_to_move), -torch.ones_like(result), result)
finished = finished | newly
obs = nobs
is_player1 = ~is_player1
w = int((result > 0.5).sum()); l = int((result < -0.5).sum()); d = N - w - l
return (w + 0.5 * d) / N
def fit_elo(score_matrix, iters=3000, lr=10.0):
"""Least-squares Elo fit to a pairwise score matrix (score[i,j] = i's score vs j), centred at 0."""
S = score_matrix.shape[0]
R = torch.zeros(S, requires_grad=True)
P = torch.as_tensor(score_matrix, dtype=torch.float32)
off = ~torch.eye(S, dtype=torch.bool)
opt = torch.optim.Adam([R], lr=lr)
for _ in range(iters):
pred = torch.sigmoid((R[:, None] - R[None, :]) * (math.log(10) / 400))
loss = ((pred - P)[off] ** 2).mean()
opt.zero_grad(); loss.backward(); opt.step()
return (R.detach() - R.detach().mean())
if SLOW:
import matplotlib.pyplot as plt
levels = [1, 2, 4, 8, 16, 32, 64]
S = len(levels)
score = torch.full((S, S), 0.5)
for i in range(S):
for j in range(S):
if i != j:
score[i, j] = ladder_match(trainer.model, env, levels[i], levels[j])
elo = fit_elo(score.numpy())
elo = elo - elo.min() # anchor the weakest at 0 for readability
for M, e in zip(levels, elo.tolist()):
print(f"{M:3d} sims: Elo {e:6.0f}")
fig, ax = plt.subplots(figsize=(7, 4.5))
ax.plot([math.log2(M) for M in levels], elo.tolist(), "o-")
ax.set_xticks([math.log2(M) for M in levels]); ax.set_xticklabels(levels)
ax.set_xlabel("MCTS simulations per move (log scale)")
ax.set_ylabel("Elo (self-play ladder)")
ax.set_title("Strength is ~linear in log(search) — the AlphaZero scaling law")
ax.grid(alpha=0.3); fig.tight_layout()