☆ Bonus - Vectorized MCTS

Learning Objectives
  • Understand how we can store trees as tensors in a way that tree search can be performed as parallel operations on the GPU in lock-step.
  • Parallelize the selection, expansion and backup stages of the MCTS algorithm.

⚠️ Note This section is seriously hard, and is mostly a tricky engineering task of getting tree updates to work with vectorized operations. There aren't more conceptual challenges as far as understanding MCTS goes, it's just speeding up the SimulatedBatchedMCTS class you were presented earlier. Attempt at your peril, and please give feedback on how to improve this section.

Admittedly, Claude is a better engineer than I am, and after getting a sequential solution working, it did most of the hard work of vectorizing the code, so it's not clear how important it is to be able to haave hardcore vectorization skills as opposed to just understanding the high-level algorithm.

Recall that we are trying to implement Root Parallelisation. Have a look at the previous sections if you need a reminder.

How to store trees on the GPU?

For each game b we keep a pool of up to MAX_NODES nodes, stored as flat tensors indexed by [game, node, …]: Since the board is finite size, and a piece is added on every timestep, we can stastically allocate MAX_NODES = height * width = 42 cells for a standard Connect-4 board, and we will never run out of room. We allocate the memory only once, and then reuse it for every set of rollouts, greatly increasing throughout as we don't need to allocate/deallocate memory.

The tensors are:

  • obs_pool : Float[Tensor, "B MAXN 3 height width"]: the board state for each game
  • tomove : Bool[Tensor, "B MAXN"]: the player to move for each game
  • terminal : Bool[Tensor, "B MAXN"]: whether the game is terminal for each game
  • term_val : Float[Tensor, "B MAXN"]: the terminal value for each game
  • child : Long[Tensor, "B MAXN 7"]: the child node-id per action, or -1 if not yet expanded
  • parent : Long[Tensor, "B MAXN"]: each node's parent slot (-1 at the root)
  • parent_act : Long[Tensor, "B MAXN"]: the column that led from the parent into each node
  • N : Long[Tensor, "B MAXN 7"]: per-edge visit counts
  • W : Float[Tensor, "B MAXN 7"]: per-edge value sums
  • P : Float[Tensor, "B MAXN 7"]: per-edge priors
  • nptr : Long[Tensor, "B"]: next free node slot; node 0 is the root.

The parent/parent_act arrays are how backup finds its way home: each is written once, when a node is created (in expand_batch), and backup follows them from a leaf up to the root instead of recording a path during selection. This works because root-parallel MCTS builds a strict tree (one parent per node); it would break under transpositions (a position reached by several move orders), which we don't use.

Handling variable length games

One annoyance is that while the length of any rollout is bounded by height * width = 42, any particular game can terminate early. We handle this with a dustbin: a throwaway node/column slot that rollouts for already terminated games hit over and over. One could optimize even further by relaunching games as soon as they terminate, but for simplicity we don't bother and just waste some extra compute on already dead games.

"sync-free" code

We never call methods like .item() in the hot loop, as it would copy a value to the CPU and stall the GPU pipeline. All operations for the batched MCTS are gather/scatter/where/argmax, so the whole search runs as one uninterrupted stream of GPU kernels. All the parallel rollouts move in lockstep, so there is no need to synchronize between threads or wait for threads to finish.

Rather than one giant search, we factor it into small, separately-testable free functions that each take a Tree — the batched mirror of section 2, where every phase took a Node. The per-phase wrappers (expand_root, select_batch, expand_batch, evaluate_batch) and dirichlet_root_noise (root noise) are given; you implement the five numeric kernels they call, each with its own unit test: masked_softmax_prior (legal-masked policy prior), puct_select (batched PUCT score), step_descent (one PUCT descent step), batched_backup (negamax backup), and get_leaf_value (which leaf value to back up). Each kernel's test checks it against the section 2 single-game function looped over the batch — the "single ↔ batched equivalence" that anchors section 3.

