3️⃣ TransformerLens: Hooks

Learning Objectives
  • Understand what hooks are, and how they are used in TransformerLens
  • Use hooks to access activations, process the results, and write them to an external tensor
  • Build tools to perform attribution, i.e. detecting which components of your model are responsible for performance on a given task
  • Understand how hooks can be used to perform basic interventions like ablation

What are hooks?

One of the great things about interpreting neural networks is that we have full control over our system. From a computational perspective, we know exactly what operations are going on inside (even if we don't know what they mean!). And we can make precise, surgical edits and see how the model's behaviour and other internals change. This is an extremely powerful tool, because it can let us e.g. set up careful counterfactuals and causal intervention to easily understand model behaviour.

Accordingly, being able to do this is a pretty core operation, and this is one of the main things TransformerLens supports! The key feature here is hook points. Every activation inside the transformer is surrounded by a hook point, which allows us to edit or intervene on it.

We do this by adding a hook function to that activation, and then calling model.run_with_hooks.

(Terminology note - because basically all the activations in our model have an associated hook point, we'll sometimes use the terms "hook" and "activation" interchangeably.)

Hook functions

Hook functions take two arguments: activation_value and hook_point. The activation_value is a tensor representing some activation in the model, just like the values in our ActivationCache. The hook_point is an object which gives us methods like hook.layer() or attributes like hook.name that are sometimes useful to call within the function.

If we're using hooks to edit activations, then the hook function should return a tensor of the same shape as the activation value. But we can also just have our hook function access the activation, do some processing, and write the results to some external variable (in which case our hook function should just not return anything).

An example hook function for changing the attention patterns at a particular layer might look like:

def hook_function(
    attn_pattern: Float[Tensor, "batch heads seq_len seq_len"],
    hook: HookPoint
) -> Float[Tensor, "batch heads seq_len seq_len"]:

    # modify attn_pattern (can be inplace)
    return attn_pattern

Running with hooks

Once you've defined a hook function (or functions), you should call model.run_with_hooks. A typical call to this function might look like:

loss = model.run_with_hooks(
    tokens,
    return_type="loss",
    fwd_hooks=[
        ('blocks.1.attn.hook_pattern', hook_function)
    ]
)

Let's break this code down.

  • tokens represents our model's input.
  • return_type="loss" is used here because we're modifying our activations and seeing how this affects the loss.
    • We could also return the logits, or just use return_type=None if we only want to access the intermediate activations and we don't care about the output.
  • fwd_hooks is a list of 2-tuples of (hook name, hook function).
    • The hook name is a string that specifies which activation we want to hook.
    • The hook function gets run with the corresponding activation as its first argument.

A bit more about hooks

Here are a few extra notes for how to squeeze even more functionality out of hooks. If you'd prefer, you can jump ahead to see an actual example of hooks being used, and come back to this section later.

Resetting hooks model.run_with_hooks has the default parameter reset_hooks_end=True which resets all hooks at the end of the run (including both those that were added before and during the run). Despite this, it's possible to shoot yourself in the foot with hooks, e.g. if there's an error in one of your hooks so the function never finishes. In this case, you can use model.reset_hooks() to reset all hooks.

If you don't want to reset hooks (i.e. you want to keep them between forward passes), you can either set reset_hooks_end=False in the run_with_hooks function, or just add the hooks directly using the add_hook method before your forward passes (this way they won't reset automatically).

Adding multiple hooks at once

Including more than one tuple in the fwd_hooks list is one way to add multiple hooks:

loss = model.run_with_hooks(
    tokens,
    return_type="loss",
    fwd_hooks=[
        ('blocks.0.attn.hook_pattern', hook_function),
        ('blocks.1.attn.hook_pattern', hook_function)
    ]
)

Another way is to use a name filter rather than a single name:

loss = model.run_with_hooks(
    tokens,
    return_type="loss",
    fwd_hooks=[
        (lambda name: name.endswith("pattern"), hook_function)
    ]
)
utils.get_act_name

When we were indexing the cache in the previous section, we found we could use strings like cache['blocks.0.attn.hook_pattern'], or use the shorthand of cache['pattern', 0]. The reason the second one works is that it calls the function utils.get_act_name under the hood, i.e. we have:

utils.get_act_name('pattern', 0) == 'blocks.0.attn.hook_pattern'

Using utils.get_act_name in your forward hooks is often easier than using the full string, since the only thing you need to remember is the activation name (you can refer back to the diagram in the previous section for this).

Using functools.partial to create variations on hooks

A useful trick is to define a hook function with more arguments than it needs, and then use functools.partial to fill in the extra arguments. For instance, if you want a hook function which only modifies a particular head, but you want to run it on all heads separately (rather than just adding all the hooks and having them all run on the next forward pass), then you can do something like:

def hook_all_attention_patterns(
    attn_pattern: Float[Tensor, "batch heads seq_len seq_len"],
    hook: HookPoint,
    head_idx: int
) -> Float[Tensor, "batch heads seq_len seq_len"]:
    # modify attn_pattern inplace, at head_idx
    return attn_pattern
for head_idx in range(12):
    temp_hook_fn = functools.partial(hook_all_attention_patterns, head_idx=head_idx)
    model.run_with_hooks(tokens, fwd_hooks=[('blocks.1.attn.hook_pattern', temp_hook_fn)])

And here are some points of interest, which aren't vital to understand:

Relationship to PyTorch hooks

[PyTorch hooks](https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/) are a great and underrated, yet incredibly janky, feature. They can act on a layer, and edit the input or output of that layer, or the gradient when applying autodiff. The key difference is that Hook points act on activations not layers. This means that you can intervene within a layer on each activation, and don't need to care about the precise layer structure of the transformer. And it's immediately clear exactly how the hook's effect is applied. This adjustment was shamelessly inspired by [Garcon's use of ProbePoints](https://transformer-circuits.pub/2021/garcon/index.html).

They also come with a range of other quality of life improvements. PyTorch's hooks are global state, which can be a massive pain if you accidentally leave a hook on a model. TransformerLens hooks are also global state, but run_with_hooks tries to create an abstraction where these are local state by removing all hooks at the end of the function (and they come with a helpful model.reset_hooks() method to remove all hooks).

How are TransformerLens hooks actually implemented?

They are implemented as modules with the identity function as their forward method:

class HookPoint(nn.Module):
    ...
    def forward(self, x):
        return x

but also with special features for adding and removing hook functions. This is why you see hooks when you print a HookedTransformer model, because all its modules are recursively printed.

When you run the model normally, hook modules won't change the model's behaviour (since applying the identity function does nothing). It's only once you add functions to the hook modules (e.g. a function which ablates any inputs into the hook module) that the model's behaviour changes.

Hooks: Accessing Activations

In later sections, we'll write some code to intervene on hooks, which is really the core feature that makes them so useful for interpretability. But for now, let's just look at how to access them without changing their value. This can be achieved by having the hook function write to a global variable, and return nothing (rather than modifying the activation in place).

Why might we want to do this? It turns out to be useful for things like:

  • Extracting activations for a specific task
  • Doing some long-running calculation across many inputs, e.g. finding the text that most activates a specific neuron

Note that, in theory, this could all be done using the run_with_cache function we used in the previous section, combined with post-processing of the cache result. But using hooks can be more intuitive and memory efficient.

Exercise - calculate induction scores with hooks

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You shouldn't spend more than 15-20 minutes on this exercise. This is our first exercise with hooks, which are an absolutely vital TransformerLens tool. Use the hints if you're stuck.

To start with, we'll look at how hooks can be used to get the same results as from the previous section (where we ran our induction head detector functions on the values in the cache).

Most of the code has already been provided for you below; the only thing you need to do is implement the induction_score_hook function. As mentioned, this function takes two arguments: the activation value (which in this case will be our attention pattern) and the hook object (which gives us some useful methods and attributes that we can access in the function, e.g. hook.layer() to return the layer, or hook.name to return the name, which is the same as the name in the cache).

Your function should do the following:

  • Calculate the induction score for the attention pattern pattern, using the same methodology as you used in the previous section when you wrote your induction head detectors.
    • Note that this time, the batch dimension is greater than 1, so you should compute the average attention score over the batch dimension.
    • Also note that you are computing the induction score for all heads at once, rather than one at a time. You might find the arguments dim1 and dim2 of the torch.diagonal function useful.
  • Write this score to the tensor induction_score_store, which is a global variable that we've provided for you. The [i, j]th element of this tensor should be the induction score for the jth head in the ith layer.
seq_len = 50
batch_size = 10
rep_tokens_10 = generate_repeated_tokens(model, seq_len, batch_size)

# We make a tensor to store the induction score for each head.
# We put it on the model's device to avoid needing to move things between the GPU and CPU,
# which can be slow.
induction_score_store = t.zeros(
    (model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device
)


def induction_score_hook(
    pattern: Float[Tensor, "batch head_index dest_pos source_pos"], hook: HookPoint
):
    """
    Calculates the induction score, and stores it in the [layer, head] position of the
    `induction_score_store` tensor.
    """
    raise NotImplementedError()


# We make a boolean filter on activation names, that's true only on attention pattern names
pattern_hook_names_filter = lambda name: name.endswith("pattern")

# Run with hooks (this is where we write to the `induction_score_store` tensor`)
model.run_with_hooks(
    rep_tokens_10,
    return_type=None,  # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(pattern_hook_names_filter, induction_score_hook)],
)

# Plot the induction scores for each head in each layer
imshow(
    induction_score_store,
    labels={"x": "Head", "y": "Layer"},
    title="Induction Score by Head",
    text_auto=".2f",
    width=900,
    height=350,
)
Click to see the expected output
Help - I'm not sure how to implement this function.

To get the induction stripe, you can use:

torch.diagonal(pattern, dim1=-2, dim2=-1, offset=1-seq_len)

since this returns the diagonal of each attention scores matrix, for every element in the batch and every attention head.

Once you have this, you can then take the mean over the batch and diagonal dimensions, giving you a tensor of length n_heads. You can then write this to the global induction_score_store tensor, using the hook.layer() method to get the correct row number.

Solution
def induction_score_hook(pattern: Float[Tensor, "batch head_index dest_pos source_pos"], hook: HookPoint):
    """
    Calculates the induction score, and stores it in the [layer, head] position of the induction_score_store tensor.
    """
    # Take the diagonal of attn paid from each dest posn to src posns (seq_len-1) tokens back
    # (This only has entries for tokens with index>=seq_len)
    induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1 - seq_len)
    # Get an average score per head
    induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score

If this function has been implemented correctly, you should see a result matching your observations from the previous section: a high induction score (>0.6) for all the heads which you identified as induction heads, and a low score (close to 0) for all others.

Exercise - find induction heads in GPT2-small

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You shouldn't spend more than 10-20 minutes on this exercise. Here, you mostly just need to use previously defined functions and interpret the results, rather than writing new code.

This is your first opportunity to investigate a larger and more extensively trained model, rather than the simple 2-layer model we've been using so far. None of the code required is new (you can copy most of it from previous sections), so these exercises shouldn't take very long.

Perform the same analysis on your gpt2_small. You should observe that some heads, particularly in a couple of the middle layers, have high induction scores. Use CircuitsVis to plot the attention patterns for these heads when run on the repeated token sequences, and verify that they look like induction heads.

Note - you can make CircuitsVis plots (and other visualisations) using hooks rather than plotting directly from the cache. For example, we've given you a hook function which will display the attention patterns at a given hook when you include it in a call to model.run_with_hooks.

def visualize_pattern_hook(
    pattern: Float[Tensor, "batch head_index dest_pos source_pos"],
    hook: HookPoint,
):
    print("Layer: ", hook.layer())
    display(
        cv.attention.attention_patterns(
            tokens=gpt2_small.to_str_tokens(rep_tokens[0]), attention=pattern.mean(0)
        )
    )


# YOUR CODE HERE - find induction heads in gpt2_small
Click to see the expected output



Solution
seq_len = 50
batch_size = 10
rep_tokens_batch = generate_repeated_tokens(gpt2_small, seq_len, batch_size)
induction_score_store = t.zeros(
    (gpt2_small.cfg.n_layers, gpt2_small.cfg.n_heads), device=gpt2_small.cfg.device
)
gpt2_small.run_with_hooks(
    rep_tokens_batch,
    return_type=None,  # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(pattern_hook_names_filter, induction_score_hook)],
)
imshow(
    induction_score_store,
    labels={"x": "Head", "y": "Layer"},
    title="Induction Score by Head",
    text_auto=".1f",
    width=700,
    height=500,
)
# Observation: heads 5.1, 5.5, 6.9, 7.2, 7.10 are all strongly induction-y.
# Confirm observation by visualizing attn patterns for layers 5 through 7:
induction_head_layers = [5, 6, 7]
fwd_hooks = [
    (utils.get_act_name("pattern", induction_head_layer), visualize_pattern_hook)
    for induction_head_layer in induction_head_layers
]
gpt2_small.run_with_hooks(
    rep_tokens,
    return_type=None,
    fwd_hooks=fwd_hooks,
)

Building interpretability tools

In order to develop a mechanistic understanding for how transformers perform certain tasks, we need to be able to answer questions like:

How much of the model's performance on some particular task is attributable to each component of the model?

where "component" here might mean, for example, a specific head in a layer.

There are many ways to approach a question like this. For example, we might look at how a head interacts with other heads in different layers, or we might perform a causal intervention by seeing how well the model performs if we remove the effect of this head. However, we'll keep things simple for now, and ask the question: what are the direct contributions of this head to the output logits?

Direct Logit attribution

A consequence of the residual stream is that the output logits are the sum of the contributions of each layer, and thus the sum of the results of each head. This means we can decompose the output logits into a term coming from each head and directly do attribution like this!

A concrete example

Let's say that our model knows that the token Harry is followed by the token Potter, and we want to figure out how it does this. The logits on Harry are residual @ W_U. But this is a linear map, and the residual stream is the sum of all previous layers residual = embed + attn_out_0 + attn_out_1. So logits = (embed @ W_U) + (attn_out @ W_U) + (attn_out_1 @ W_U)

We can be even more specific, and just look at the logit of the Potter token - this corresponds to a column of W_U, and so a direction in the residual stream - our logit is now a single number that is the sum of (embed @ potter_U) + (attn_out_0 @ potter_U) + (attn_out_1 @ potter_U). Even better, we can decompose each attention layer output into the sum of the result of each head, and use this to get many terms.

Your mission here is to write a function to look at how much each component contributes to the correct logit. Your components are:

  • The direct path (i.e. the residual connections from the embedding to unembedding),
  • Each layer 0 head (via the residual connection and skipping layer 1)
  • Each layer 1 head

To emphasise, these are not paths from the start to the end of the model, these are paths from the output of some component directly to the logits - we make no assumptions about how each path was calculated!

A few important notes for this exercise:

  • Here we are just looking at the DIRECT effect on the logits, i.e. the thing that this component writes / embeds into the residual stream - if heads compose with other heads and affect logits like that, or inhibit logits for other tokens to boost the correct one we will not pick up on this!
  • By looking at just the logits corresponding to the correct token, our data is much lower dimensional because we can ignore all other tokens other than the correct next one (Dealing with a 50K vocab size is a pain!). But this comes at the cost of missing out on more subtle effects, like a head suppressing other plausible logits, to increase the log prob of the correct one.
    • There are other situations where our job might be easier. For instance, in the IOI task (which we'll discuss shortly) we're just comparing the logits of the indirect object to the logits of the direct object, meaning we can use the difference between these logits, and ignore all the other logits.
  • When calculating correct output logits, we will get tensors with a dimension (position - 1,), not (position,) - we remove the final element of the output (logits), and the first element of labels (tokens). This is because we're predicting the next token, and we don't know the token after the final token, so we ignore it.
Aside - centering W_U

While we won't worry about this for this exercise, logit attribution is often more meaningful if we first center W_U - i.e. ensure the mean of each row writing to the output logits is zero. Log softmax is invariant when we add a constant to all the logits, so we want to control for a head that just increases all logits by the same amount. We won't do this here for ease of testing.

Question - why don't we do this to the log probs instead?

Because log probs aren't linear, they go through log_softmax, a non-linear function.

Exercise - build logit attribution tool

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You shouldn't spend more than 10-15 minutes on this exercise. This exercise is important, but has quite a few messy einsums, so you might get more value from reading the solution than doing the exercises.

You should implement the logit_attribution function below. This should return the contribution of each component in the "correct direction". We've already given you the unembedding vectors for the correct direction, W_U_correct_tokens (note that we take the [1:] slice of tokens, for reasons discussed above).

The code below this function will check your logit attribution function is working correctly, by taking the sum of logit attributions and comparing it to the actual values in the residual stream at the end of your model.

def logit_attribution(
    embed: Float[Tensor, "seq d_model"],
    l1_results: Float[Tensor, "seq nheads d_model"],
    l2_results: Float[Tensor, "seq nheads d_model"],
    W_U: Float[Tensor, "d_model d_vocab"],
    tokens: Int[Tensor, "seq"],
) -> Float[Tensor, "seq-1 n_components"]:
    """
    Inputs:
        embed: the embeddings of the tokens (i.e. token + position embeddings)
        l1_results: the outputs of the attention heads at layer 1 (with head as one of the dims)
        l2_results: the outputs of the attention heads at layer 2 (with head as one of the dims)
        W_U: the unembedding matrix
        tokens: the token ids of the sequence

    Returns:
        Tensor of shape (seq_len-1, n_components)
        represents the concatenation (along dim=-1) of logit attributions from:
            the direct path (seq-1,1)
            layer 0 logits (seq-1, n_heads)
            layer 1 logits (seq-1, n_heads)
        so n_components = 1 + 2*n_heads
    """
    W_U_correct_tokens = W_U[:, tokens[1:]]

    raise NotImplementedError()


text = "We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this."
logits, cache = model.run_with_cache(text, remove_batch_dim=True)
str_tokens = model.to_str_tokens(text)
tokens = model.to_tokens(text)

with t.inference_mode():
    embed = cache["embed"]
    l1_results = cache["result", 0]
    l2_results = cache["result", 1]
    logit_attr = logit_attribution(embed, l1_results, l2_results, model.W_U, tokens[0])
    # Uses fancy indexing to get a len(tokens[0])-1 length tensor, where the kth entry is the predicted logit for the correct k+1th token
    correct_token_logits = logits[0, t.arange(len(tokens[0]) - 1), tokens[0, 1:]]
    t.testing.assert_close(logit_attr.sum(1), correct_token_logits, atol=1e-3, rtol=0)
    print("Tests passed!")
Solution
def logit_attribution(
    embed: Float[Tensor, "seq d_model"],
    l1_results: Float[Tensor, "seq nheads d_model"],
    l2_results: Float[Tensor, "seq nheads d_model"],
    W_U: Float[Tensor, "d_model d_vocab"],
    tokens: Int[Tensor, "seq"],
) -> Float[Tensor, "seq-1 n_components"]:
    """
    Inputs:
        embed: the embeddings of the tokens (i.e. token + position embeddings)
        l1_results: the outputs of the attention heads at layer 1 (with head as one of the dims)
        l2_results: the outputs of the attention heads at layer 2 (with head as one of the dims)
        W_U: the unembedding matrix
        tokens: the token ids of the sequence
    Returns:
        Tensor of shape (seq_len-1, n_components)
        represents the concatenation (along dim=-1) of logit attributions from:
            the direct path (seq-1,1)
            layer 0 logits (seq-1, n_heads)
            layer 1 logits (seq-1, n_heads)
        so n_components = 1 + 2*n_heads
    """
    W_U_correct_tokens = W_U[:, tokens[1:]]
direct_attributions = einops.einsum(W_U_correct_tokens, embed[:-1], "emb seq, seq emb -> seq")
    l1_attributions = einops.einsum(
        W_U_correct_tokens, l1_results[:-1], "emb seq, seq nhead emb -> seq nhead"
    )
    l2_attributions = einops.einsum(
        W_U_correct_tokens, l2_results[:-1], "emb seq, seq nhead emb -> seq nhead"
    )
    return t.concat([direct_attributions.unsqueeze(-1), l1_attributions, l2_attributions], dim=-1)

Once you've got the tests working, you can visualise the logit attributions for each path through the model. We've provided you with the helper function plot_logit_attribution, which presents the results in a nice way.

embed = cache["embed"]
l1_results = cache["result", 0]
l2_results = cache["result", 1]
logit_attr = logit_attribution(embed, l1_results, l2_results, model.W_U, tokens.squeeze())

plot_logit_attribution(model, logit_attr, tokens, title="Logit attribution (demo prompt)")
Click to see the expected output

Question - what is the interpretation of this plot?

You should find that the most variation in the logit attribution comes from the direct path. In particular, some of the tokens in the direct path have a very high logit attribution (e.g. tokens 7, 12, 24, 38, 46, 58). Can you guess what gives them in particular such a high logit attribution?

Answer - what is special about these tokens?

The tokens with very high logit attribution are the ones which are the first token in common bigrams. For instance, the highest contribution on the direct path comes from | manip|, because this is very likely to be followed by |ulative| (or presumably a different stem like | ulation|). | super| -> |human| is another example of a bigram formed when the tokenizer splits one word into multiple tokens.

There are also examples that come from two different words, rather than a single word split by the tokenizer. These include:

| more| -> | likely| (12) | machine| -> | learning| (24) | by| -> | default| (38) | how| -> | to| (58)

See later for a discussion of all the ~infuriating~ fun quirks of tokenization!

Another feature of the plot - the heads in layer 1 seem to have much higher contributions than the heads in layer 0. Why do you think this might be?

Hint

Think about what this graph actually represents, in terms of paths through the transformer.

Answer - why might layer-1 heads have higher contributions?

This is because of a point we discussed earlier - this plot doesn't pick up on things like a head's effect in composition with another head. So the attribution for layer-0 heads won't involve any composition, whereas the attributions for layer-1 heads will involve not only the single-head paths through those attention heads, but also the 2-layer compositional paths through heads in layer 0 and layer 1.

Exercise - interpret logit attribution for the induction heads

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You shouldn't spend more than 10-15 minutes on this exercise.

This exercise just involves calling logit_attribution and plot_logit_attribution with appropriate arguments - the important part is interpreting the results. Please do look at the solutions if you're stuck on the code; this part isn't important.

Perform logit attribution for your attention-only model model, on the rep_cache. What do you expect to see?

# YOUR CODE HERE - plot logit attribution for the induction sequence (i.e. using `rep_tokens` and
# `rep_cache`), and interpret the results.
Click to see the expected output
Solution
seq_len = 50
embed = rep_cache["embed"]
l1_results = rep_cache["result", 0]
l2_results = rep_cache["result", 1]
logit_attr = logit_attribution(embed, l1_results, l2_results, model.W_U, rep_tokens.squeeze())
plot_logit_attribution(
    model, logit_attr, rep_tokens.squeeze(), title="Logit attribution (random induction prompt)"
)

What is the interpretation of this plot, in the context of our induction head circuit?

Answer

The first half of the plot is mostly meaningless, because the sequences here are random and carry no predictable pattern, and so there can't be any part of the model that is doing meaningful computation to make predictions.

In the second half, we see that heads 1.4 and 1.10 have a large logit attribution score. This makes sense given our previous observation that these heads seemed to be performing induction (since they both exhibited the characteristic induction pattern), however it's worth emphasizing that this plot gives us a different kind of evidence than looking at attention patterns does, because just observing some head is attending to a particular token doesn't mean it's necessarily using that information to make a concrete prediction. Note that we see head 1.10 has a larger direct effect than 1.4, which agrees with our attention scores result (where 1.10 also scored higher than 1.4).

Hooks: Intervening on Activations

Now that we've built some tools to decompose our model's output, it's time to start making causal interventions.

Ablations

Let's start with a simple example: ablation. An ablation is a simple causal intervention on a model - we pick some part of it and set it to zero. This is a crude proxy for how much that part matters. Further, if we have some story about how a specific circuit in the model enables some capability, showing that ablating other parts does nothing can be strong evidence of this.

As mentioned in the glossary, there are many ways to do ablation. We'll focus on the simplest: zero-ablation (even though it's somewhat unprincipled).

Exercise - induction head ablation

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should aim to spend 20-35 mins on this exercise.

The code below provides a template for performing zero-ablation on the output vectors at a particular head (i.e. the vectors we get when taking a weighted sum of the value vectors according to the attention probabilities, before projecting them up & adding them back to the residual stream). If you're confused about what different activations mean, you can refer back to the diagram.

You need to do 2 things:

  1. Fill in head_zero_ablation_hook so that it performs zero-ablation on the head given by head_index_to_ablate.
  2. Fill in the missing code in the get_ablation_scores function (i.e. where you see the raise NotImplementedError() line), so that loss_with_ablation is computed as the loss of the model after ablating head head in layer layer.

The rest of the get_ablation_scores function is designed to return a tensor of shape (n_layers, n_heads) containing the increase in loss from ablating each of these heads.

A few notes about this function / tips on how to implement it:

  • You can create a temporary hook function by applying functools.partial to the ablation_function, fixing the head index to a particular value.
  • You can use utils.get_act_name("z", layer) to get the name of the hook point (to see the full diagram of named hook points and how to get the names, you can refer to the streamlit reference page, which can be found on the left hand sidebar after you navigate to the homepage).
  • See that loss_no_ablation is computed with the get_log_probs function, and that we only take the last seq_len - 1 tokens - this is because we're dealing with sequences of length 2 * seq_len + 1 (a BOS token plus 2 repeated random sequences), and we only care about the loss on the second half of the sequence.
  • Note that we call model.reset_hooks() at the start of the function - this is a useful practice in general, to make sure you've not accidentally left in any hooks that might change your model's behaviour.
def head_zero_ablation_hook(
    z: Float[Tensor, "batch seq n_heads d_head"],
    hook: HookPoint,
    head_index_to_ablate: int,
) -> None:
    raise NotImplementedError()


def get_ablation_scores(
    model: HookedTransformer,
    tokens: Int[Tensor, "batch seq"],
    ablation_function: Callable = head_zero_ablation_hook,
) -> Float[Tensor, "n_layers n_heads"]:
    """
    Returns a tensor of shape (n_layers, n_heads) containing the increase in cross entropy loss
    from ablating the output of each head.
    """
    # Initialize an object to store the ablation scores
    ablation_scores = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

    # Calculating loss without any ablation, to act as a baseline
    model.reset_hooks()
    seq_len = (tokens.shape[1] - 1) // 2
    logits = model(tokens, return_type="logits")
    loss_no_ablation = -get_log_probs(logits, tokens)[:, -(seq_len - 1) :].mean()

    for layer in tqdm(range(model.cfg.n_layers)):
        for head in range(model.cfg.n_heads):
            raise NotImplementedError()

    return ablation_scores


ablation_scores = get_ablation_scores(model, rep_tokens)
tests.test_get_ablation_scores(ablation_scores, model, rep_tokens)
Solution
def head_zero_ablation_hook(
    z: Float[Tensor, "batch seq n_heads d_head"],
    hook: HookPoint,
    head_index_to_ablate: int,
) -> None:
    z[:, :, head_index_to_ablate, :] = 0.0
def get_ablation_scores(
    model: HookedTransformer,
    tokens: Int[Tensor, "batch seq"],
    ablation_function: Callable = head_zero_ablation_hook,
) -> Float[Tensor, "n_layers n_heads"]:
    """
    Returns a tensor of shape (n_layers, n_heads) containing the increase in cross entropy loss
    from ablating the output of each head.
    """
    # Initialize an object to store the ablation scores
    ablation_scores = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
# Calculating loss without any ablation, to act as a baseline
    model.reset_hooks()
    seq_len = (tokens.shape[1] - 1) // 2
    logits = model(tokens, return_type="logits")
    loss_no_ablation = -get_log_probs(logits, tokens)[:, -(seq_len - 1) :].mean()
for layer in tqdm(range(model.cfg.n_layers)):
        for head in range(model.cfg.n_heads):
            # Use functools.partial to create a temporary hook function with the head number fixed
            temp_hook_fn = functools.partial(ablation_function, head_index_to_ablate=head)
            # Run the model with the ablation hook
            ablated_logits = model.run_with_hooks(
                tokens, fwd_hooks=[(utils.get_act_name("z", layer), temp_hook_fn)]
            )
            # Calculate the loss difference (= neg correct logprobs), only on the last seq_len tokens
            loss = -get_log_probs(ablated_logits, tokens)[:, -(seq_len - 1) :].mean()
            # Store the result, subtracting the clean loss so that a value of 0 means no loss change
            ablation_scores[layer, head] = loss - loss_no_ablation
return ablation_scores

Once you've passed the tests, you can plot the results:

imshow(
    ablation_scores,
    labels={"x": "Head", "y": "Layer", "color": "Logit diff"},
    title="Loss Difference After Ablating Heads",
    text_auto=".2f",
    width=900,
    height=350,
)
Click to see the expected output (yours might be slightly different due to randomness)

What is your interpretation of these results?

Interpretation

This tells us not just which heads are responsible for writing output to the residual stream that gets us the correct result, but which heads play an important role in the induction circuit.

This chart tells us that - for sequences of repeated tokens - head 0.7 is by far the most important in layer 0 (which makes sense, since we observed it to be the strongest "previous token head"), and heads 1.4, 1.10 are the most important in layer 1 (which makes sense, since we observed these to be the most induction-y).

This is a good illustration of the kind of result which we can get from ablation, but wouldn't be able to get from something like direct logit attribution, because it isn't a causal intervention.

Exercise - mean ablation

Difficulty: 🔴⚪⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should aim to spend 5-15 mins on this exercise.

An alternative to zero-ablation is mean-ablation, where rather than setting values to zero, we set them to be their mean across some suitable distribution (commonly we'll use the mean over some batch dimension). This can be more informative, because zero-ablation takes a model out of its normal distribution, and so the results from it aren't necessarily representative of what you'd get if you "switched off" the effect from some particular component. Mean ablation on the other hand works slightly better (although it does come with its own set of risks). You can read more here or here.

You should fill in the head_mean_ablation_hook function below, and run the code (also make sure in your previous get_ablation_scores function that you were actually using the ablation_function rather than hardcoding the zero ablation function, otherwise your code won't work here). You should see that the results are slightly cleaner, with the unimportant heads having values much closer to zero relative to the important heads.

def head_mean_ablation_hook(
    z: Float[Tensor, "batch seq n_heads d_head"],
    hook: HookPoint,
    head_index_to_ablate: int,
) -> None:
    raise NotImplementedError()


rep_tokens_batch = run_and_cache_model_repeated_tokens(model, seq_len=50, batch_size=10)[0]
mean_ablation_scores = get_ablation_scores(
    model, rep_tokens_batch, ablation_function=head_mean_ablation_hook
)

imshow(
    mean_ablation_scores,
    labels={"x": "Head", "y": "Layer", "color": "Logit diff"},
    title="Loss Difference After Ablating Heads",
    text_auto=".2f",
    width=900,
    height=350,
)
Click to see the expected output (yours might be slightly different due to randomness)
Solution
def head_mean_ablation_hook(
    z: Float[Tensor, "batch seq n_heads d_head"],
    hook: HookPoint,
    head_index_to_ablate: int,
) -> None:
    z[:, :, head_index_to_ablate, :] = z[:, :, head_index_to_ablate, :].mean(0)
rep_tokens_batch = run_and_cache_model_repeated_tokens(model, seq_len=50, batch_size=10)[0]
mean_ablation_scores = get_ablation_scores(
    model, rep_tokens_batch, ablation_function=head_mean_ablation_hook
)
imshow(
    mean_ablation_scores,
    labels={"x": "Head", "y": "Layer", "color": "Logit diff"},
    title="Loss Difference After Ablating Heads",
    text_auto=".2f",
    width=900,
    height=350,
)

Bonus - understand heads 0.4 & 0.11 (very hard!)

There are 2 heads which appeared strongly in our induction ablation experiments, but haven't stood out as much in the other analysis we've done in this section 0.4 and 0.11. Can you construct causal experiments (i.e. targeted ablations) to try and figure out what these heads are doing?

Note - you might want to attempt this once you've made some headway into the next section, as this will give you a more mechanistic understanding of the induction circuit. Even once you've done that, you might still find this bonus exercise challenging, because it ventures outside of the well-defined induction circuit we've been working with and into potentially more ambiguous results. To restate - the material here is very challenging!

Here's a hint to get you started

Look at the positions that heads 0.4 and 0.11 are attending to. Can you figure out which source positions are important to attend to for the model to perform well?

Partial answer (and some sample code)

Below is some sample code which plots the effect of ablating the inputs to heads 0.4 and 0.11 at all offset positions minus a few (e.g. the first row shows the effect on loss of mean ablating all inputs to the heads except for those that come from self-attention, and the second row shows the effect when we ablate all inputs except for those that come from the token immediately before it in the sequence).

def head_z_ablation_hook(
    z: Float[Tensor, "batch seq n_heads d_head"],
    hook: HookPoint,
    head_index_to_ablate: int,
    seq_posns: list[int],
    cache: ActivationCache,
) -> None:
    """
    We perform ablation at the z vector, by doing the equivalent of mean ablating all the inputs to this attention head
    except for those which come from the tokens n positions back, where n is in the seq_posns list.
    """
    batch, seq = z.shape[:2]
    v = cache["v", hook.layer()][:, :, head_index_to_ablate]  # shape [batch seq_K d_head]
    pattern = cache["pattern", hook.layer()][:, head_index_to_ablate]  # shape [batch seq_Q seq_K]
# Get a repeated version of v, and mean ablate all but the previous token values
    v_repeated = einops.repeat(v, "b sK h -> b sQ sK h", sQ=seq)
    v_ablated = einops.repeat(v_repeated.mean(0), "sQ sK h -> b sQ sK h", b=batch).clone()
    for offset in seq_posns:
        seqQ_slice = t.arange(offset, seq)
        v_ablated[:, seqQ_slice, seqQ_slice - offset] = v_repeated[:, seqQ_slice, seqQ_slice - offset]
# Take weighted sum of this new v, and use it to edit z inplace.
    z[:, :, head_index_to_ablate] = einops.einsum(v_ablated, pattern, "b sQ sK h, b sQ sK -> b sQ h")
def get_ablation_scores_cache_assisted(
    model: HookedTransformer,
    tokens: Int[Tensor, "batch seq"],
    ablation_function: Callable = head_zero_ablation_hook,
    seq_posns: list[int] = [0],
    layers: list[int] = [0],
) -> Float[Tensor, "n_layers n_heads"]:
    """
    Version of get_ablation_scores which can use the cache to assist with the ablation.
    """
    ablation_scores = t.zeros((len(layers), model.cfg.n_heads), device=model.cfg.device)
model.reset_hooks()
    seq_len = (tokens.shape[1] - 1) // 2
    logits, cache = model.run_with_cache(tokens, return_type="logits")
    loss_no_ablation = -get_log_probs(logits, tokens)[:, -(seq_len - 1) :].mean()
for layer in layers:
        for head in range(model.cfg.n_heads):
            temp_hook_fn = functools.partial(ablation_function, head_index_to_ablate=head, cache=cache, seq_posns=seq_posns)
            ablated_logits = model.run_with_hooks(tokens, fwd_hooks=[(utils.get_act_name("z", layer), temp_hook_fn)])
            loss = -get_log_probs(ablated_logits, tokens)[:, -(seq_len - 1) :].mean()
            ablation_scores[layer, head] = loss - loss_no_ablation
return ablation_scores
rep_tokens_batch = run_and_cache_model_repeated_tokens(model, seq_len=50, batch_size=50)[0]
offsets = [[0], [1], [2], [3], [1, 2], [1, 2, 3]]
z_ablation_scores = [
    get_ablation_scores_cache_assisted(model, rep_tokens_batch, head_z_ablation_hook, offset).squeeze()
    for offset in tqdm(offsets)
]
imshow(
    t.stack(z_ablation_scores),
    labels={"x": "Head", "y": "Position offset", "color": "Logit diff"},
    title="Loss Difference (ablating heads everywhere except for certain offset positions)",
    text_auto=".2f",
    y=[str(offset) for offset in offsets],
    width=900,
    height=400,
)

Some observations from the result of this code:

- Head 0.7 is truly a previous token head. The second row shows that mean ablating all its inputs except for those that come from the previous token has no effect on loss, so this is all the information it's using. - Head 0.11 is only a current token head. The first row shows that mean ablating all its inputs except for those that come from self-attending (i.e. to the current token) has no effect on loss, so this is all the information it's using. - Head 0.4 is only using information from positions 1, 2 or 3 tokens back. This is shown from the 5th row of the plot above - the effect of ablating all inputs except for those that come from tokens 1 or 2 positions back is very small. Note that it's important we draw this conclusion from an ablation experiment, not just from looking at attention patterns - because attending to a token doesn't tell you whether that token is being used for a way that's important in the context of this particular distribution (induction).

Starting with 0.11 - we know that there are heads in layer 1 whose job it is to copy tokens - i.e. in sequences [A][B]...[A][B], they attend from the second [A] back to the first [B] and copy its value to use as a prediction. And if head 0.11 always self-attends, then it actually makes sense to consider (embedding of B) + (output of head 0.11 when it attends to token B) as the "true embedding of B", since this is always the thing that the layer 1 head will be learning to copy. This idea of an extended embedding or effective embedding will come up again later in the course, when we look at GPT2-Small. As for whether the output of 0.11 is more important in the QK circuit of the layer-1 copying head, or the OV copying head, we'll leave as an exercise to the reader!

Next, 0.4 - it's using information from both 1 and 2 tokens back. Using the previous token makes sense, since induction circuits contain previous token heads. But what could it be doing with the information 2 positions back? One theory we might have is that it's also creating an induction circuit, but using 3 tokens rather than 2 tokens! In other words, rather than having sequences like [A][B]...[A][B] where the second [A] attends back to "token that came immediately after the value of this token", we might have sequences like [Z][A][B]...[Z][A][B] where the second [A] attends back to "token that came 2 positions after the value of the previous token". One way to test this would be to construct random induction sequences which have a maximum of 2 repetitions, i.e. they're constructed with the first half being random sequences and the second half being pairs of randomly chosen tokens which appear in the first half adjacent to each other. To illustrate, for vocab size of 10 and half seq len of 10, we might have a sequence like:

0 5 8 3 1 8 2 2 4 6 (5 8) (4 6) (3 1) (2 4) (0 5)

Based on our theory about head 0.4, we should expect that mean ablating it in this kind of sequence should have nearly zero effect on loss (because it's designed to support induction sequences of length at least 3), even though all the other heads which were identified as important in the induction experiment (0.7, 0.11, 1.4, 1.10) should still be important. This is in fact what we find - you can try this for yourself with the code below.

def generate_repeated_tokens_maxrep(
    model: HookedTransformer,
    seq_len: int,
    batch_size: int = 1,
    maxrep: int = 2,
) -> Int[Tensor, "batch_size full_seq_len"]:
    """
    Same as previous function, but contains a max number of allowed repetitions. For example, maxrep=2 means we can have
    sequences like [A][B]...[A][B], but not [A][B][C]...[A][B][C].
    """
    prefix = (t.ones(batch_size, 1)  model.tokenizer.bos_token_id).long()
    rep_tokens_half = t.randint(0, model.cfg.d_vocab, (batch_size, seq_len), dtype=t.int64)
    rep_tokens = t.cat([prefix, rep_tokens_half], dim=-1)
    for _ in range(seq_len // maxrep + 1):
        random_start_posn = t.randint(0, seq_len - 2, (batch_size,)).tolist()
        rep_tokens_repeated = t.stack([rep_tokens_half[b, s : s + maxrep] for b, s in enumerate(random_start_posn)])
        rep_tokens = t.cat([rep_tokens, rep_tokens_repeated], dim=-1)
return rep_tokens[:, : 2  seq_len + 1].to(device)
rep_tokens_max2 = generate_repeated_tokens_maxrep(model, seq_len=50, batch_size=50, maxrep=2)
mean_ablation_scores = get_ablation_scores(model, rep_tokens_max2, ablation_fn=head_mean_ablation_hook)
imshow(
    mean_ablation_scores,
    labels={"x": "Head", "y": "Layer", "color": "Logit diff"},
    title="Loss Difference After Ablating Heads",
    text_auto=".2f",
    width=900,
    height=350,
)