4️⃣ Training a Probe

Learning Objectives
  • Learn how to set up and train a linear probe
  • See how to train multiple probes at once, and log their performance to Weights & Biases

In this final section, we'll return to the linear probe from earlier, but discuss how you might go about training it from scratch.

We won't be doing a full-scale training run here, instead we'll just look at a small example involving the board_seqs_square_small.npy datasets we've already used (the actual probe we've been using in these exercises was trained using a larger board_seqs_int.pth dataset).

Note - if you're an ARENA participant who did the material on model training during the first week (or in the "transformer from scratch" material), this should all be familiar to you. I'd recommend doing that section before this one, unless you already have experience writing standard ML training loops.

One important thing to make clear: we're not actually training our transformer model here. If this was a standard training loop, we'd run our model in training mode and update its gradients in a way which reduces the cross entropy loss between its logit output and the true black/white/blank labels. Instead, we're running our model in inference mode, caching its residual stream values, applying our probe to these values, and then updating the weights of our probe in a way which reduces the cross entropy loss between our probe's output and the true mine/theirs/blank labels.

utils.plot_board_values(
    focus_states[0, :16],
    boards_per_row=8,
    board_titles=[f"Move {i}" for i in range(1, 17)],
    title="First 16 moves of first game",
    width=1400,
    height=440,
)
Click to see the expected output

Now, we'll create a dataclass to store our probe training args. This is a great way to keep all our variables in one place (and also it works well with VSCode's autocompletion features!). Also, note a cool feature of dataclasses - you can define attributes in terms of previous attributes (e.g. see the length attribute).

We've also included a setup_linear_probe method, which will give us a randomly initialized probe with appropriately normalized weights.

@dataclass
class ProbeTrainingArgs:
    # Determine the activations we'll train the probe on
    layer: int = 6
    pos_start: int = 5
    pos_end: int = -5  # i.e. we slice [pos_start: model.n_ctx + pos_end]

    # Game state (options are blank/mine/theirs)
    options: int = 3
    rows: int = 8
    cols: int = 8

    # Standard training hyperparams
    epochs: int = 3
    num_games: int = 10_000

    # Hyperparams for optimizer
    batch_size: int = 32
    lr: float = 1e-3  # high LR for quick convergence in these exercises; you may want to reduce
    betas: tuple[float, float] = (0.9, 0.99)
    weight_decay: float = 0.01

    # Saving & logging
    use_wandb: bool = False
    wandb_project: str | None = "othellogpt-probe"
    wandb_name: str | None = None

    # Code to get randomly initialized probe
    def setup_linear_probe(self, model: HookedTransformer):
        linear_probe = t.randn(model.cfg.d_model, self.rows, self.cols, self.options, device=device) / np.sqrt(
            model.cfg.d_model
        )
        linear_probe.requires_grad = True
        return linear_probe