@dataclass
class Tree:
    """Flat-tensor store for `B` independent root-parallel MCTS trees (one per game) — the batched
    analogue of `Node`. Node `0` of each game is its root; real nodes occupy slots `[0, MAXN)`; the
    extra slot `DUST_N` (= `MAXN`, the `+1` row) is a **dustbin** that absorbs writes from games not
    expanding this simulation, so all games move in lockstep without `.item()` syncs. Statistics are
    per-EDGE (`N`/`W`/`P`/`child`, shape `(B, MAXN+1, 7)`) or per-NODE (the rest). Backup walks
    `parent`/`parent_act` home, so SELECT records no path. Indexing convention for all per-edge
    tensors: `N[b, n, a]` = the stat for action `a` at node `n` of game `b`.

    Fields:
        B           int — number of independent games / trees.
        MAXN        int — node-pool capacity per game (`cfg.sims + 2`); real nodes are slots [0, MAXN).
        DUST_N      int — dustbin slot index (`= MAXN`); dead/non-expanding games write here.
        MAXD        int — max descent/backup depth (`cfg.max_depth`).
        ar          (B,) long — `arange(B)`, so `X[ar, node]` gathers each game's own row.
        obs_pool    (B, MAXN+1, 3, 6, 7) float — board per node.
        tomove      (B, MAXN+1) bool — player-to-move (red?) per node.
        terminal    (B, MAXN+1) bool — node is a finished position.
        term_val    (B, MAXN+1) float — terminal value from this node's mover's perspective (`-reward`).
        legal       (B, MAXN+1, 7) bool — legal-column mask per node.
        P           (B, MAXN+1, 7) float — per-edge prior (legal-masked softmax; root may be noised).
        child       (B, MAXN+1, 7) long — child node-id per action, `-1` if the edge is unexpanded.
        parent      (B, MAXN+1) long — parent slot of each node (`-1` at the root).
        parent_act  (B, MAXN+1) long — the column that led from `parent` into this node.
        N           (B, MAXN+1, 7) float — per-edge visit counts (updated in BACKUP).
        W           (B, MAXN+1, 7) float — per-edge value sums, this node's mover's perspective.
        nptr        (B,) long — next free node slot per game; starts at 1 (node 0 is the root).
    """
    B: int
    MAXN: int
    DUST_N: int
    MAXD: int
    ar: Int[Tensor, "B"]
    obs_pool: Float[Tensor, "B nodes 3 6 7"]
    tomove: Bool[Tensor, "B nodes"]
    terminal: Bool[Tensor, "B nodes"]
    term_val: Float[Tensor, "B nodes"]
    legal: Bool[Tensor, "B nodes 7"]
    P: Float[Tensor, "B nodes 7"]
    child: Int[Tensor, "B nodes 7"]
    parent: Int[Tensor, "B nodes"]
    parent_act: Int[Tensor, "B nodes"]
    N: Float[Tensor, "B nodes 7"]
    W: Float[Tensor, "B nodes 7"]
    nptr: Int[Tensor, "B"]

    @classmethod
    def alloc(cls, B: int, cfg: MCTSConfig, device) -> "Tree":
        """Allocate all pools once (node 0 = each game's root, `nptr` starts at 1). A spare dustbin
        slot (`DUST_N`, the `+1` index) absorbs writes from games not expanding this simulation."""
        MAXN = cfg.sims + 2
        z = lambda *shape, dtype=torch.float32: torch.zeros((B, MAXN + 1, *shape), dtype=dtype, device=device)
        return cls(
            B=B, MAXN=MAXN, DUST_N=MAXN, MAXD=cfg.max_depth,
            ar=torch.arange(B, device=device),
            obs_pool=z(3, 6, 7),
            tomove=z(dtype=torch.bool),
            terminal=z(dtype=torch.bool),
            term_val=z(),
            legal=z(7, dtype=torch.bool),
            P=z(7),
            child=torch.full((B, MAXN + 1, 7), -1, dtype=torch.long, device=device),
            parent=torch.full((B, MAXN + 1), -1, dtype=torch.long, device=device),
            parent_act=z(dtype=torch.long),
            N=z(7),
            W=z(7),
            nptr=torch.ones((B,), dtype=torch.long, device=device),
        )

Exercise - implement masked_softmax_prior

Difficulty: 🔴⚪⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to ~5 minutes on this exercise.

The policy head returns raw logits (B, 7), but some columns are full (illegal). Turn the logits into a normalised prior P(a) over the legal columns only: set illegal columns to -torch.inf so they get probability ≈ 0, then softmax. Used at the root and at every newly-expanded leaf.

def masked_softmax_prior(
    logits: Float[Tensor, "B 7"], legal: Bool[Tensor, "B 7"]
) -> Float[Tensor, "B 7"]:
    """Softmax of the policy logits over the legal columns only; used at the root and every new leaf.

    Args:
        logits: (B, 7) raw policy-head scores
        legal:  (B, 7) legal-column mask

    Returns:
        (B, 7) prior P(a): zero on illegal columns, summing to 1 over the legal ones
    """
    raise NotImplementedError()


tests.test_masked_softmax_prior(masked_softmax_prior)
Solution
def masked_softmax_prior(
    logits: Float[Tensor, "B 7"], legal: Bool[Tensor, "B 7"]
) -> Float[Tensor, "B 7"]:
    """Softmax of the policy logits over the legal columns only; used at the root and every new leaf.

    Args:
        logits: (B, 7) raw policy-head scores
        legal:  (B, 7) legal-column mask

    Returns:
        (B, 7) prior P(a): zero on illegal columns, summing to 1 over the legal ones
    """
    legal_logits = torch.where(legal, logits, -torch.inf)
    return torch.softmax(legal_logits, dim=-1)


tests.test_masked_softmax_prior(masked_softmax_prior)

