[2.5] - MCTS & AlphaZero

Colab: exercises | solutions

Please send any problems / bugs on the #errata channel in the Slack group, and ask any questions on the dedicated channels for this chapter of material.

If you want to change to dark mode, you can do this by clicking the three horizontal lines in the top-right, then navigating to Settings → Theme.

Links to all other chapters: (0) Fundamentals, (1) Transformer Interpretability, (2) RL.

Introduction

Up until now we've been dealing primarailly with model-free methods: those that have no explicit model of how the world works, and need to learn the rules of the game from experience. This is wasteful when the environment is already known and cheap to simulate (like a board game). Today we introduce Monte Carlo Tree Search (MCTS), a model-based method where we have access to a simulator of the environment that we can use for planning, and combine this with deep learning to create an agent to play a strong game of Connect 4, learning only from self-play. This was the same method used by AlphaGo Zero to become superhuman at Go.

The main idea is as follows:

  • We use a neural network to guide the tree search.
  • We select actions based on which nodes were the most visited during the tree search.
  • We train the policy network to mimic the tree search, distilling the planning into the policy network.

This feedback loop (policy iteration via search) is what took AlphaZero from random play to superhuman in hours.

The rough steps for today: 1. Build the network (a small ResNet with two heads), 2. Build a simple single-game MCTS in pure Python to understand the algorithm, 3. Vectorize the MCTS to run hundreds of games at once on the GPU, 4. Build the PUCT sampler that turns search into training data, and 5. Train the network to mimic the tree search.

We've provided a vectorized implementation of Connect 4 in part5_mcts_alphazero/connect4.py, as well as a fast, net-independent evaluation built on Pascal Pons' perfect Connect-4 solver — we score the agent on how often it picks a provably optimal move. At the end, you'll have a model that trains to a strong level in under five minutes on a GPU.

Attributions: Part of the codebase was build upon implementations of AlphaZero by

Content & Learning Objectives

0️⃣ MCTS & AlphaZero — Theory

A non-exercise section introducing Monte Carlo Tree Search and how AlphaZero turns it into a learning algorithm.

Learning Objectives
  • Understand the four phases of MCTS (selection, expansion, simulation, backup).
  • See how AlphaZero replaces random rollouts with a value net and uses the policy as a prior via PUCT.
  • Understand the self-play loop and loss function for the network.

1️⃣ The Environment & Network

We meet the provided Connect-4 environment and build the policy-value network. The network is a small ResNet with two heads: an actor (policy) and a critic (value).

Learning Objectives
  • Use the provided vectorised Connect-4 environment, and understand how the board is encoded.
  • Build the AlphaZero policy-value network.

2️⃣ Single-Game MCTS

Implement MCTS with an explicit tree, on a single board, in pure Python. No prizes for speed here, but it helps to write the sequential version first.

Learning Objectives
  • Implement a Node class, PUCT selection, expansion, and backup.
  • Assemble the full search loop and verify it finds tactical wins and blocks.

3️⃣ Simulating Vectorized MCTS

Understand the demands of the vectorized version, and explaining how we can batch the forward passes of the network using our sequential version as a reference.

Learning Objectives
  • Understand how we can restructure the sequential version to batch the evaluate stage of the MCTS algorithm.
  • Understand both Root Parallelization and Leaf Parallelization, how they work, and why Root Parallelization is suited for implementing in PyTorch.
  • Build a simulated version of Root Parallelization that uses the sequential version as the backend.

4️⃣ Self-Play & Training

Close the loop: turn search into training data and train an agent.

Learning Objectives
  • Implement the self-play sampler: the tree policy, the network policy, and using the critic to estimate the value of rollouts.
  • Understand the loss function for the network and how it distills the planning provided by the tree search.
  • Train an agent to match a perfect solver's moves (and hopefully beat you too!)

☆ Bonus - Vectorized MCTS

Implement the vectorized version of the MCTS algorithm.

Learning Objectives
  • Understand how we can store trees as tensors in a way that tree search can be performed as parallel operations on the GPU in lock-step.
  • Parallelize the selection, expansion and backup stages of the MCTS algorithm.

☆ Bonus

Various other extensions to MCTS, or things to try, including data augementation to abuse board symmetries, measuring the Elo of the model as a function of MCTS simulations, and various potential things to explore.

Readings

Setup code

try:
    get_ipython().run_line_magic("load_ext", "autoreload")
    get_ipython().run_line_magic("autoreload", "2")
except Exception:
    pass
import einops
from eindex import eindex
import math
import sys
from pathlib import Path
from dataclasses import dataclass, field, asdict
from jaxtyping import Float, Bool, Int
from torch import Tensor
import torch
import torch as t   # ARENA convention: other chapters use `t.`; this file spells out `torch.`, both work
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchinfo import summary
from tqdm.auto import tqdm

# Make sure exercises are in the path
chapter = "chapter2_rl"
section = "part5_mcts_alphazero"
exercises_dir = next(p for p in Path.cwd().parents if (p / chapter).exists()) / chapter / "exercises"
section_dir = exercises_dir / section
for _p in (str(section_dir), str(exercises_dir)):   # section_dir for bare imports; exercises_dir so
    if _p not in sys.path:                          # `from part5_mcts_alphazero.solutions import ...` resolves
        sys.path.append(_p)
import part5_mcts_alphazero.tests as tests
import part5_mcts_alphazero.utils as utils
from part5_mcts_alphazero.utils import (
    Connect4Env, MCTSConfig, legal_mask_from_obs, fmt_si,
    render_board, place_piece, plot_board_and_policy, print_mcts_tree, plot_mcts_tree,
    two_ply_positions, greedy_policy_action,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAIN = __name__ == "__main__"
SLOW = False   # set True to run the slow bonus demos (strength-vs-sims, Elo-vs-search budget)
TRAINING = True    # set False to skip section 4 self-play training (e.g. while iterating on section 2/section 3)