A reminder of what some of these mean: * modes refers to "black to play/odd moves", "white to play/even moves", and "all moves". In the previous exercises, we only ever used "black to play" (this choice didn't really matter, since the model is detecting "my/their color" rather than "black/white"). * options (for our linear probe) refers to "empty", "black" and "white". After we've trained it, we'll create a version with "empty", "theirs" and "mine".

Now for our main block of code - we'll write a class for training our linear probe.

Exercise - fill in the missing code below

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 30-40 minutes on this exercise. There are several steps to this exercise, so after trying for some time you're recommended to look at the solution.

We've just left the training_step function incomplete, so that's the one you need to fill in. It should return the loss which you backpropagate on (you can see the code which actually performs the backprop algorithm below).

To make this exercise easier we're only having you train a single probe mode, on even moves (as a bonus exercise you can try training 3 modes at once - if you do this then remember to flip the labels for odd and even moves when you compute loss for the "even & odd" probe mode).

Your training_step function should:

  • Use model.run_with_cache to get the cached residual stream values for the layer you're training on.
    • Tip: you can use the stop_at_layer argument to prevent unnecessary computation
    • Tip: you can also the names_filter argument to only get the activations you need (this can be a hook name or a function mapping hook names to bools)
    • Tip: remember to use inference mode here, since we're only running our model to get the activations to feed into our probe.
    • Tip: remember that the games_id tensor you're given has shape (batch_size, 60), so you'll need to slice off the last game before passing into the model.
  • Slice the activations appropriately (we have args.pos_start and args.pos_end telling you which positions to use, and also we're only using the even positions for now)
  • Compute the probe output, by taking an inner product over the d_model dimension and summing over positions.
  • Convert the probe output to logprobs, and get correct logprobs by indexing into it.
    • Tip: we're using even positions, so remember that we convert from board state to our probe basis using "mine=black, theirs=white".
    • Tip: this indexing will require taking code from earlier which computes board state from the game sequence. We recommend just using get_board_states_and_legal_moves and keeping the state object returned - remember that this state object has values 0, 1, 2 corresponding to options empty, white, black.
  • Compute the loss, by summing over the row & column dimensions and averaging over the batch & seqpos dimensions (this is because the latter are data batch dimensions, but the former are effectively probe batch dimensions, since we're training multiple probes at once on each board square).
  • Return the loss (this code is given to you).

Once you've finished implementing this function, you can run the code below. Your training loss should quickly drop below 64 * ln(3) ≈ 70 (which is the loss you'd get from uniform random guesses about what occupies each square). This loss should fall to around 10-20 by the end of training, if you use the default hyperparameters.

We recommend you set use_wandb=False while working on this code, until you're getting it running without errors.

class LinearProbeTrainer:
    def __init__(self, model: HookedTransformer, args: ProbeTrainingArgs):
        self.model = model
        self.args = args
        self.linear_probe = args.setup_linear_probe(model)

    def training_step(self, indices: Int[Tensor, "n_games"]) -> Float[Tensor, ""]:
        # Use indices to slice our batch of games (remember, games_id = token IDs
        # from 1 to 60, and games_square = indices of squares in board)
        indices_cpu = indices.cpu()
        games_id = board_seqs_id[indices_cpu]  # shape [batch n_moves=60]
        games_square = board_seqs_square[indices_cpu]  # shape [batch n_moves=60]

        # Define seqpos slicing (note, we add n_ctx to pos_end to deal with the zero case)
        pos_start = self.args.pos_start
        pos_end = self.args.pos_end + self.model.cfg.n_ctx

        # YOUR CODE HERE - define loss

        if self.args.use_wandb:
            wandb.log(dict(loss=loss.item()), step=self.step)
        self.step += 1

        return loss

    def shuffle_training_indices(self):
        """
        Returns the tensors you'll use to index into the training data.
        """
        n_indices = self.args.num_games - (self.args.num_games % self.args.batch_size)
        full_train_indices = t.randperm(self.args.num_games)[:n_indices]
        full_train_indices = einops.rearrange(
            full_train_indices,
            "(batch_idx game_idx) -> batch_idx game_idx",
            game_idx=self.args.batch_size,
        )
        return full_train_indices

    def train(self):
        self.step = 0
        if self.args.use_wandb:
            wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)

        optimizer = t.optim.AdamW(
            [self.linear_probe],
            lr=self.args.lr,
            betas=self.args.betas,
            weight_decay=self.args.weight_decay,
        )

        for epoch in range(self.args.epochs):
            print(f"Epoch {epoch + 1}/{self.args.epochs}")
            full_train_indices = self.shuffle_training_indices()
            progress_bar = tqdm(full_train_indices)
            for indices in progress_bar:
                loss = self.training_step(indices)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                progress_bar.set_description(f"Loss = {loss:.4f}")

        if self.args.use_wandb:
            wandb.finish()


t.set_grad_enabled(True)

args = ProbeTrainingArgs()
trainer = LinearProbeTrainer(model, args)
trainer.train()
Solution
class LinearProbeTrainer:
    def __init__(self, model: HookedTransformer, args: ProbeTrainingArgs):
        self.model = model
        self.args = args
        self.linear_probe = args.setup_linear_probe(model)
def training_step(self, indices: Int[Tensor, "n_games"]) -> Float[Tensor, ""]:
        # Use indices to slice our batch of games (remember, games_id = token IDs
        # from 1 to 60, and games_square = indices of squares in board)
        indices_cpu = indices.cpu()
        games_id = board_seqs_id[indices_cpu]  # shape [batch n_moves=60]
        games_square = board_seqs_square[indices_cpu]  # shape [batch n_moves=60]
# Define seqpos slicing (note, we add n_ctx to pos_end to deal with the zero case)
        pos_start = self.args.pos_start
        pos_end = self.args.pos_end + self.model.cfg.n_ctx
# Cache resid_post from all our games (ignoring the last one)
        with t.inference_mode():
            _, cache = model.run_with_cache(
                games_id[:, :-1].to(device),
                return_type=None,
                names_filter=lambda name: name.endswith("resid_post"),
            )
# We slice from the first even index (on or after pos_start), since we're
        # just looking at predictions made after white has played a move.
        pos_start_even = pos_start + (pos_start % 2)
        seqpos_indices = np.arange(pos_start_even, pos_end, 2)
        resid_post = cache["resid_post", self.args.layer][:, seqpos_indices]
# Get probe output, i.e. probe_logits[g, p, r, c] = the 3-vector of logit
        # predictions for what color is in square [r, c] AFTER the p-th move is
        # played in game g.
        probe_logits = einops.einsum(
            resid_post,
            self.linear_probe,
            "batch pos d_model, d_model rows cols options -> batch pos rows cols options",
        )
        probe_logprobs = probe_logits.log_softmax(-1)
# Get the actual game state. The original state has {0: empty, 1: black, -1: white} and
        # we want our probe to be in the basis {0: empty, 1: theirs, 2: mine}. We're only training
        # on even moves i.e. black just played and mine = white, so we just need to map -1 -> 2.
        state = get_board_states_and_legal_moves(games_square)[0]  # shape [batch moves 8 8]
        state = state[:, seqpos_indices]  # shape [batch pos 8 8]
        state[state == -1] = 2
# Index into probe logprobs to get the logprobs for correct board state, and then
        # return loss as the mean over games & posns, and sum over rows & cols (since each
        # row & col is effectively an independent probe).
        correct_probe_logprobs = eindex(probe_logprobs, state, "game pos row col [game pos row col]")
        loss = -einops.reduce(correct_probe_logprobs, "game pos row col -> row col", "mean").sum()
if self.args.use_wandb:
            wandb.log(dict(loss=loss.item()), step=self.step)
        self.step += 1
return loss
def shuffle_training_indices(self):
        """
        Returns the tensors you'll use to index into the training data.
        """
        n_indices = self.args.num_games - (self.args.num_games % self.args.batch_size)
        full_train_indices = t.randperm(self.args.num_games)[:n_indices]
        full_train_indices = einops.rearrange(
            full_train_indices,
            "(batch_idx game_idx) -> batch_idx game_idx",
            game_idx=self.args.batch_size,
        )
        return full_train_indices
def train(self):
        self.step = 0
        if self.args.use_wandb:
            wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
optimizer = t.optim.AdamW(
            [self.linear_probe],
            lr=self.args.lr,
            betas=self.args.betas,
            weight_decay=self.args.weight_decay,
        )
for epoch in range(self.args.epochs):
            print(f"Epoch {epoch + 1}/{self.args.epochs}")
            full_train_indices = self.shuffle_training_indices()
            progress_bar = tqdm(full_train_indices)
            for indices in progress_bar:
                loss = self.training_step(indices)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                progress_bar.set_description(f"Loss = {loss:.4f}")
if self.args.use_wandb:
            wandb.finish()

Finally, let's make the same accuracy plot from before, and see how well it works. Note that we're not constructing a new probe by averaging the even and odd mode probes this time, we're just taking our single even mode probe and using it (there's a bonus exercise below where you can train all 3 modes at once).