Exercise - implement puct_select (batched PUCT)

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 10-15 minutes on this exercise.

The batched twin of select_child: given the current node's per-edge statistics for all B games at once, return the legal action maximising the PUCT score, per game,

$$ Q(s,a) + c_\text{puct}\, P(s,a)\, \frac{\sqrt{1 + \sum_b N(s,b)}}{1 + N(s,a)} $$
with $Q(s,a) = W(s,a) / \max(N(s,a), 1)$. Mask illegal columns to get a PUCT score of -torch.inf before the argmax, and keep the same $\sqrt{1 + \sum_b N}$ form as the single-game version so the two agree exactly. You can assume that this function will only every be called on a node that has at least one legal action.

def puct_select(
    node_N: Float[Tensor, "B 7"],
    node_W: Float[Tensor, "B 7"],
    node_P: Float[Tensor, "B 7"],
    node_legal: Bool[Tensor, "B 7"],
    c_puct: float,
) -> Int[Tensor, "B"]:
    """Batched PUCT selection: pick the legal action with the highest PUCT score, per game.

    The score trades off exploitation `Q = W / max(N, 1)` against exploration
    `c_puct * P * sqrt(1 + N.sum(-1, keepdim=True)) / (1 + N)` — the sum is over the 7 actions at
    this node, computed independently per game; illegal columns are masked out before the argmax.
    All inputs are the flat-tree slices at the current node of each of the `B` games.

    Args:
        node_N:     (B, 7) per-edge visit counts
        node_W:     (B, 7) per-edge value sums
        node_P:     (B, 7) per-edge priors P(a)
        node_legal: (B, 7) legal-column mask
        c_puct:     exploration constant

    Returns:
        (B,) the chosen legal action (column index) for each game
    """
    raise NotImplementedError()


tests.test_puct_select(puct_select)
Solution
def puct_select(
    node_N: Float[Tensor, "B 7"],
    node_W: Float[Tensor, "B 7"],
    node_P: Float[Tensor, "B 7"],
    node_legal: Bool[Tensor, "B 7"],
    c_puct: float,
) -> Int[Tensor, "B"]:
    """Batched PUCT selection: pick the legal action with the highest PUCT score, per game.

    The score trades off exploitation `Q = W / max(N, 1)` against exploration
    `c_puct * P * sqrt(1 + N.sum(-1, keepdim=True)) / (1 + N)` — the sum is over the 7 actions at
    this node, computed independently per game; illegal columns are masked out before the argmax.
    All inputs are the flat-tree slices at the current node of each of the `B` games.

    Args:
        node_N:     (B, 7) per-edge visit counts
        node_W:     (B, 7) per-edge value sums
        node_P:     (B, 7) per-edge priors P(a)
        node_legal: (B, 7) legal-column mask
        c_puct:     exploration constant

    Returns:
        (B,) the chosen legal action (column index) for each game
    """
    sumN = node_N.sum(-1, keepdim=True)
    Q = node_W / node_N.clamp_min(1.0)
    U = c_puct * node_P * torch.sqrt(sumN + 1.0) / (1.0 + node_N)
    legal_score = torch.where(node_legal, Q + U, -torch.inf)
    return legal_score.argmax(-1)


tests.test_puct_select(puct_select)

Selection: follow PUCT from the root to a leaf

For a single game, selection is a short walk down the tree: from the root, repeatedly take the PUCT-best action and step into that child, until you either fall off the tree (an unexpanded edge) or reach a finished position (a terminal node). You just remember the leaf you stopped on — the tree's parent pointers already record the route, so backup can walk it home without a separate path.

# single-game selection
node = root
while True:
    if node.is_terminal:
        leaf, leaf_value = node, node.term_val   # stop -- nothing to expand
        break
    a = puct_select(node)                        # PUCT-best legal action
    if node.child[a] is None:                    # unexpanded edge -> expand; the new child is our leaf
        leaf, leaf_value = expand_and_eval(node, a)
        break
    node = node.child[a]                         # descend into an existing child

The batched version runs B of these takes one PUCT step for all games per iteration of a for d in range(MAXD) loop. The only wrinkle is that the walks have different lengths: a game finishes at its own depth while the loop keeps going for games still descending. We track that with a done : Bool[Tensor, "B"] mask (a game flips to done the moment it hits a terminal node or an unexpanded edge) and skip finished games on later iterations.

Selection records nothing

Because every node stores its parent (and the parent_act that reached it), selection doesn't build any per-game path buffer. It only reports where each game stopped (the leaf, and the parent/action to expand). Backup then reconstructs each route by following parent pointers up to the root.

Exercise - implement step_descent

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to ~5-10 minutes on this exercise.

This is the inside of one descent step, factored out so the loop in select_batch reads cleanly. You're handed the per-edge slices at each game's current node which contains exactly the (B, 7) tensors puct_select already consumes, plus the child table. We need to:

  1. pick the PUCT-best action a for every game (just call your puct_select),
  2. look up the child each a points to.

