1️⃣ The Environment & Network
Learning Objectives
- Use the provided vectorised Connect-4 environment, and understand how the board is encoded.
- Build the AlphaZero policy-value network.
⚠️ Note: The exercises in this section are pretty similar to the ones that we did on CNN/ResNets day. If you're already comfortable and don't desire to do more exercises on building a bunch of
nn.Modulesgiven a blueprint, feel free to just read the descrption of the environment and the actor/critic network, and then skip the exercises and move on.
The Connect-4 environment (given)
Connect4Env (in game.py) is a fully vectorised environment: it operates on a batch of
N boards at once. The interface:
class Connect4Env:
"""
Vectorized, GPU-friendly Connect 4 environment.
- Board shape: height x width (default 6 x 7)
- Observation: (N, 3, H, W) float32, channels = [empty, green, blue]
Reward scheme (mover's perspective): +1 win, -2 illegal move (worse than a
normal loss), 0 for a draw or a game that continues. ``step`` only ever checks
whether the *mover* just won or drew: you cannot lose on your own move.
Environments auto-reset in-place whenever done or illegal, but the returned
`done` indicates the terminal transition.
"""
def reset(self, num_env: int) -> Float[Tensor, "num_env 3 H W"]:
"""Return an initial observation tensor of shape (num_env, 3, H, W), channels [empty, green, blue]."""
@torch.no_grad()
def step(
self,
obs: Float[Tensor, "... 3 H W"],
actions: Int[Tensor, "..."] | int,
is_player1: Bool[Tensor, "..."] | bool
) -> Tuple[Float[Tensor, "... 3 H W"], Bool[Tensor, "..."] | bool, Float[Tensor, "..."] | float]:
"""
Apply one move by the current player on a batch or single Connect 4 board.
If `is_player1` is True, the mover is green (channel 1); if False, blue (channel 2).
Args:
obs (Tensor): Input observation; either shape (N, 3, H, W) or (3, H, W).
actions (int or Tensor): Actions to apply; either scalar int or (N,) long tensor.
is_player1 (bool or Tensor): Indicates the player's color; either scalar bool or (N,) bool tensor.
Returns:
Tuple[Float[Tensor, "... 3 H W"], Bool[Tensor, "..."] | bool, Float[Tensor, "..."] | float]:
Next observation : Float[Tensor, "... 3 H W"],
done mask(s) : Bool[Tensor, "..."] | bool,
reward(s) : Float[Tensor, "..."] | float,
If a single board is provided, returns scalars for done and reward.
Raises:
AssertionError: If the input `obs` does not have valid shape or size.
"""
@torch.no_grad()
def legal_action_mask(self,
obs: Float[Tensor, "... 3 H W"]
) -> Bool[Tensor, "... W"]:
"""Boolean mask of columns with space.
obs: (N, 3, H, W) -> (N, W), or (3, H, W) -> (W,).
"""
Let's look at a board:
env = Connect4Env(device=device)
obs = env.reset(1)
obs, _, _ = env.step(obs, torch.tensor([3], device=device), torch.tensor([True], device=device))
obs, _, _ = env.step(obs, torch.tensor([3], device=device), torch.tensor([False], device=device))
print(render_board(obs, is_player1=True))
The mover's perspective: eval_net
The network sees a board from the perspective of the player to move: its own pieces in channel 1, the opponent's in channel 2. But the environment stores boards in absolute order: player 1's pieces in channel 1, player 2's in channel 2 (and empty in channel 0). So before calling the network we canonicalise the inputs: if the mover is player 2, swap channels 1 and 2. This simplifies things as essentialy the network only every needs to learn to play as one colour (as we invert the colours on the opponents turn).
Exercise - implement canonicalise_obs
Implement the canonicalise_obs function, which swaps the player channels based on the is_player1 boolean. This function is essentially a vectorized version of the following code:
def canonicalise_obs(obs_abs : Float[Tensor, "3 H W"],
is_player1 : bool
) -> Float[Tensor, "3 H W"]:
if is_player1:
return obs_abs
else:
return obs_abs[..., [0, 2, 1], :, :] # reorder channels [empty, p1, p2] -> [empty, p2, p1] (works batched or not)
Hint: Use torch.where to conditionally swap the channels.
def canonicalise_obs(obs : Float[Tensor, "batch 3 H W"],
is_player1 : Bool[Tensor, "batch"] | None = None
) -> Float[Tensor, "batch 3 H W"]:
"""
Canonicalise the observation for the mover's perspective.
Returns the same tensor as input, but with obs_abs[b,1,:,:] and obs_abs[b,2,:,:] swapped iff is_player1[b] is False, for all b.
If is_player1 is None, return the input tensor unchanged.
"""
raise NotImplementedError()
tests.test_canonicalise_obs(canonicalise_obs)
Solution
def canonicalise_obs(obs : Float[Tensor, "batch 3 H W"],
is_player1 : Bool[Tensor, "batch"] | None = None
) -> Float[Tensor, "batch 3 H W"]:
"""
Canonicalise the observation for the mover's perspective.
Returns the same tensor as input, but with obs_abs[b,1,:,:] and obs_abs[b,2,:,:] swapped iff is_player1[b] is False, for all b.
If is_player1 is None, return the input tensor unchanged.
"""
if is_player1 is None:
return obs
is_player1 = einops.repeat(is_player1, "batch -> batch 1 1 1")
swap_obs = obs[:, [0, 2, 1]]
obs_canon = torch.where(is_player1, obs, swap_obs)
return obs_canon
tests.test_canonicalise_obs(canonicalise_obs)
With canonicalise_obs in hand, eval_net (given) is just a thin wrapper: canonicalise the board
to the mover's perspective, run the network, and return the value (a (B,) tensor, from the
mover's perspective) and the column logits (B, 7).
def eval_net(
model: nn.Module,
obs_abs: Float[Tensor, "batch 3 H W"],
is_player1: Bool[Tensor, "batch"],
) -> tuple[Float[Tensor, "batch"], Float[Tensor, "batch 7"]]:
"""Run the network on absolute observations, canonicalised to the mover's perspective.
Args:
model: the Connect4Model
obs_abs: (B, 3, H, W) absolute boards (channels [empty, p1, p2])
is_player1: (B,) whether player-1 is to move (selects the canonical view)
Returns:
value: (B,) the position's value for the mover, in [-1, 1] (tanh-squashed)
logits: (B, 7) one policy logit per column
"""
obs_canon = canonicalise_obs(obs_abs, is_player1)
value, logits = model(obs_canon.contiguous())
return value.reshape(-1), logits
The network architecture
The network is a small residual CNN with a shared trunk and two heads: an actor head (a prior over the 7 columns) and a critic head (how good the position is for the mover):
flowchart TD
I["obs (B, 3, 6, 7)<br/>channels: empty, mover, opponent"] --> C["initial Conv2d 3 to 128<br/>3x3, pad 1, then BatchNorm, ReLU"]
C --> R1["ResBlock(128)"]
R1 --> R2["ResBlock(128)"]
R2 --> VH["critic"]
R2 --> PH["actor"]
VH --> V["value (B,)<br/>mover's expected result"]
PH --> P["logits (B, 7)<br/>one score per column"]
Each residual block adds its input back after two conv layers (the skip connection), which keeps deep stacks easy to train:
flowchart TD
X(["x"]) --> A["Conv 3x3, BN, ReLU"]
A --> B["Conv 3x3, BN"]
X -. skip .-> S(("+"))
B --> S
S --> RO["ReLU"]
RO --> O["out"]
The two heads each collapse the 128-channel trunk down to their output:
flowchart TD
subgraph "critic (value head)"
direction TB
XV["trunk<br/>(B, 128, 6, 7)"] --> AV["Conv 1x1<br/>128 to 3"] --> BV["BN, ReLU"] --> FV["flatten"] --> LV["Linear<br/>3*6*7 to 32"] --> RV["ReLU"] --> OV["Linear<br/>32 to 1"] --> TV["tanh"] --> VV["value<br/>(B,)"]
end
subgraph "actor (policy head)"
direction TB
XP["trunk<br/>(B, 128, 6, 7)"] --> AP["Conv 1x1<br/>128 to 32"] --> BP["BN, ReLU"] --> FP["flatten"] --> LP["Linear<br/>32*6*7 to 7"] --> OP["logits<br/>(B, 7)"]
end
(Note the two heads shrink the trunk to a different number of channels: 3 for the critic, 32 for the actor.)
Building the network
We'll build the network in four small pieces, each with its own test: the ResBlock the trunk
stacks, the Critic (value head) and Actor (policy head), and finally the Connect4Model that
wires the shared trunk and the two heads together. This is just like how we built CNNs in [0.2]. Throughout: 3×3 convs use padding=1, the 1×1 convs in the heads use padding=0, and
each conv is followed by BatchNorm.
Exercise - implement ResBlock
A residual block runs its input through two 3×3 conv→BatchNorm layers and adds the original input back
before the final ReLU (the skip connection). The block only has to learn a residual, which keeps
deep stacks easy to train.
class ResBlock(nn.Module):
"""A residual block, shape-preserving on (B, channels, H, W):
x -> Conv2d 3×3 -> BatchNorm -> ReLU -> Conv2d 3×3 -> BatchNorm -> (+ x) -> ReLU
Both convs are `channels -> channels`, 3×3, padding=1, `bias=False`: each conv is followed
immediately by a BatchNorm, whose learned shift makes a conv bias redundant.
"""
def __init__(self, channels: int):
super().__init__()
raise NotImplementedError()
def forward(self, x: Float[Tensor, "B C H W"]) -> Float[Tensor, "B C H W"]:
"""Two conv-BN layers (ReLU between), then add the input back (skip) and ReLU.
Args:
x: (B, C, H, W) input feature map
Returns:
(B, C, H, W) output feature map (shape preserved)
"""
raise NotImplementedError()
tests.test_resblock(ResBlock)
Solution
class ResBlock(nn.Module):
"""A residual block, shape-preserving on (B, channels, H, W):
x -> Conv2d 3×3 -> BatchNorm -> ReLU -> Conv2d 3×3 -> BatchNorm -> (+ x) -> ReLU
Both convs are `channels -> channels`, 3×3, padding=1, `bias=False`: each conv is followed
immediately by a BatchNorm, whose learned shift makes a conv bias redundant.
"""
def __init__(self, channels: int):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x: Float[Tensor, "B C H W"]) -> Float[Tensor, "B C H W"]:
"""Two conv-BN layers (ReLU between), then add the input back (skip) and ReLU.
Args:
x: (B, C, H, W) input feature map
Returns:
(B, C, H, W) output feature map (shape preserved)
"""
residual = x
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)) + residual)
return x
tests.test_resblock(ResBlock)
Exercise - implement Critic (the value head)
The critic maps the shared trunk to a single scalar giving the value of the position for the side
to move. It shrinks the 128-channel trunk with a 1×1 conv, then flattens and runs a small MLP down
to one number. Finish with a tanh so the output is squashed into [-1, 1] so it lies in the same range as the
game-outcome value targets z (−1 loss / 0 draw / +1 win), which keeps the head's range matched to
its target and stops it from drifting to large magnitudes. Output shape: (B,).
class Critic(nn.Module):
"""The value head, (B, in_channels, H, W) trunk features -> (B,) values in [-1, 1]:
Conv2d 1×1 `in_channels -> conv_out` (bias=True) -> BatchNorm -> ReLU -> Flatten
-> Linear(conv_out*height*width -> 32) -> ReLU -> Linear(32 -> 1) -> Tanh -> squeeze to (B,)
The 1×1 conv shrinks the trunk to `conv_out` channels before flattening (a per-square shared
Linear), and the final tanh bounds the value to [-1, 1] to match the game-outcome targets z.
The conv keeps `bias=True` (the default) even though a BatchNorm follows; the tests load
reference weights into your layers, so match these bias settings exactly.
"""
def __init__(self, in_channels=128, conv_out=3, height=6, width=7):
super().__init__()
raise NotImplementedError()
def forward(self, x: Float[Tensor, "B C H W"]) -> Float[Tensor, "B"]:
"""Map the shared trunk to a scalar value for the side to move.
Args:
x: (B, C, 6, 7) shared-trunk features
Returns:
(B,) the position's value for the mover, in [-1, 1] (tanh-squashed)
"""
return self.net(x).squeeze(-1) # (B, 1) -> (B,)
tests.test_critic(Critic)
Solution
class Critic(nn.Module):
"""The value head, (B, in_channels, H, W) trunk features -> (B,) values in [-1, 1]:
Conv2d 1×1 `in_channels -> conv_out` (bias=True) -> BatchNorm -> ReLU -> Flatten
-> Linear(conv_out*height*width -> 32) -> ReLU -> Linear(32 -> 1) -> Tanh -> squeeze to (B,)
The 1×1 conv shrinks the trunk to `conv_out` channels before flattening (a per-square shared
Linear), and the final tanh bounds the value to [-1, 1] to match the game-outcome targets z.
The conv keeps `bias=True` (the default) even though a BatchNorm follows; the tests load
reference weights into your layers, so match these bias settings exactly.
"""
def __init__(self, in_channels=128, conv_out=3, height=6, width=7):
super().__init__()
# The 1x1 conv is a shared per-cell Linear: it maps each square's `in_channels`-vector down
# to `conv_out` channels with the *same* weights at every square, shrinking the trunk before
# we flatten and run the small MLP. Far fewer params than flattening all 128 channels straight
# into a Linear, and it keeps the board's spatial layout intact.
self.net = nn.Sequential(
nn.Conv2d(in_channels, conv_out, 1, bias=True),
nn.BatchNorm2d(conv_out),
nn.ReLU(),
nn.Flatten(),
nn.Linear(conv_out * height * width, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Tanh(), # squash to [-1, 1] so the value head's range matches the game-outcome targets z
)
def forward(self, x: Float[Tensor, "B C H W"]) -> Float[Tensor, "B"]:
"""Map the shared trunk to a scalar value for the side to move.
Args:
x: (B, C, 6, 7) shared-trunk features
Returns:
(B,) the position's value for the mover, in [-1, 1] (tanh-squashed)
"""
return self.net(x).squeeze(-1) # (B, 1) -> (B,)
tests.test_critic(Critic)
Exercise - implement Actor (the policy head)
The actor maps the shared trunk to 7 logits, one for each column on the board. Same
1×1-conv → flatten → Linear pattern as the critic, but the final Linear produces width outputs.
Output shape: (B, 7).
class Actor(nn.Module):
"""The policy head, (B, in_channels, H, W) trunk features -> (B, width) column logits:
Conv2d 1×1 `in_channels -> conv_out` (bias=True) -> BatchNorm -> ReLU -> Flatten
-> Linear(conv_out*height*width -> width)
Same shrink-then-flatten pattern as the Critic (note the different `conv_out` default: 32
here vs 3 there), but the final Linear emits one logit per column and there is no squashing —
these are raw logits, softmaxed later. Conv bias=True (the default), as in the Critic.
"""
def __init__(self, in_channels=128, conv_out=32, height=6, width=7):
super().__init__()
raise NotImplementedError()
def forward(self, x: Float[Tensor, "B C H W"]) -> Float[Tensor, "B 7"]:
"""Map the shared trunk to one policy logit per column.
Args:
x: (B, C, 6, 7) shared-trunk features
Returns:
(B, 7) one logit per column
"""
return self.net(x)
tests.test_actor(Actor)
Solution
class Actor(nn.Module):
"""The policy head, (B, in_channels, H, W) trunk features -> (B, width) column logits:
Conv2d 1×1 `in_channels -> conv_out` (bias=True) -> BatchNorm -> ReLU -> Flatten
-> Linear(conv_out*height*width -> width)
Same shrink-then-flatten pattern as the Critic (note the different `conv_out` default: 32
here vs 3 there), but the final Linear emits one logit per column and there is no squashing —
these are raw logits, softmaxed later. Conv bias=True (the default), as in the Critic.
"""
def __init__(self, in_channels=128, conv_out=32, height=6, width=7):
super().__init__()
# 1x1 conv = shared per-cell Linear (see Critic), shrinking the trunk before the flatten + FC.
self.net = nn.Sequential(
nn.Conv2d(in_channels, conv_out, 1, bias=True),
nn.BatchNorm2d(conv_out),
nn.ReLU(),
nn.Flatten(),
nn.Linear(conv_out * height * width, width),
)
def forward(self, x: Float[Tensor, "B C H W"]) -> Float[Tensor, "B 7"]:
"""Map the shared trunk to one policy logit per column.
Args:
x: (B, C, 6, 7) shared-trunk features
Returns:
(B, 7) one logit per column
"""
return self.net(x)
tests.test_actor(Actor)
Exercise - implement Connect4Model
Now assemble the full network: a stem (3×3 conv → BN → ReLU) lifting the 3-channel board to
channels, two ResBlocks, then the critic and actor heads on the shared trunk. forward
returns (value, logits).
class Connect4Model(nn.Module):
"""The full AlphaZero network: a shared convolutional trunk and the two heads.
Architecture (build the modules in this order — see note below):
- `self.features` (the trunk): Conv2d 3×3 `3 -> channels` (padding=1, bias=True) -> BatchNorm
-> ReLU, then two `ResBlock(channels)`.
- `self.critic = Critic(channels, critic_conv_out, height, width)` on the trunk output.
- `self.actor = Actor(channels, actor_conv_out, height, width)` on the trunk output.
Note: create trunk, then critic, then actor, in that order. The tests check your model is
functionally identical to the reference by copying the reference weights in by
*parameter-creation order*, so a correct architecture built in a different order fails them.
"""
def __init__(self,
device,
channels: int = 128,
critic_conv_out: int = 3,
actor_conv_out: int = 32,
height: int = 6,
width: int = 7,
):
super().__init__()
raise NotImplementedError()
self.to(device)
def forward(
self, x: Float[Tensor, "B 3 6 7"]
) -> tuple[Float[Tensor, "B"], Float[Tensor, "B 7"]]:
"""Run the shared trunk then both heads on a canonical board batch.
Args:
x: (B, 3, 6, 7) canonical board (channels [empty, mover, opponent])
Returns:
value: (B,) the position's value for the mover, in [-1, 1] (tanh-squashed)
logits: (B, 7) one policy logit per column
"""
raise NotImplementedError()
summary(Connect4Model(device), input_size=(5, 3, 6, 7))
tests.test_connect4_model(Connect4Model)
Solution
class Connect4Model(nn.Module):
"""The full AlphaZero network: a shared convolutional trunk and the two heads.
Architecture (build the modules in this order — see note below):
- `self.features` (the trunk): Conv2d 3×3 `3 -> channels` (padding=1, bias=True) -> BatchNorm
-> ReLU, then two `ResBlock(channels)`.
- `self.critic = Critic(channels, critic_conv_out, height, width)` on the trunk output.
- `self.actor = Actor(channels, actor_conv_out, height, width)` on the trunk output.
Note: create trunk, then critic, then actor, in that order. The tests check your model is
functionally identical to the reference by copying the reference weights in by
*parameter-creation order*, so a correct architecture built in a different order fails them.
"""
def __init__(self,
device,
channels: int = 128,
critic_conv_out: int = 3,
actor_conv_out: int = 32,
height: int = 6,
width: int = 7,
):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, channels, 3, padding=1, bias=True),
nn.BatchNorm2d(channels),
nn.ReLU(),
ResBlock(channels),
ResBlock(channels),
)
self.critic = Critic(channels, critic_conv_out, height, width)
self.actor = Actor(channels, actor_conv_out, height, width)
self.to(device)
def forward(
self, x: Float[Tensor, "B 3 6 7"]
) -> tuple[Float[Tensor, "B"], Float[Tensor, "B 7"]]:
"""Run the shared trunk then both heads on a canonical board batch.
Args:
x: (B, 3, 6, 7) canonical board (channels [empty, mover, opponent])
Returns:
value: (B,) the position's value for the mover, in [-1, 1] (tanh-squashed)
logits: (B, 7) one policy logit per column
"""
x = self.features(x)
return self.critic(x), self.actor(x)
summary(Connect4Model(device), input_size=(5, 3, 6, 7))
tests.test_connect4_model(Connect4Model)