# Getting the probe's output, and then its predictions
probe_out = einops.einsum(
    focus_cache["resid_post", args.layer],
    trainer.linear_probe,
    "game move d_model, d_model row col options -> game move row col options",
)
probe_out_value = probe_out.argmax(dim=-1).cpu()

# See what the accuracy was in 3 cases: odd moves, even moves, and aggregate moves
is_correct = probe_out_value == focus_states_theirs_vs_mine[:, :-1]
accuracies_odd = einops.reduce(is_correct[:, 5:-5:2].float(), "game move row col -> row col", "mean")
accuracies_even = einops.reduce(is_correct[:, 6:-6:2].float(), "game move row col -> row col", "mean")
accuracies_all = einops.reduce(is_correct[:, 5:-5].float(), "game move row col -> row col", "mean")

utils.plot_board_values(
    1 - t.stack([accuracies_odd, accuracies_even, accuracies_all], dim=0),
    title="Average Error Rate of Linear Probe",
    board_titles=["Black to play", "White to play", "All Moves"],
    zmax=0.25,
    zmin=-0.25,
    height=400,
    width=900,
)
Click to see the expected output

Exercise - train all 3 modes at once (bonus)

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵⚪⚪⚪⚪
You should spend up to 20-40 minutes on this exercise.

This exercise isn't super important or conceptually deep (and it can be very fiddly), but we include it for the sake of completionism. It's quite fiddly to get the indexing right here! What you should do is:

  • Use a new probe training args class, where your linear probe has an extra mode dimension (we've given you code for this below).
  • Rewrite your training_step function (in the new LinearMultiProbeTrainer class) to train all 3 modes in parallel.

We've also given you some sample code to run at the end, which will plot the accuracy of each of your 3 probes (as well as their ability to transfer to parities which they weren't trained on). You should see that your training loss for each of the 3 modes is about the same size as the total loss was in your previous probe, and they all reduce at approximately the same rate.