The child lookup is the new bit. node_child : (B, 7) holds the child id per action, so you want child[b] = node_child[b, a[b]]

How do I do child[b] = node_child[b, a[b]] in a vectorized fashion?

Options include: node_child.gather(1, a.unsqueeze(1)).squeeze(1) or node_child[torch.arange(B), a] or eindex(node_child, a, "batch [batch] -> batch") using Callum's eindex library.

You can ignore the bookkeeping around games that have already stopped — select_batch masks that out. In particular puct_select is sometimes called here on a terminal node with no legal moves; the action it returns is then meaningless, but select_batch discards it, so you don't need to special-case it.

def step_descent(
    node_N: Float[Tensor, "B 7"],
    node_W: Float[Tensor, "B 7"],
    node_P: Float[Tensor, "B 7"],
    node_child: Int[Tensor, "B 7"],
    node_legal: Bool[Tensor, "B 7"],
    c_puct: float,
) -> tuple[Int[Tensor, "B"], Int[Tensor, "B"]]:
    """One level of PUCT descent for all `B` games at once: pick the PUCT-best legal action at each
    game's current node, then follow it to the child it points at.

    All inputs are the flat-tree slices at the current node of each of the `B` games (the same slices
    `puct_select` takes, plus the child row). Pure per-node work -- the caller masks out games that
    have already stopped descending.

    Args:
        node_N:     (B, 7) per-edge visit counts
        node_W:     (B, 7) per-edge value sums
        node_P:     (B, 7) per-edge priors P(a)
        node_child: (B, 7) child node-id per action, or -1 if that edge is unexpanded
        node_legal: (B, 7) legal-column mask
        c_puct:     exploration constant

    Returns:
        a:     (B,) the PUCT-chosen action (column) at each game's node
        child: (B,) the child node id along `a`, or -1 if that edge is not yet expanded
    """
    raise NotImplementedError()


tests.test_step_descent(step_descent)
Solution
def step_descent(
    node_N: Float[Tensor, "B 7"],
    node_W: Float[Tensor, "B 7"],
    node_P: Float[Tensor, "B 7"],
    node_child: Int[Tensor, "B 7"],
    node_legal: Bool[Tensor, "B 7"],
    c_puct: float,
) -> tuple[Int[Tensor, "B"], Int[Tensor, "B"]]:
    """One level of PUCT descent for all `B` games at once: pick the PUCT-best legal action at each
    game's current node, then follow it to the child it points at.

    All inputs are the flat-tree slices at the current node of each of the `B` games (the same slices
    `puct_select` takes, plus the child row). Pure per-node work -- the caller masks out games that
    have already stopped descending.

    Args:
        node_N:     (B, 7) per-edge visit counts
        node_W:     (B, 7) per-edge value sums
        node_P:     (B, 7) per-edge priors P(a)
        node_child: (B, 7) child node-id per action, or -1 if that edge is unexpanded
        node_legal: (B, 7) legal-column mask
        c_puct:     exploration constant

    Returns:
        a:     (B,) the PUCT-chosen action (column) at each game's node
        child: (B,) the child node id along `a`, or -1 if that edge is not yet expanded
    """
    a = puct_select(node_N, node_W, node_P, node_legal, c_puct)
    child = node_child.gather(1, a.unsqueeze(1)).squeeze(1)
    # Readable einops-style alternative (bit-identical, but ~35x slower per call on CPU because it
    # re-parses the pattern each time, eindex is pretty slow):
    #   from eindex import eindex
    #   child = eindex(node_child, a, "batch [batch] -> batch")
    return a, child


tests.test_step_descent(step_descent)

Exercise - implement batched_backup (negamax backup)

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 10-15 minutes on this exercise.

This is the batched twin of the single-game parent-pointer backup. Recall that loop: start at the leaf with its value, then keep stepping up to the parent, flipping the sign each step (negamax — a position good for the mover is bad for its parent) and adding a visit + the signed value to the edge you came up:

v, nd = leaf_value, leaf
while nd.parent is not None:       # leaf -> root
    v = -v                         # negamax: flip the sign each step up
    nd.parent.N[nd.parent_action] += 1
    nd.parent.W[nd.parent_action] += v
    nd = nd.parent

We do exactly this for all B games at once, walking the flat parent / parent_act arrays instead of .parent links. Each game starts at its own leaf slot leaf_node : Int[Tensor, "B"] carrying value leaf_value. The games reach the root at different depths, so instead of a per-game while we take a fixed max_depth steps up (the deepest any leaf can be) — a game that has already reached the root just idles for the remaining steps.

