2️⃣ Single-Game MCTS

Learning Objectives
  • Implement a Node class, PUCT selection, expansion, and backup.
  • Assemble the full search loop and verify it finds tactical wins and blocks.

Before we build a fully batched version of Monte-Carlo Tree Search, let's build a simpler version of MCTS in plain Python that operates on a single board.

We store statistics on the edges of each node: a node holds per-action arrays N (visit counts) and W (value sums, from this node's mover perspective), plus the network priors P and a dict of child Nodes created lazily. The substrate is the provided Connect4Env with a batch of size 1, so transitions are identical to the batched version (this matters for section 3).

Action = int # 0, 1, 2, 3, 4, 5, 6

@dataclass
class Node:
    """One node of the single-game search tree = one board state. Statistics live on the **edges**
    (per-action length-7 arrays), filled in as the search runs.

    Attributes:
        obs            (1, 3, H, W) the connect4 board for this state
        is_player1     (1,) Bool[Tensor, "1"] whos turn to play
        is_terminal    bool : True if this state is game-over. Set by `expand`.
        terminal_value float : result from perspective of `is_player1` (= -reward of the move that
                       created the node). Only meaningful when `is_terminal=True`. Set by `expand`.
        P              (7,) network prior over each move. Set by `evaluate`.
        legal          (7,) bool legal-column mask. Set by `evaluate`.
        N              (7,) per-edge visit counts N(s, a). Updated by `backup`.
        W              (7,) per-edge value sums W(s, a), accumulated from the perspective of THIS
                       node's mover (`is_player1`): backup negates each child value once per step on
                       the way up, so what lands here is already in this node's convention. Updated by `backup`.
        children       dict[Action, Node] : child nodes per played action, created lazily by `expand`.
        parent         Node | None : the node we were expanded from (`None` at the root).
        parent_action  Action | None : the column that led from `parent` to this node. Updated by `expand`.

    Properties:
        Q              (7,) per-edge mean value W / max(N, 1).
    """
    obs: Float[Tensor, "1 3 H W"] #lives on GPU
    is_player1: Bool[Tensor, "1"]
    num_actions: int = 7
    is_terminal: bool = False
    terminal_value: float = 0.0
    P: Tensor | None = None
    legal: Tensor | None = None
    N: Tensor = None
    W: Tensor = None
    children: dict[Action, 'Node'] = field(default_factory=dict)
    parent: 'Node | None' = None
    parent_action: Action | None = None

    def __post_init__(self):
        if self.N is None:
            self.N = torch.zeros(self.num_actions) # N(s,a_0), ..., N(s,a_6)
        if self.W is None:
            self.W = torch.zeros(self.num_actions) # W(s,a_0), ..., W(s,a_6)

    @property
    def Q(self):
        return self.W / torch.maximum(self.N, torch.ones_like(self.N))


# Sentinel "no action" returned by `select` when the selection walk stops on a terminal node.
# We use -100 rather than None because None silently inserts a dim when used to index a tensor
# (`t[None]`), whereas -100 is always an illegal column and raises loudly if ever misused as an index.
NULL_ACTION = -100

tests.test_mcts_node(Node)

Exercise - implement select_child (PUCT)

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 10-15 minutes on this exercise.

Return the action maximising the PUCT score

$$ PUCT(s,a) = Q(s,a) + c_\text{puct} p_\theta(a|s) \frac{ \sqrt{1 + \sum_{a'} N(s,a')} }{1 + N(s,a)} $$
Illegal moves always have $PUCT(s,a) = -\infty$. You can use node.legal to obtain a legal move mask for the board state at that node.

Use the same form in the batched version so the two agree exactly. \ You may assume there will always be at least one legal action.

def select_child(node : Node, c_puct: float) -> Action:
    """Return the action (column index) with the highest PUCT score at `node`.

    Everything you need is already stored on the node (see the `Node` class above):
        node.Q      (7,) per-edge mean values Q(s,a) = W / max(N, 1)  — a property, kept up to date
        node.N      (7,) per-edge visit counts N(s,a); its sum is the parent visit total Σ_a' N(s,a')
        node.P      (7,) the network prior p_θ(a|s), set when the node was evaluated
        node.legal  (7,) bool mask of legal columns

    Compute the PUCT score per action, set illegal actions to -inf, and take the argmax.

    Args:
        node:   the node to pick an action from (its edge statistics are already populated)
        c_puct: the exploration/exploitation trade-off constant c

    Returns:
        int: the PUCT-maximising column 0-6 (you may assume at least one action is legal).
        Return a plain Python int — `int(scores.argmax())`, not a 0-dim Tensor. A Tensor looks
        identical in comparisons but breaks `action in node.children` dict lookups in `select`
        (tensors hash by identity, not value).
    """
    raise NotImplementedError()


tests.test_select_child(select_child, Node)
Solution
def select_child(node : Node, c_puct: float) -> Action:
    """Return the action (column index) with the highest PUCT score at `node`.

    Everything you need is already stored on the node (see the `Node` class above):
        node.Q      (7,) per-edge mean values Q(s,a) = W / max(N, 1)  — a property, kept up to date
        node.N      (7,) per-edge visit counts N(s,a); its sum is the parent visit total Σ_a' N(s,a')
        node.P      (7,) the network prior p_θ(a|s), set when the node was evaluated
        node.legal  (7,) bool mask of legal columns

    Compute the PUCT score per action, set illegal actions to -inf, and take the argmax.

    Args:
        node:   the node to pick an action from (its edge statistics are already populated)
        c_puct: the exploration/exploitation trade-off constant c

    Returns:
        int: the PUCT-maximising column 0-6 (you may assume at least one action is legal).
        Return a plain Python int — `int(scores.argmax())`, not a 0-dim Tensor. A Tensor looks
        identical in comparisons but breaks `action in node.children` dict lookups in `select`
        (tensors hash by identity, not value).
    """
    sumN = node.N.sum()
    puct_score = node.Q + c_puct * node.P * torch.sqrt(sumN + 1.0) / (1.0 + node.N)
    legal_score = torch.where(node.legal, puct_score, -torch.inf)
    return int(legal_score.argmax())


tests.test_select_child(select_child, Node)

Exercise - implement select

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 5-10 minutes on this exercise.

Starting at the root, take the PUCT-best action and step into that child, over and over, until you reach a node that is either terminal or whose PUCT-best action has not yet been expanded (no child). Return (node, action): the node to expand and the action to play. If you stopped on a terminal node there's nothing to expand, so return (node, NULL_ACTION).

You may assume root is never terminal. Use select_child to select the next action.

def select(root: Node, 
           c_puct: float
) -> tuple[Node, Action]:
    """Walks down the tree from root: at each node take the PUCT-best action (`select_child`);
    if that action has a child, step into it and repeat. Halt at the first node whose PUCT-best
    action has NO child yet (`action not in node.children`), or on a terminal node.

    Note the halting rule is a property of the node's best *action*, not of the node itself: a
    node may already have children on other actions (and visits on them) while its current
    PUCT-best action is still unexpanded — that node is where the walk stops.

    Args:
        root: The root node of the search tree (never terminal).
        c_puct: The PUCT constant.

    Returns:
        * `(node, action)`: the node to expand and the action to play, OR
        * `(node, NULL_ACTION)` if the walk stopped on a terminal node (nothing to expand).
    """
    assert not root.is_terminal, "select: root must not be terminal"
    raise NotImplementedError()


tests.test_select(select, Node)
Solution
def select(root: Node, 
           c_puct: float
) -> tuple[Node, Action]:
    """Walks down the tree from root: at each node take the PUCT-best action (`select_child`);
    if that action has a child, step into it and repeat. Halt at the first node whose PUCT-best
    action has NO child yet (`action not in node.children`), or on a terminal node.

    Note the halting rule is a property of the node's best *action*, not of the node itself: a
    node may already have children on other actions (and visits on them) while its current
    PUCT-best action is still unexpanded — that node is where the walk stops.

    Args:
        root: The root node of the search tree (never terminal).
        c_puct: The PUCT constant.

    Returns:
        * `(node, action)`: the node to expand and the action to play, OR
        * `(node, NULL_ACTION)` if the walk stopped on a terminal node (nothing to expand).
    """
    assert not root.is_terminal, "select: root must not be terminal"
    node = root
    while not node.is_terminal:
        a = select_child(node, c_puct)
        if a not in node.children:
            return node, a            # found unexpanded action
        node = node.children[a]       # descent into the chosen child
    return node, NULL_ACTION          # hit terminal node


tests.test_select(select, Node)

Exercise - implement expand

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 10-15 minutes on this exercise.
  1. Play the given action action on the current board state node.obs using env.step to recieve a new board state.
  2. Construct a new Node object with the result, and store it as a child of the current node.
  3. Mark the child terminal if the move ended the game, and set the terminal_value to -reward recieved from the environment.
  4. Set parent/parent_action so we can follow from the child back to it's parent later on.

Why terminal_value = -reward?

env.step reports the reward to the player who just moved, but the child's mover is the opponent, so from the child's perspective that value is negated.

Don't forget that node.is_player1 is a Bool[Tensor, "1"], rather than bool so you need to negate it with ~ rather than not.

@torch.no_grad()
def expand(node: Node, 
           action: int, 
           env: Connect4Env) -> Node:
    """
    Plays the given action on the current board state, constructs a new node with the result,
    stores it as a child of the current node, and returns it.

    Side effects:
        * Mutates the current node to add a new child node.

    Args:
        node: The current node in the search tree.
        action: The action to play.
        env: The Connect4Env instance

    Returns:
        The newly created child node after the action has been played.
    """
    assert not node.is_terminal, "expand: cannot expand a terminal node"
    raise NotImplementedError()


tests.test_expand(expand)
Solution
@torch.no_grad()
def expand(node: Node, 
           action: int, 
           env: Connect4Env) -> Node:
    """
    Plays the given action on the current board state, constructs a new node with the result,
    stores it as a child of the current node, and returns it.

    Side effects:
        * Mutates the current node to add a new child node.

    Args:
        node: The current node in the search tree.
        action: The action to play.
        env: The Connect4Env instance

    Returns:
        The newly created child node after the action has been played.
    """
    assert not node.is_terminal, "expand: cannot expand a terminal node"
    next_obs, done, reward = env.step(node.obs, action, node.is_player1)
    child = Node(obs = next_obs,
                 is_player1 = ~node.is_player1,
                 is_terminal = done,
                 terminal_value = -reward,
                 parent = node,
                 parent_action = action)
    node.children[action] = child
    return child


tests.test_expand(expand)

Exercise - implement evaluate

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 10-15 minutes on this exercise.

If the node is terminal, then we already know the value, it was stored during its creation in node.terminal_value. If the node is non-terminal, we estimate the value using the critic. We also set the node.P and node.legal attributes using the actor. Recall that the model returns a tuple of (value, logits) for the critic and actor heads.

@torch.no_grad()
def evaluate(node: Node, model: nn.Module, env: Connect4Env) -> float:
    """
    Estimates the value of a node from its mover's perspective.

    Side effects:
        If the node is non-terminal, mutates the node to 
        set the `node.P` and `node.legal` attributes.

    Args:
        node: The node to estimate the value of.
        model: The model to use.
        env: The Connect4Env instance

    Returns:
        For terminal nodes: the stored `terminal_value` (already from the mover's perspective).
        For non-terminal nodes: the critic head's value estimate, from the mover's perspective.
            Use `eval_net(model, obs, is_player1) -> (value, logits)` for the forward pass; the
            returned tensors live on the GPU, while the tree's per-node stats live in CPU memory.
    """
    raise NotImplementedError()


tests.test_evaluate(evaluate)
Solution
@torch.no_grad()
def evaluate(node: Node, model: nn.Module, env: Connect4Env) -> float:
    """
    Estimates the value of a node from its mover's perspective.

    Side effects:
        If the node is non-terminal, mutates the node to 
        set the `node.P` and `node.legal` attributes.

    Args:
        node: The node to estimate the value of.
        model: The model to use.
        env: The Connect4Env instance

    Returns:
        For terminal nodes: the stored `terminal_value` (already from the mover's perspective).
        For non-terminal nodes: the critic head's value estimate, from the mover's perspective.
            Use `eval_net(model, obs, is_player1) -> (value, logits)` for the forward pass; the
            returned tensors live on the GPU, while the tree's per-node stats live in CPU memory.
    """
    if node.is_terminal:
        return node.terminal_value                                   
    value, logits = eval_net(model, node.obs, node.is_player1)
    value = value.squeeze().cpu()
    logits = logits.squeeze().cpu()
    node.legal = env.legal_action_mask(node.obs)[0].cpu()
    legal_logits = torch.where(node.legal, logits, -torch.inf)
    node.P = torch.softmax(legal_logits, dim=-1)
    return float(value)


tests.test_evaluate(evaluate)

Exercise - implement backup

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 5-10 minutes on this exercise.

Backup walks from a node back along the parent pointers to the root, negating the value at every step (negamax), and adding one to the visit count N and the value to the value sum W at each edge. Mutates N/W in place; returns nothing.

Don't forget to negate the value at every step (negamax)!

def backup(leaf: Node, value: float) -> None:
    """Walks from a node back up the tree to the root, updating the visit counts and value sums at each edge.

    Side effects:
        * Mutates the visit counts and value sums at each edge.

    Args:
        leaf: The leaf node to backup from.
        value: The value of the leaf node from its mover's perspective.

    Returns:
        None
    """
    raise NotImplementedError()


tests.test_backup(backup, Node)
Solution
def backup(leaf: Node, value: float) -> None:
    """Walks from a node back up the tree to the root, updating the visit counts and value sums at each edge.

    Side effects:
        * Mutates the visit counts and value sums at each edge.

    Args:
        leaf: The leaf node to backup from.
        value: The value of the leaf node from its mover's perspective.

    Returns:
        None
    """
    v, curr = value, leaf
    while curr.parent is not None:
        v = -v
        curr.parent.N[curr.parent_action] += 1.0
        curr.parent.W[curr.parent_action] += v
        curr = curr.parent


tests.test_backup(backup, Node)

Exercise - implement sample_action

Difficulty: 🔴⚪⚪⚪⚪
Importance: 🔵🔵⚪⚪⚪
You should spend up to 5-10 minutes on this exercise.

After the MCTS sampling, we then define the tree policy based on the visit counts. This is the actual policy that is used to play the game.

$$ \pi(a|s) = \frac{N(s,a)^{1/\tau}}{\sum_{a'} N(s,a')^{1/\tau}} $$
where $\tau$ is the temperature parameter.

Manually handle very small temperatures (temperature < 1e-8) by returning the greedy action.

Use torch.multinomial to sample from the distribution.

def sample_tree_policy(visits: Float[Tensor, "7"], 
                       temperature: float = 1.0,
) -> Action:
    """Samples an action from the tree policy. `temperature` -> 0 is greedy argmax; `temperature` = 1
    samples proportional to visits; larger temperatures flattens the distribution.

    Args:
        visits: The visit counts for each action, shape `(7,)`.
        temperature: The temperature parameter.

    Returns:
        int: The action to sample.
    """
    raise NotImplementedError()


tests.test_sample_tree_policy(sample_tree_policy)
Solution
def sample_tree_policy(visits: Float[Tensor, "7"], 
                       temperature: float = 1.0,
) -> Action:
    """Samples an action from the tree policy. `temperature` -> 0 is greedy argmax; `temperature` = 1
    samples proportional to visits; larger temperatures flattens the distribution.

    Args:
        visits: The visit counts for each action, shape `(7,)`.
        temperature: The temperature parameter.

    Returns:
        int: The action to sample.
    """
    if temperature < 1e-8:
        return int(visits.argmax())   # tiny temperature -> greedy (avoids visits ** (1/temp) overflow)

    visits_temp = visits ** (1 / temperature)

    probs = visits_temp / visits_temp.sum()
    return int(torch.multinomial(probs, 1))


tests.test_sample_tree_policy(sample_tree_policy)
Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 25-35 minutes on this exercise.

Each simulation grows the tree by exactly one node, via the four MCTS phases you just built. The root is created and evaluate'd for you; we then loop the four phases of MCTS for cfg.sims times.

  1. SELECT. We select from the root to a leaf node.
  2. EXPAND. if the leaf node is non-terminal, we expand it by adding a new child node.
  3. EVALUATE. we evaluate the leaf node to get its value.
  4. BACKUP. we backup the obtained value of the leaf back up the tree.

After cfg.sims iterations, the root's visit counts root.N are returned (the return is given; passing return_root=True also hands back the root node, which we use below to draw the tree).

@torch.no_grad()
def mcts_search(
    root_obs: Float[Tensor, "1 3 H W"],
    root_is_player1: Bool[Tensor, "1"],
    model: nn.Module,
    env: Connect4Env,
    cfg: MCTSConfig,
    add_noise: bool = False,
    return_root: bool = False,
) -> Float[Tensor, "7"]:
    """Run `cfg.sims` MCTS simulations from the root; return the root's visit counts `(7,)` — or
    `(visit_counts, root)` when `return_root=True` (e.g. to inspect / visualise the search tree).
    """
    root = Node(root_obs, root_is_player1)
    evaluate(root, model, env) #required to set root.P and root.legal
    raise NotImplementedError()
    return (root.N, root) if return_root else root.N

# First check the search logic in isolation, with a dummy (uniform-policy, zero-value) network:
# a forced win-in-one must be found purely from the terminal reward backing up the tree.
tests.test_mcts_search(mcts_search)
# Then confirm the same search drives the real network correctly:
tests.test_mcts_search(mcts_search, Connect4Model(device).eval())
Solution
@torch.no_grad()
def mcts_search(
    root_obs: Float[Tensor, "1 3 H W"],
    root_is_player1: Bool[Tensor, "1"],
    model: nn.Module,
    env: Connect4Env,
    cfg: MCTSConfig,
    add_noise: bool = False,
    return_root: bool = False,
) -> Float[Tensor, "7"]:
    """Run `cfg.sims` MCTS simulations from the root; return the root's visit counts `(7,)` — or
    `(visit_counts, root)` when `return_root=True` (e.g. to inspect / visualise the search tree).
    """
    root = Node(root_obs, root_is_player1)
    evaluate(root, model, env) #required to set root.P and root.legal
    for _ in range(cfg.sims):
        node, action = select(root, cfg.c_puct)
        if not node.is_terminal:
            leaf = expand(node, action, env)
        else:
            leaf = node
        leaf_value = evaluate(leaf, model, env)
        backup(leaf, leaf_value)
    return (root.N, root) if return_root else root.N

# First check the search logic in isolation, with a dummy (uniform-policy, zero-value) network:
# a forced win-in-one must be found purely from the terminal reward backing up the tree.
tests.test_mcts_search(mcts_search)
# Then confirm the same search drives the real network correctly:
tests.test_mcts_search(mcts_search, Connect4Model(device).eval())

Watch it find a win

To show the search does the tactical work, we hand MCTS a dummy network that returns value 0 and a uniform prior for every board. So whatever the search concentrates on comes purely from the tree policy backing up terminal rewards, not from the net.

Below it's Red (X) to move on a crowded mid-game position where Red already has a diagonal three, (5,1)-(4,2)-(3,3). Dropping in column 4 falls to (2,4) and completes the / diagonal and with a uniform prior, only the visit counts can single it out.

The right-hand bars are the visit-count policy $\pi(a) = N(s,a) / \sum_{a'} N(s,a')$: the improved policy that we will train the policy network to mimic.

class DummyNet(nn.Module):

    def forward(self, x):
        b = x.shape[0]
        values = torch.zeros(b, device=x.device)
        logits = torch.zeros(b, 7, device=x.device)
        return values, logits

model = DummyNet()
obs, red = tests.diagonal_win_red()

print("Starting position (X = Red to move):")
print(render_board(obs, is_player1=True))

visits, root = mcts_search(obs, torch.tensor([red], device=device), model, env,
                           MCTSConfig(sims=32), return_root=True)
print("\nMCTS visit counts per column:", visits.int().tolist())
chosen = int(visits.argmax())
print(f"Most-visited column: {chosen}  ({int(visits[chosen])} of {int(visits.sum())} visits)")

obs_after = place_piece(obs, chosen, is_player1=True)
print(f"\nBoard after X plays column {chosen}  (completes the diagonal):")
print(render_board(obs_after))

# board + the visit-count policy pi(a) = N(s,a) / sum_a' N(s,a'), chosen column highlighted
plot_board_and_policy(obs, visits / visits.sum(), chosen_action=chosen,
                      title="MCTS finds the diagonal win")
# the search tree the simulations grew: edge thickness ~ visit count, terminal leaves in yellow
plot_mcts_tree(root, max_depth=2, title="MCTS search tree (after 32 sims)")
assert chosen == 4, "MCTS should find the diagonal win"