Help - I'm confused about how to rewrite the training_step function.

Previously, we just trained one linear probe on only the even moves. To break down the steps for this, we had to do the following:

- Compute the probe logprobs, with shape (games, posn=59, rows=8, cols=8, options=3) - Remap the board state state from basis {0: empty, 1: white, -1: black} to {0: empty, 1: theirs, 2: mine}. - Since we were only using the even-valued positions, we could just slice them out and perform the mapping {0, 1, -1} -> {0, 1, 2}. - Index into logprobs with the basis-mapped state to get the correct logprobs, which had shape (games, posns, rows, cols). - Compute loss = negative mean of correct logprobs (only including the even positions).

Now, we want to train 3 linear probes at once, one for even moves only, one for odd moves only, and one for both at once. This will look like:

- Compute the probe logprobs, with shape (modes=3, games, posn=59, rows=8, cols=8, options=3) - Remap the board state state from basis {0: empty, 1: white, -1: black} to {0: empty, 1: theirs, 2: mine}. - Since we're using odd and even positions, this means mapping even positions with {0, 1, -1} -> {0, 1, 2} and odd positions with {0, 1, -1} -> {0, 2, 1}. - Index into logprobs with the basis-mapped state to get the correct logprobs for each probe, which will have shape (modes, games, posn, rows, cols). - Compute loss = negative mean of correct logprobs (only including the even positions for our even probe, odd positions for our odd probe, and all positions for our all probe).

In other words it looks very similar, but we need to use the entire state tensor (with different basis mapping for odd and even moves) then make sure we're computing each loss over the correct set of sequence positions.

@dataclass
class MultiProbeTrainingArgs(ProbeTrainingArgs):
    modes: int = 3  # even, odd, both (i.e. the data we train on)

    def setup_linear_probe(self, model: HookedTransformer):
        linear_probe = t.randn(
            self.modes,
            model.cfg.d_model,
            self.rows,
            self.cols,
            self.options,
            device=device,
        ) / np.sqrt(model.cfg.d_model)
        linear_probe.requires_grad = True
        return linear_probe


class LinearMultiProbeTrainer(LinearProbeTrainer):
    def training_step(self, indices: Int[Tensor, "n_games"]) -> Float[Tensor, ""]:
        indices_cpu = indices.cpu()
        games_id = board_seqs_id[indices_cpu]  # shape [batch n_moves=60]
        games_square = board_seqs_square[indices_cpu]  # shape [batch n_moves=60]

        pos_start = self.args.pos_start
        pos_end = self.args.pos_end + self.model.cfg.n_ctx

        # YOUR CODE HERE - define loss_even, loss_odd, loss_both
        loss = loss_even + loss_odd + loss_both

        if self.args.use_wandb:
            wandb.log(
                dict(
                    loss=loss.item(),
                    loss_even=loss_even.item(),
                    loss_odd=loss_odd.item(),
                    loss_both=loss_both.item(),
                ),
                step=self.step,
            )
        self.step += 1

        return loss


t.set_grad_enabled(True)

args = MultiProbeTrainingArgs(epochs=5)
trainer = LinearMultiProbeTrainer(model, args)
trainer.train()
# Here, we test out each of our 3 probe modes (even / odd / both) on each of these 3 settings
# (even / odd / both). Hopefully we should see all 3 probes generalize!

probe_out = einops.einsum(
    focus_cache["resid_post", args.layer],
    trainer.linear_probe,
    "game move d_model, mode d_model row col options -> mode game move row col options",
)
probe_out_value = probe_out.argmax(dim=-1).cpu()  # mode game move row col

# For each mode, get the accuracy on even / odd / both
is_correct = probe_out_value == focus_states_theirs_vs_mine[:, :-1]  # mode game move row col
accuracies_even = einops.reduce(is_correct[:, 6:-6:2].float(), "mode game move row col -> mode row col", "mean")
accuracies_odd = einops.reduce(is_correct[:, 5:-5:2].float(), "mode game move row col -> mode row col", "mean")
accuracies_all = einops.reduce(is_correct[:, 5:-5].float(), "mode game move row col -> mode row col", "mean")

# Get all 3x3 accuracies, stacked over first dim
accuracies_stacked = t.concat([accuracies_even, accuracies_odd, accuracies_all], dim=0)

# Plot results!
board_titles = [
    f"{probe_mode} probe on {data_mode} data"
    for data_mode in ["even", "odd", "all"]
    for probe_mode in ["even", "odd", "both"]
]