We give you ar = torch.arange(B) — the per-game row index, so that game b updates its own node (N[ar, p, a] gathers N[b, p[b], a[b]]; a plain N[:, p, a] would cross every game with every other). Each of the max_depth steps:

  • at_root = (node == 0) — the root (slot 0) has no incoming edge, so these games are finished;
  • read the edge you came up: p = parent[ar, node], a = parent_act[ar, node];
  • flip the sign of v, but not for already-rooted games: v = torch.where(at_root, v, -v);
  • update N and W at (ar, p, a) — add 1 to N and the signed v to W, gated by ~at_root (use p.clamp_min(0) / a.clamp_min(0) so the unused root edge indexes safely);
  • hop up: node = torch.where(at_root, node, p).

Update N and W in place, vectorised over the B games — the only Python loop is the fixed max_depth sweep.

def batched_backup(
    N: Float[Tensor, "batch max_nodes 7"],
    W: Float[Tensor, "batch max_nodes 7"],
    parent: Int[Tensor, "batch max_nodes"],
    parent_act: Int[Tensor, "batch max_nodes"],
    leaf_node: Int[Tensor, "batch"],
    leaf_value: Float[Tensor, "batch"],
    max_depth: int,
) -> None:
    """Negamax backup by following parent pointers from each leaf up to the root; updates N, W in-place.

    The batched twin of the single-game `while nd.parent is not None` backup: each game starts at its
    own leaf slot and walks up the flat `parent`/`parent_act` arrays, flipping the value's sign at every
    edge (negamax -- good for the mover is bad for its parent) and adding one visit + the signed value.
    A fixed `max_depth` steps cover the deepest possible leaf; games that reach the root early idle on
    `at_root`: the root has `parent == -1` (a sentinel, not a real slot), so there is no incoming edge
    left to update and no further sign flip — gate the writes with a "still live" mask, and clamp the
    parent index (`p.clamp_min(0)`) only so the idle games' gathers stay in-bounds. Valid because the tree is strict (one parent per node) -- it would be wrong under
    transpositions (a position reachable by several move orders), which root-parallel MCTS doesn't use.

    Args:
        N:          (batch, max_nodes, 7) per-edge visit counts -- updated in place
        W:          (batch, max_nodes, 7) per-edge value sums   -- updated in place
        parent:     (batch, max_nodes) each node's parent slot (-1 at the root)
        parent_act: (batch, max_nodes) the column that led from the parent into each node
        leaf_node:  (batch,) slot each game's backup starts from (its leaf)
        leaf_value: (batch,) value of that leaf, from the leaf mover's perspective
        max_depth:  number of steps up to take (>= the deepest leaf's depth)

    Returns:
        None -- mutates N and W **in-place**.
    """
    B = N.shape[0]
    ar = torch.arange(B, device=N.device)      # per-game row index (so we update game b's own node)
    node = leaf_node.clone()                   # each game walks up from its own leaf...
    v = leaf_value.clone()                     # ...carrying that leaf's value
    for _ in range(max_depth):                 # fixed number of hops; rooted games idle on `at_root`
        raise NotImplementedError()


tests.test_batched_backup(batched_backup)
Solution
def batched_backup(
    N: Float[Tensor, "batch max_nodes 7"],
    W: Float[Tensor, "batch max_nodes 7"],
    parent: Int[Tensor, "batch max_nodes"],
    parent_act: Int[Tensor, "batch max_nodes"],
    leaf_node: Int[Tensor, "batch"],
    leaf_value: Float[Tensor, "batch"],
    max_depth: int,
) -> None:
    """Negamax backup by following parent pointers from each leaf up to the root; updates N, W in-place.

    The batched twin of the single-game `while nd.parent is not None` backup: each game starts at its
    own leaf slot and walks up the flat `parent`/`parent_act` arrays, flipping the value's sign at every
    edge (negamax -- good for the mover is bad for its parent) and adding one visit + the signed value.
    A fixed `max_depth` steps cover the deepest possible leaf; games that reach the root early idle on
    `at_root`: the root has `parent == -1` (a sentinel, not a real slot), so there is no incoming edge
    left to update and no further sign flip — gate the writes with a "still live" mask, and clamp the
    parent index (`p.clamp_min(0)`) only so the idle games' gathers stay in-bounds. Valid because the tree is strict (one parent per node) -- it would be wrong under
    transpositions (a position reachable by several move orders), which root-parallel MCTS doesn't use.

    Args:
        N:          (batch, max_nodes, 7) per-edge visit counts -- updated in place
        W:          (batch, max_nodes, 7) per-edge value sums   -- updated in place
        parent:     (batch, max_nodes) each node's parent slot (-1 at the root)
        parent_act: (batch, max_nodes) the column that led from the parent into each node
        leaf_node:  (batch,) slot each game's backup starts from (its leaf)
        leaf_value: (batch,) value of that leaf, from the leaf mover's perspective
        max_depth:  number of steps up to take (>= the deepest leaf's depth)

    Returns:
        None -- mutates N and W **in-place**.
    """
    B = N.shape[0]
    ar = torch.arange(B, device=N.device)      # per-game row index (so we update game b's own node)
    node = leaf_node.clone()                   # each game walks up from its own leaf...
    v = leaf_value.clone()                     # ...carrying that leaf's value
    for _ in range(max_depth):                 # fixed number of hops; rooted games idle on `at_root`
        at_root = node == 0                                   # the root has no incoming edge
        p = parent[ar, node]                                  # the slot we came from
        a = parent_act[ar, node]                              # the column we took to get here
        v = torch.where(at_root, v, -v)                       # negamax, but not "above" the root
        live = (~at_root).float()
        N[ar, p.clamp_min(0), a.clamp_min(0)] += live         # +1 visit on the edge (rooted games add 0)
        W[ar, p.clamp_min(0), a.clamp_min(0)] += v * live     # + the signed value
        node = torch.where(at_root, node, p)                  # hop up; rooted games stay put


