1️⃣ Model & Task Setup

Learning Objectives
  • Understand the IOI task, and why the authors chose to study it
  • Build functions to demonstrate the model's performance on this task

Loading our model

The first step is to load in our model, GPT-2 Small, a 12 layer and 80M parameter transformer with HookedTransformer.from_pretrained. The various flags are simplifications that preserve the model's output but simplify its internals.

model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)
Note on refactor_factored_attn_matrices (optional)

This argument means we redefine the matrices $W_Q$, $W_K$, $W_V$ and $W_O$ in the model (without changing the model's actual behaviour).

For example, we know that instead of working with $W_Q$ and $W_K$ individually, the only matrix we actually need to use in the model is the low-rank matrix $W_Q W_K^T$ (note that I'm using the convention of matrix multiplication on the right, which matches the code in transformerlens and previous exercises in this series, but doesn't match Anthropic's Mathematical Frameworks paper). So if we perform singular value decomposition $W_Q W_K^T = U S V^T$, then we see that we can just as easily define $W_Q = U \sqrt{S}$ and $W_K = V \sqrt{S}$ and use these instead. This means that $W_Q$ and $W_K$ both have orthogonal columns with matching norms. You can investigate this yourself (e.g. using the code below). This is arguably a more interpretable setup, because now there's no obvious asymmetry between the keys and queries.

There's also some fiddlyness with how biases are handled in this factorisation, which is why the comments above don't hold absolutely (see the documentation for more info).

# Show column norms are the same (except first few, for fiddly bias reasons)
line([model.W_Q[0, 0].pow(2).sum(0), model.W_K[0, 0].pow(2).sum(0)])
# Show columns are orthogonal (except first few, again)
W_Q_dot_products = einops.einsum(
    model.W_Q[0, 0], model.W_Q[0, 0], "d_model d_head_1, d_model d_head_2 -> d_head_1 d_head_2"
)
imshow(W_Q_dot_products)

In a similar way, since $W_{OV} = W_V W_O = U S V^T$, we can define $W_V = U S$ and $W_O = V^T$. This is arguably a more interpretable setup, because now $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the result of the head.

Note on fold_ln, center_unembed and center_writing_weights (optional)

See link [here](https://github.com/neelnanda-io/TransformerLens/blob/main/further_comments.md#what-is-layernorm-folding-fold_ln) for comments.

The next step is to verify that the model can actually do the task! Here we use utils.test_prompt, and see that the model is significantly better at predicting Mary than John!

Asides

Note: If we were being careful, we'd want to run the model on a range of prompts and find the average performance. We'll do more stuff like this in the fourth section (when we try to replicate some of the paper's results, and take a more rigorous approach).

prepend_bos is a flag to add a BOS (beginning of sequence) to the start of the prompt. GPT-2 was not trained with this, but I find that it often makes model behaviour more stable, as the first token is treated weirdly.
# Here is where we test on a single prompt
# Result: 70% probability on Mary, as we expect

example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']

Performance on answer token:
Rank: 0        Logit: 18.09 Prob: 70.07% Token: | Mary|

Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|
Top 1th token. Logit: 15.38 Prob:  4.67% Token: | the|
Top 2th token. Logit: 15.35 Prob:  4.54% Token: | John|
Top 3th token. Logit: 15.25 Prob:  4.11% Token: | them|
Top 4th token. Logit: 14.84 Prob:  2.73% Token: | his|
Top 5th token. Logit: 14.06 Prob:  1.24% Token: | her|
Top 6th token. Logit: 13.54 Prob:  0.74% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.73% Token: | their|
Top 8th token. Logit: 13.13 Prob:  0.49% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.42% Token: | him|

Ranks of the answer tokens: [(' Mary', 0)]

We now want to find a reference prompt to run the model on. Even though our ultimate goal is to reverse engineer how this behaviour is done in general, often the best way to start out in mechanistic interpretability is by zooming in on a concrete example and understanding it in detail, and only then zooming out and verifying that our analysis generalises. In section 3, we'll work with a dataset similar to the one used by the paper authors, but this probably wouldn't be the first thing we reached for if we were just doing initial investigations.

We'll run the model on 4 instances of this task, each prompt given twice - one with the first name as the indirect object, one with the second name. To make our lives easier, we'll carefully choose prompts with single token names and the corresponding names in the same token positions.

Aside on tokenization

We want models that can take in arbitrary text, but models need to have a fixed vocabulary. So the solution is to define a vocabulary of tokens and to deterministically break up arbitrary text into tokens. Tokens are, essentially, subwords, and are determined by finding the most frequent substrings - this means that tokens vary a lot in length and frequency!

Tokens are a massive headache and are one of the most annoying things about reverse engineering language models... Different names will be different numbers of tokens, different prompts will have the relevant tokens at different positions, different prompts will have different total numbers of tokens, etc. Language models often devote significant amounts of parameters in early layers to convert inputs from tokens to a more sensible internal format (and do the reverse in later layers). You really, really want to avoid needing to think about tokenization wherever possible when doing exploratory analysis (though, of course, it's relevant later when trying to flesh out your analysis and make it rigorous!). HookedTransformer comes with several helper methods to deal with tokens: to_tokens, to_string, to_str_tokens, to_single_token, get_token_position

Exercise: I recommend using model.to_str_tokens to explore how the model tokenizes different strings. In particular, try adding or removing spaces at the start, or changing capitalization - these change tokenization!
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
name_pairs = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]

# Define 8 prompts, in 4 groups of 2 (with adjacent prompts having answers swapped)
prompts = [
    prompt.format(name)
    for (prompt, names) in zip(prompt_format, name_pairs)
    for name in names[::-1]
]
# Define the answers for each prompt, in the form (correct, incorrect)
answers = [names[::i] for names in name_pairs for i in (1, -1)]
# Define the answer tokens (same shape as the answers)
answer_tokens = t.concat([model.to_tokens(names, prepend_bos=False).T for names in answers])

rprint(prompts)
rprint(answers)
rprint(answer_tokens)

table = Table("Prompt", "Correct", "Incorrect", title="Prompts & Answers:")

for prompt, answer in zip(prompts, answers):
    table.add_row(prompt, repr(answer[0]), repr(answer[1]))

rprint(table)
[
    'When John and Mary went to the shops, John gave the bag to',
    'When John and Mary went to the shops, Mary gave the bag to',
    'When Tom and James went to the park, James gave the ball to',
    'When Tom and James went to the park, Tom gave the ball to',
    'When Dan and Sid went to the shops, Sid gave an apple to',
    'When Dan and Sid went to the shops, Dan gave an apple to',
    'After Martin and Amy went to the park, Amy gave a drink to',
    'After Martin and Amy went to the park, Martin gave a drink to'
]

[
    (' Mary', ' John'),
    (' John', ' Mary'),
    (' Tom', ' James'),
    (' James', ' Tom'),
    (' Dan', ' Sid'),
    (' Sid', ' Dan'),
    (' Martin', ' Amy'),
    (' Amy', ' Martin')
]

tensor([[ 5335,  1757],
        [ 1757,  5335],
        [ 4186,  3700],
        [ 3700,  4186],
        [ 6035, 15686],
        [15686,  6035],
        [ 5780, 14235],
        [14235,  5780]], device='cuda:0')

                                   Prompts & Answers:                                    
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Prompt                                                         Correct    Incorrect ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━┩
│ When John and Mary went to the shops, John gave the bag to    │ ' Mary'   │ ' John'   │
│ When John and Mary went to the shops, Mary gave the bag to    │ ' John'   │ ' Mary'   │
│ When Tom and James went to the park, James gave the ball to   │ ' Tom'    │ ' James'  │
│ When Tom and James went to the park, Tom gave the ball to     │ ' James'  │ ' Tom'    │
│ When Dan and Sid went to the shops, Sid gave an apple to      │ ' Dan'    │ ' Sid'    │
│ When Dan and Sid went to the shops, Dan gave an apple to      │ ' Sid'    │ ' Dan'    │
│ After Martin and Amy went to the park, Amy gave a drink to    │ ' Martin' │ ' Amy'    │
│ After Martin and Amy went to the park, Martin gave a drink to │ ' Amy'    │ ' Martin' │
└───────────────────────────────────────────────────────────────┴───────────┴───────────┘
Aside - the rich library

The outputs above were created by rich, a fun library which prints things in nice formats. It has functions like rich.table.Table, which are very easy to use but can produce visually clear outputs which are sometimes useful.

You can also color the columns of a table, by using the rich.table.Column argument with the style parameter:

cols = [
    "Prompt",
    Column("Correct", style="rgb(0,200,0) bold"),
    Column("Incorrect", style="rgb(255,0,0) bold"),
]
table = Table(*cols, title="Prompts & Answers:")
for prompt, answer in zip(prompts, answers):
    table.add_row(prompt, repr(answer[0]), repr(answer[1]))
rprint(table)

We now run the model on these prompts and use run_with_cache to get both the logits and a cache of all internal activations for later analysis.

tokens = model.to_tokens(prompts, prepend_bos=True)
# Move the tokens to the GPU
tokens = tokens.to(device)
# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

We'll later be evaluating how model performance differs upon performing various interventions, so it's useful to have a metric to measure model performance. Our metric here will be the logit difference, the difference in logit between the indirect object's name and the subject's name (eg, logit(Mary) - logit(John)).

Exercise - implement the performance evaluation function

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 10-15 minutes on this exercise. It's important to understand exactly what this function is computing, and why it matters.

This function should take in your model's logit output (shape (batch, seq, d_vocab)), and the array of answer tokens (shape (batch, 2), containing the token ids of correct and incorrect answers respectively for each sequence), and return the logit difference as described above. If per_prompt is False, then it should take the mean over the batch dimension, if not then it should return an array of length batch.

def logits_to_ave_logit_diff(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Int[Tensor, "batch 2"] = answer_tokens,
    per_prompt: bool = False,
) -> Float[Tensor, "*batch"]:
    """
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    """
    raise NotImplementedError()


tests.test_logits_to_ave_logit_diff(logits_to_ave_logit_diff)

original_per_prompt_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
print("Per prompt logit difference:", original_per_prompt_diff)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", original_average_logit_diff)

cols = [
    "Prompt",
    Column("Correct", style="rgb(0,200,0) bold"),
    Column("Incorrect", style="rgb(255,0,0) bold"),
    Column("Logit Difference", style="bold"),
]
table = Table(*cols, title="Logit differences")

for prompt, answer, logit_diff in zip(prompts, answers, original_per_prompt_diff):
    table.add_row(prompt, repr(answer[0]), repr(answer[1]), f"{logit_diff.item():.3f}")

rprint(table)
Click to see the expected output
                                             Logit differences                                              
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃ Prompt                                                         Correct    Incorrect  Logit Difference ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│ When John and Mary went to the shops, John gave the bag to    │ ' Mary'    ' John'    3.337            │
│ When John and Mary went to the shops, Mary gave the bag to    │ ' John'    ' Mary'    3.202            │
│ When Tom and James went to the park, James gave the ball to   │ ' Tom'     ' James'   2.709            │
│ When Tom and James went to the park, Tom gave the ball to     │ ' James'   ' Tom'     3.797            │
│ When Dan and Sid went to the shops, Sid gave an apple to      │ ' Dan'     ' Sid'     1.720            │
│ When Dan and Sid went to the shops, Dan gave an apple to      │ ' Sid'     ' Dan'     5.281            │
│ After Martin and Amy went to the park, Amy gave a drink to    │ ' Martin'  ' Amy'     2.601            │
│ After Martin and Amy went to the park, Martin gave a drink to │ ' Amy'     ' Martin'  5.767            │
└───────────────────────────────────────────────────────────────┴───────────┴───────────┴──────────────────┘
Solution
def logits_to_ave_logit_diff(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Int[Tensor, "batch 2"] = answer_tokens,
    per_prompt: bool = False,
) -> Float[Tensor, "*batch"]:
    """
    Returns logit difference between the correct and incorrect answer.
    If per_prompt=True, return the array of differences rather than the average.
    """
    # Only the final logits are relevant for the answer
    final_logits: Float[Tensor, "batch d_vocab"] = logits[:, -1, :]
    # Get the logits corresponding to the indirect object / subject tokens respectively
    answer_logits: Float[Tensor, "batch 2"] = final_logits.gather(dim=-1, index=answer_tokens)
    # Find logit difference
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

Brainstorm What's Actually Going On

Before diving into running experiments, it's often useful to spend some time actually reasoning about how the behaviour in question could be implemented in the transformer. This is optional, and you'll likely get the most out of engaging with this section if you have a decent understanding already of what a transformer is and how it works!

You don't have to do this and forming hypotheses after exploration is also reasonable, but I think it's often easier to explore and interpret results with some grounding in what you might find. In this particular case, I'm cheating somewhat, since I know the answer, but I'm trying to simulate the process of reasoning about it!

Note that often your hypothesis will be wrong in some ways and often be completely off. We're doing science here, and the goal is to understand how the model actually works, and to form true beliefs! There are two separate traps here at two extremes that it's worth tracking: * Confusion: Having no hypotheses at all, getting a lot of data and not knowing what to do with it, and just floundering around * Dogmatism: Being overconfident in an incorrect hypothesis and being unwilling to let go of it when reality contradicts you, or flinching away from running the experiments that might disconfirm it.

Exercise: Spend some time thinking through how you might imagine this behaviour being implemented in a transformer. Try to think through this for yourself before reading through my thoughts!

(*) My reasoning

Brainstorming:

So, what's hard about the task? Let's focus on the concrete example of the first prompt, "When John and Mary went to the shops, John gave the bag to" -> " Mary".

A good starting point is thinking though whether a tiny model could do this, e.g. a 1L Attn-Only model. I'm pretty sure the answer is no! Attention is really good at the primitive operations of looking nearby, or copying information. I can believe a tiny model could figure out that at to it should look for names and predict that those names came next (e.g. the skip trigram " John...to -> John"). But it's much harder to tell how many of each previous name there are - attending to each copy of John will look exactly the same as attending to a single John token. So this will be pretty hard to figure out on the " to" token!

The natural place to break this symmetry is on the second " John" token - telling whether there is an earlier copy of the current token should be a much easier task. So I might expect there to be a head which detects duplicate tokens on the second " John" token, and then another head which moves that information from the second " John token to the " to" token.

The model then needs to learn to predict " Mary" and not " John". I can see two natural ways to do this: 1. Detect all preceding names and move this information to " to" and then delete the any name corresponding to the duplicate token feature. This feels easier done with a non-linearity, since precisely cancelling out vectors is hard, so I'd imagine an MLP layer deletes the " John" direction of the residual stream. 2. Have a head which attends to all previous names, but where the duplicate token features inhibit it from attending to specific names. So this only attends to Mary. And then the output of this head maps to the logits.

Spoiler - which one of these two is correct

It's the second one.

Experiment Ideas

A test that could distinguish these two is to look at which components of the model add directly to the logits - if it's mostly attention heads which attend to `" Mary"` and to neither `" John"` it's probably hypothesis 2, if it's mostly MLPs it's probably hypothesis 1. And we should be able to identify duplicate token heads by finding ones which attend from `" John"` to `" John"`, and whose outputs are then moved to the `" to"` token by V-Composition with another head (Spoiler: It's more complicated than that!) Note that all of the above reasoning is very simplistic and could easily break in a real model! There'll be significant parts of the model that figure out whether to use this circuit at all (we don't want to inhibit duplicated names when, e.g. figuring out what goes at the start of the next sentence), and may be parts towards the end of the model that do "post-processing" just before the final output. But it's a good starting point for thinking about what's going on.