utils.plot_board_values(
    1 - accuracies_stacked,
    title="Average Error Rate of Linear Probe",
    board_titles=board_titles,
    boards_per_row=3,
    zmax=0.25,
    zmin=-0.25,
    height=1000,
    width=900,
)
Click to see the expected output
Solution
@dataclass
class MultiProbeTrainingArgs(ProbeTrainingArgs):
    modes: int = 3  # even, odd, both (i.e. the data we train on)
def setup_linear_probe(self, model: HookedTransformer):
        linear_probe = t.randn(
            self.modes,
            model.cfg.d_model,
            self.rows,
            self.cols,
            self.options,
            device=device,
        ) / np.sqrt(model.cfg.d_model)
        linear_probe.requires_grad = True
        return linear_probe
class LinearMultiProbeTrainer(LinearProbeTrainer):
    def training_step(self, indices: Int[Tensor, "n_games"]) -> Float[Tensor, ""]:
        indices_cpu = indices.cpu()
        games_id = board_seqs_id[indices_cpu]  # shape [batch n_moves=60]
        games_square = board_seqs_square[indices_cpu]  # shape [batch n_moves=60]
pos_start = self.args.pos_start
        pos_end = self.args.pos_end + self.model.cfg.n_ctx
# Cache resid_post from all our games (ignoring the last one)
        with t.inference_mode():
            _, cache = model.run_with_cache(
                games_id[:, :-1].to(device),
                return_type=None,
                names_filter=lambda name: name.endswith("resid_post"),
            )
# We're training on all modes, so we slice all resid values in our range.
        resid_post = cache["resid_post", self.args.layer][:, pos_start:pos_end]
# Get probe output, i.e. probe_logits[m, g, p, r, c] = the 3-vector of logit predictions from
        # mode-m probe, for what color is in square [r, c] AFTER the p-th move is played in game g.
        probe_logits = einops.einsum(
            resid_post,
            self.linear_probe,
            "game pos d_model, mode d_model rows cols options -> mode game pos rows cols options",
        )
        probe_logprobs = probe_logits.log_softmax(-1)
# Get the actual game state. The original state has {0: empty, 1: black, -1: white} and
        # we want our probe to be in the basis {0: empty, 1: theirs, 2: mine}. For even moves,
        # mine = white, so we map -1 -> 2. For odd moves, mine = black, so we map {1, -1} -> {2, 1}.
        state = get_board_states_and_legal_moves(games_square)[0]
        state[:, ::2][state[:, ::2] == -1] = 2
        state[:, 1::2][state[:, 1::2] == 1] = 2
        state[:, 1::2][state[:, 1::2] == -1] = 1
        state = state[:, pos_start:pos_end]
# Index into the probe logprobs with the correct indices (note, each of our 3 probe modes
        # gives us a different tensor of logprobs).
        correct_probe_logprobs = eindex(
            probe_logprobs,
            state,
            "mode game pos row col [game pos row col]",  # -> shape [mode game pos row col]
        )
        # Get the logprobs we'll be using for our 3 different probes. Remember that for the even
        # and odd probes we need to take only the even/odd moves respectively (and also that we've
        # already sliced logprobs from pos_start: pos_end).
        pos_start_even, pos_start_odd = (0, 1) if pos_start % 2 == 0 else (1, 0)
        even_probe_logprobs = correct_probe_logprobs[0, pos_start_even::2]
        odd_probe_logprobs = correct_probe_logprobs[1, pos_start_odd::2]
        both_probe_logprobs = correct_probe_logprobs[2]
        # Get our 3 different loss functions
        loss_even = -einops.reduce(even_probe_logprobs, "game pos row col -> row col", "mean").sum()
        loss_odd = -einops.reduce(odd_probe_logprobs, "game pos row col -> row col", "mean").sum()
        loss_both = -einops.reduce(both_probe_logprobs, "game pos row col -> row col", "mean").sum()
        # We backprop on the sum of all 3 losses
        loss = loss_even + loss_odd + loss_both
if self.args.use_wandb:
            wandb.log(
                dict(
                    loss=loss.item(),
                    loss_even=loss_even.item(),
                    loss_odd=loss_odd.item(),
                    loss_both=loss_both.item(),
                ),
                step=self.step,
            )
        self.step += 1
return loss

As a bonus exercise, you can try the following:

  • Add an evaluation loop to the training code - this is helpful if you don't want to wait until the end of training to look at your probe classification accuracy!
  • Make a hybrid probe by averaging over your even and odd mode probes (like the way we made probes from the provided probe in earlier exercises). Is this probe's accuracy higher than either the even or odd mode probe? Is it higher than the probe trained on both even and odd moves?