tests.test_batched_backup(batched_backup)

Exercise - implement get_leaf_value

Difficulty: 🔴⚪⚪⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to ~5 minutes on this exercise.

When a simulation reaches a leaf we need its value, from that leaf's mover's perspective, to back up. In the single-game search there are three cases:

if node.is_terminal:                 # we re-reached an already-terminal node
    leaf_value = node.terminal_value
elif child.is_terminal:              # the move we just expanded ends the game
    leaf_value = -reward             # reward goes to the player who just moved -> negate
else:                                # an ordinary new leaf
    leaf_value = net_value           # ask the network

The batched version gets these as two masks, is_leaf_terminal and is_child_terminal, with the matching values term_value (the stored value of the re-reached terminal node), new_reward (the env reward from the expansion step), and if neither mask is true, net_value (the network estimate). You can solve this with torch.where, but since the masks are mutually exclusive, you can also so a masked sum with the masks cast to float.

def get_leaf_value(
    leaf_is_term: Bool[Tensor, "batch"],
    term_value: Float[Tensor, "batch"],
    term_new: Bool[Tensor, "batch"],
    new_reward: Float[Tensor, "batch"],
    eval_new: Bool[Tensor, "batch"],
    net_value: Float[Tensor, "batch"],
) -> Float[Tensor, "batch"]:
    """The value to back up for each game's leaf, from the leaf mover's perspective.

    The three masks partition the games (each game is in exactly one): a re-reached terminal node uses
    its stored `term_value`, a newly-terminal leaf uses `-new_reward`, and an ordinary new leaf uses
    the network's `net_value`.

    Args:
        leaf_is_term: (batch,) leaf was an already-terminal node
        term_value:   (batch,) that terminal node's stored value
        term_new:     (batch,) leaf is a newly-terminal node
        new_reward:   (batch,) env reward at expansion (mover's perspective)
        eval_new:     (batch,) leaf is a newly-evaluated (non-terminal) node
        net_value:    (batch,) network value estimate at the new leaf

    Returns:
        (batch,) the leaf value to back up
    """
    raise NotImplementedError()


tests.test_get_leaf_value(get_leaf_value)
Solution
def get_leaf_value(
    leaf_is_term: Bool[Tensor, "batch"],
    term_value: Float[Tensor, "batch"],
    term_new: Bool[Tensor, "batch"],
    new_reward: Float[Tensor, "batch"],
    eval_new: Bool[Tensor, "batch"],
    net_value: Float[Tensor, "batch"],
) -> Float[Tensor, "batch"]:
    """The value to back up for each game's leaf, from the leaf mover's perspective.

    The three masks partition the games (each game is in exactly one): a re-reached terminal node uses
    its stored `term_value`, a newly-terminal leaf uses `-new_reward`, and an ordinary new leaf uses
    the network's `net_value`.

    Args:
        leaf_is_term: (batch,) leaf was an already-terminal node
        term_value:   (batch,) that terminal node's stored value
        term_new:     (batch,) leaf is a newly-terminal node
        new_reward:   (batch,) env reward at expansion (mover's perspective)
        eval_new:     (batch,) leaf is a newly-evaluated (non-terminal) node
        net_value:    (batch,) network value estimate at the new leaf

    Returns:
        (batch,) the leaf value to back up
    """
    return (leaf_is_term.float() * term_value
            + term_new.float() * (-new_reward)
            + eval_new.float() * net_value)


tests.test_get_leaf_value(get_leaf_value)

The phases, as functions on a Tree

Each phase is a short given free function taking a Tree (the batched mirror of section 2, where every phase took a Node): expand_root writes the root prior; select_batch descends all games to their leaves (calling your step_descent); expand_batch plays one batched env step and links the new nodes; evaluate_batch runs one network forward over the new leaves. batched_search strings them together with your get_leaf_value and batched_backup. A 3-line BatchedMCTS wrapper allocates a Tree and calls batched_search, so callers keep a tidy .search(...) API.

@torch.no_grad()
def expand_root(tree: Tree, root_obs: Float[Tensor, "B 3 6 7"], root_is_player1: Bool[Tensor, "B"],
                model: nn.Module, cfg: MCTSConfig, add_noise: bool) -> None:
    """ROOT: evaluate the network at the root and write its (optionally noised) prior into `tree.P[:, 0]`.
    The batched mirror of the `evaluate(root, ...)` call that opens `mcts_search`."""
    tree.obs_pool[:, 0] = root_obs
    tree.tomove[:, 0] = root_is_player1
    _, logits0 = eval_net(model, root_obs, root_is_player1)
    lm0 = legal_mask_from_obs(root_obs)
    tree.legal[:, 0] = lm0
    pri0 = masked_softmax_prior(logits0, lm0)
    if add_noise:
        pri0 = dirichlet_root_noise(pri0, lm0, cfg.dirichlet_alpha, cfg.dirichlet_eps)
    tree.P[:, 0] = pri0
def select_batch(tree: Tree, c_puct: float) -> tuple:
    """SELECTION (batched `descend`): from each root, follow PUCT down to a leaf (an unexpanded edge or a
    terminal node). Reports where each game stopped; backup later walks the tree's own `parent` pointers
    home (set in `expand_batch`), so there's no path buffer.

    Returns per-game tensors `(leaf_is_term, term_leaf_node, leaf_parent, leaf_act, has_expand)`, each (B,).
    """
    B, ar, MAXD, dev = tree.B, tree.ar, tree.MAXD, tree.ar.device
    node  = torch.zeros((B,), dtype=torch.long, device=dev)                  # current node (root = 0)
    done  = torch.zeros((B,), dtype=torch.bool, device=dev)                  # stopped descending?
    leaf_is_term   = torch.zeros((B,), dtype=torch.bool, device=dev)
    term_leaf_node = torch.zeros((B,), dtype=torch.long, device=dev)
    leaf_parent = torch.zeros((B,), dtype=torch.long, device=dev)
    leaf_act    = torch.zeros((B,), dtype=torch.long, device=dev)
    has_expand  = torch.zeros((B,), dtype=torch.bool, device=dev)

    # One PUCT step for all games per iteration; games finish at different depths (tracked by `done`).
    # `ar` is arange(B): `X[ar, node]` gathers each game's own row.
    for d in range(MAXD):
        # your `step_descent`: PUCT-best action + the child it points to (results for masked games unused)
        a, child = step_descent(tree.N[ar, node], tree.W[ar, node], tree.P[ar, node],
                                tree.child[ar, node], tree.legal[ar, node], c_puct)

        active  = ~done                                          # still descending coming into this step
        is_term = tree.terminal[ar, node] & active               # landed on an existing terminal -> stop
        leaf_is_term   = leaf_is_term | is_term
        term_leaf_node = torch.where(is_term, node, term_leaf_node)

        step_taken = active & (~is_term)                         # games that walk a real edge at depth d
        is_unexp = step_taken & (child < 0)                      # walked an unexpanded edge -> our leaf
        leaf_parent = torch.where(is_unexp, node, leaf_parent)
        leaf_act    = torch.where(is_unexp, a,    leaf_act)
        has_expand  = has_expand | is_unexp

        done = done | is_term | is_unexp                         # both stop conditions end the descent
        node = torch.where(step_taken & (~is_unexp), child, node)  # else descend into the existing child
        if d >= 1 and bool(done.all()):
            break
    return leaf_is_term, term_leaf_node, leaf_parent, leaf_act, has_expand


tests.test_select_batch(select_batch, Tree)
@torch.no_grad()
def expand_batch(tree: Tree, leaf_parent: Int[Tensor, "B"], leaf_act: Int[Tensor, "B"],
                 has_expand: Bool[Tensor, "B"], env: Connect4Env) -> tuple:
    """EXPANSION (batched `expand`): one batched env step from each leaf's parent along `leaf_act`; store
    the resulting node in the pool and link it in. Games not expanding write to the dustbin slot.

    Returns `(new_ids, nrew, term_new, eval_new)`, each (B,)."""
    ar = tree.ar
    # play the chosen edge in the env, for ALL games at once (one batched step)
    parent_obs = tree.obs_pool[ar, leaf_parent]
    parent_tomove = tree.tomove[ar, leaf_parent]
    nobs, ndone, nrew = env.step(parent_obs, leaf_act, parent_tomove)
    # store the resulting board as a fresh node at the next free slot `nptr`; games that aren't
    # expanding this step write to the dustbin slot (DUST_N) so they leave the real tree untouched
    new_ids = tree.nptr
    slot = torch.where(has_expand, new_ids, torch.full_like(new_ids, tree.DUST_N))
    tree.obs_pool[ar, slot] = nobs
    tree.tomove[ar, slot] = ~parent_tomove
    tree.terminal[ar, slot] = ndone
    tree.term_val[ar, slot] = -nrew            # value to the parent's mover if this move ended the game
    tree.parent[ar, slot] = leaf_parent        # remember where we came from...
    tree.parent_act[ar, slot] = leaf_act       # ...and the edge taken (backup walks these home)
    # link parent --leaf_act--> new node, and advance the free-slot pointer for games that expanded
    tree.child[ar, leaf_parent, leaf_act] = torch.where(
        has_expand, new_ids, tree.child[ar, leaf_parent, leaf_act])
    tree.nptr = tree.nptr + has_expand.long()
    term_new = has_expand & ndone              # the new node ends the game
    eval_new = has_expand & (~ndone)           # the new node needs a network evaluation
    return new_ids, nrew, term_new, eval_new


tests.test_expand_batch(expand_batch, Tree)
@torch.no_grad()
def evaluate_batch(tree: Tree, new_ids: Int[Tensor, "B"], eval_new: Bool[Tensor, "B"],
                   model: nn.Module) -> Float[Tensor, "B"]:
    """EVALUATION (batched `evaluate`): one network forward over all `B` new leaves; write the
    prior/legal mask for the leaves that need it (non-terminal new nodes). Returns (B,) leaf values."""
    ar = tree.ar
    lobs = tree.obs_pool[ar, new_ids]
    ltm = tree.tomove[ar, new_ids]
    val, logits = eval_net(model, lobs, ltm)
    lm = legal_mask_from_obs(lobs)
    pri = masked_softmax_prior(logits, lm)
    ne = eval_new.unsqueeze(-1)
    tree.legal[ar, new_ids] = torch.where(ne, lm, tree.legal[ar, new_ids])
    tree.P[ar, new_ids] = torch.where(ne, pri, tree.P[ar, new_ids])
    return val


tests.test_evaluate_batch(evaluate_batch, Tree)
@torch.no_grad()
def batched_search(tree: Tree, root_obs: Float[Tensor, "B 3 6 7"], root_is_player1: Bool[Tensor, "B"],
                   model: nn.Module, env: Connect4Env, cfg: MCTSConfig, add_noise: bool = False,
                   ) -> Float[Tensor, "B 7"]:
    """Run `cfg.sims` simulations of root-parallel MCTS on `tree` (the batched mirror of `mcts_search`):
    SELECT a leaf, EXPAND it (one env step), EVALUATE it (one net forward), BACK UP via parent pointers.
    Returns (B, 7) root visit counts `N[:, 0]` — the per-game policy target."""
    expand_root(tree, root_obs, root_is_player1, model, cfg, add_noise)
    for _ in range(cfg.sims):
        leaf_is_term, term_leaf_node, leaf_parent, leaf_act, has_expand = select_batch(tree, cfg.c_puct)
        new_ids, nrew, term_new, eval_new = expand_batch(tree, leaf_parent, leaf_act, has_expand, env)
        val = evaluate_batch(tree, new_ids, eval_new, model)
        term_value = tree.term_val[tree.ar, term_leaf_node]   # stored value if the leaf was terminal
        leaf_value = get_leaf_value(leaf_is_term, term_value, term_new, nrew, eval_new, val)
        # where backup starts: the new node if we expanded, else the terminal node SELECT stopped on
        leaf_node = torch.where(has_expand, new_ids, term_leaf_node)
        batched_backup(tree.N, tree.W, tree.parent, tree.parent_act, leaf_node, leaf_value, tree.MAXD)
    return tree.N[:, 0]  # root visit counts (B,7)


class BatchedMCTS:
    """Thin wrapper: holds env/cfg, and `.search(model, ...)` allocates a fresh `Tree` and runs
    `batched_search`. Same interface as `SimulatedBatchedMCTS` — a drop-in replacement, just fully
    vectorised (and far faster). The model is passed per-search so the trainer can own it."""
    def __init__(self, env, cfg):
        self.env, self.cfg = env, cfg

    @torch.no_grad()
    def search(self, model, root_obs: Float[Tensor, "B 3 6 7"], root_is_player1: Bool[Tensor, "B"],
               add_noise: bool = False) -> Float[Tensor, "B 7"]:
        tree = Tree.alloc(root_obs.shape[0], self.cfg, self.env.device)
        return batched_search(tree, root_obs, root_is_player1, model, self.env, self.cfg, add_noise)

The payoff: single ↔ batched equivalence

Because the batched search runs the same algorithm as your single-game version (same PUCT, same negamax backup, same transitions), with add_noise=False the two must produce exactly the same visit counts. This is the best possible debugging tool: if your batched version is wrong, this test tells you immediately.

model = Connect4Model(device).eval()
cfg = MCTSConfig(sims=64, c_puct=1.5)
batched = BatchedMCTS(env, cfg)
tests.test_batched_mcts(lambda o, tm, add_noise=False: batched.search(model, o, tm, add_noise), model)