2️⃣ Logit Attribution

Learning Objectives
  • Perform direct logit attribution to figure out which heads are writing to the residual stream in a significant way
  • Learn how to use different transformerlens helper functions, which decompose the residual stream in different ways

Direct Logit Attribution

The easiest part of the model to understand is the output - this is what the model is trained to optimize, and so it can always be directly interpreted! Often the right approach to reverse engineering a circuit is to start at the end, understand how the model produces the right answer, and to then work backwards (you will have seen this if you went through the balanced bracket classifier task, and in fact if you did then this section will probably be quite familiar to you and you should feel free to just skim through it). The main technique used to do this is called direct logit attribution

Background: The central object of a transformer is the residual stream. This is the sum of the outputs of each layer and of the original token and positional embedding. Importantly, this means that any linear function of the residual stream can be perfectly decomposed into the contribution of each layer of the transformer. Further, each attention layer's output can be broken down into the sum of the output of each head (See A Mathematical Framework for Transformer Circuits for details), and each MLP layer's output can be broken down into the sum of the output of each neuron (and a bias term for each layer).

The logits of a model are logits=Unembed(LayerNorm(final_residual_stream)). The Unembed is a linear map, and LayerNorm is approximately a linear map, so we can decompose the logits into the sum of the contributions of each component, and look at which components contribute the most to the logit of the correct token! This is called direct logit attribution. Here we look at the direct attribution to the logit difference!

Background and motivation of the logit difference

Logit difference is actually a really nice and elegant metric and is a particularly nice aspect of the setup of Indirect Object Identification. In general, there are two natural ways to interpret the model's outputs: the output logits, or the output log probabilities (or probabilities).

The logits are much nicer and easier to understand, as noted above. However, the model is trained to optimize the cross-entropy loss (the average of log probability of the correct token). This means it does not directly optimize the logits, and indeed if the model adds an arbitrary constant to every logit, the log probabilities are unchanged.

But we have:

log_probs == logits.log_softmax(dim=-1) == logits - logsumexp(logits)

and because they differ by a constant, we have:

log_probs(" Mary") - log_probs(" John") = logits(" Mary") - logits(" John")
  • the ability to add an arbitrary constant cancels out!
Technical details (if this equivalence doesn't seem obvious to you)

Let $\vec{\textbf{x}}$ be the logits, $\vec{\textbf{L}}$ be the log probs, and $\vec{\textbf{p}}$ be the probs. Then we have the following relations:

$$ p_i = \operatorname{softmax}(\vec{\textbf{x}})_i = \frac{e^{x_i}}{\sum_{i=1}^n e^{x_i}} $$

and:

$$ L_i = \log p_i $$

Combining these, we get:

$$ L_i = \log \frac{e^{x_i}}{\sum_{j=1}^n e^{x_j}} = x_i - \log \sum_{j=1}^n e^{x_j} $$

Notice that the sum term on the right hand side is the same for all $i$, so we get:

$$ L_i - L_j = x_i - x_j $$

in other words, the logit diff $x_i - x_j$ is the same as the log prob diff. This motivates the choice of logit diff as our choice of metric (since the model is directly training to make the log prob of the correct token large, and all other log probs small).

Further, the metric helps us isolate the precise capability we care about - figuring out which name is the Indirect Object. There are many other components of the task - deciding whether to return an article (the) or pronoun (her) or name, realising that the sentence wants a person next at all, etc. By taking the logit difference we control for all of that.

Our metric is further refined, because each prompt is repeated twice, for each possible indirect object. This controls for irrelevant behaviour such as the model learning that John is a more frequent token than Mary (this actually happens! The final layernorm bias increases the John logit by 1 relative to the Mary logit). Another way to handle this would be to use a large enough dataset (with names randomly chosen) that this effect is averaged out, which is what we'll do in section 3.

Ignoring LayerNorm

LayerNorm is an analogous normalization technique to BatchNorm (that's friendlier to massive parallelization) that transformers use. Every time a transformer layer reads information from the residual stream, it applies a LayerNorm to normalize the vector at each position (translating to set the mean to 0 and scaling to set the variance to 1) and then applying a learned vector of weights and biases to scale and translate the normalized vector. This is almost a linear map, apart from the scaling step, because that divides by the norm of the vector and the norm is not a linear function. (The fold_ln flag when loading a model factors out all the linear parts).

But if we fixed the scale factor, the LayerNorm would be fully linear. And the scale of the residual stream is a global property that's a function of all components of the stream, while in practice there is normally just a few directions relevant to any particular component, so in practice this is an acceptable approximation. So when doing direct logit attribution we use the apply_ln flag on the cache to apply the global layernorm scaling factor to each constant. See [my clean GPT-2 implementation](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=Clean_Transformer_Implementation) for more on LayerNorm.

Logit diff directions

Getting an output logit is equivalent to projecting onto a direction in the residual stream, and the same is true for getting the logit diff.

If it's not clear what is meant by this statement, read this dropdown.

Suppose our final value in the residual stream for a single sequence and a position within that sequence is $x$ (i.e. $x$ is a vector of length $d_{model}$). Then (ignoring layernorm - see the point above for why it's okay to do this), we get logits by multiplying by the unembedding matrix $W_U$ (which has shape $(d_{model}, d_{vocab})$):

$$ \text{output} = x^T W_U $$

Now, remember that we want the logit diff, which is $\text{output}_{IO} - \text{output}_{S}$ (the difference between the logits for our indirect object and subject). We can write this as:

$$ \text{logit diff} = (x^T W_U)_{IO} - (x^T W_U)_{S} = x^T (u_{IO} - u_{S}) $$

where $u_{IO}$ and $u_S$ are the columns of the unembedding matrix $W_U$ corresponding to the indirect object and subject tokens respectively.

To summarize, we've written the logit diff as a dot product between the vector in the residual stream and a constant vector (which is a function of the model's unembedding matrix). We call this vector $u_{IO} - u_{S}$ the logit difference direction (because it "points in the direction of largest logit difference"). To put it another way, if $x$ is a vector of fixed magnitude, then it maximises the logit difference when it is pointing in the same direction as the vector $u_{IO} - u_{S}$. We use the term "projection" synonymously with "dot product" here.

(If you've completed the exercise where we interpret a transformer on balanced / unbalanced bracket strings, this is basically the same principle. The only difference here is that we actually have a much larger unembedding vocabulary than just the classifications {balanced, unbalanced}, but since we're only interested in comparing the model's prediction for IO vs S, and the logits for these two tokens are usually larger than most others, this method is still well-justified).

We use model.tokens_to_residual_directions to map the answer tokens to that direction, and then convert this to a logit difference direction for each batch

answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)  # [batch 2 d_model]
print("Answer residual directions shape:", answer_residual_directions.shape)

correct_residual_directions, incorrect_residual_directions = answer_residual_directions.unbind(
    dim=1
)
logit_diff_directions = (
    correct_residual_directions - incorrect_residual_directions
)  # [batch d_model]
print("Logit difference directions shape:", logit_diff_directions.shape)

To verify that this works, we can apply this to the final residual stream for our cached prompts (after applying LayerNorm scaling) and verify that we get the same answer.

Technical details logits = Unembed(LayerNorm(final_residual_stream)), so we technically need to account for the centering, and then learned translation and scaling of the layernorm, not just the variance 1 scaling.

The centering is accounted for with the preprocessing flag center_writing_weights which ensures that every weight matrix writing to the residual stream has mean zero.

The learned scaling is folded into the unembedding weights model.unembed.W_U via W_U_fold = layer_norm.weights[:, None] * unembed.W_U

The learned translation is folded to model.unembed.b_U, a bias added to the logits (note that GPT-2 is not trained with an existing b_U). This roughly represents unigram statistics. But we can ignore this because each prompt occurs twice with names in the opposite order, so this perfectly cancels out.

Note that rather than using layernorm scaling we could just study cache["ln_final.hook_normalised"]

The code below does the following:

  • Gets the final residual stream values from the cache object (which you should already have defined above).
  • Apply layernorm scaling to these values.
    • This is done by cache.apply_to_ln_stack, a helpful function which takes a stack of residual stream values (e.g. a batch, or the residual stream decomposed into components), treats them as the input to a specific layer, and applies the layer norm scaling of that layer to them.
    • The keyword arguments here indicate that our input is the residual stream values for the last sequence position, and we want to apply the final layernorm in the model.
  • Project them along the unembedding directions (you've already defined these above, as logit_diff_directions).
# Cache syntax: resid_post is the residual stream at the end of the layer, -1 gets the final layer.
# The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream: Float[Tensor, "batch seq d_model"] = cache["resid_post", -1]
print(f"Final residual stream shape: {final_residual_stream.shape}")
final_token_residual_stream: Float[Tensor, "batch d_model"] = final_residual_stream[:, -1, :]

# Apply LayerNorm scaling (to just the final sequence position)
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einops.einsum(
    scaled_final_token_residual_stream, logit_diff_directions, "batch d_model, batch d_model ->"
) / len(prompts)

print(f"Calculated average logit diff: {average_logit_diff:.10f}")
print(f"Original logit difference:     {original_average_logit_diff:.10f}")

t.testing.assert_close(average_logit_diff, original_average_logit_diff)

Logit Lens

We can now decompose the residual stream! First we apply a technique called the logit lens - this looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequence layers.

Exercise - implement residual_stack_to_logit_diff

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-15 minutes on this exercise. Again, make sure you understand what the output of this function represents.

This function should look a lot like your code immediately above. residual_stack is a tensor of shape (..., batch, d_model) containing the residual stream values for the final sequence position. You should apply the final layernorm to these values, then project them in the logit difference directions.

def residual_stack_to_logit_diff(
    residual_stack: Float[Tensor, "... batch d_model"],
    cache: ActivationCache,
    logit_diff_directions: Float[Tensor, "batch d_model"] = logit_diff_directions,
) -> Float[Tensor, "..."]:
    """
    Gets the avg logit difference between the correct and incorrect answer for a given stack of
    components in the residual stream.
    """
    raise NotImplementedError()


# Test function by checking that it gives the same result as the original logit difference
t.testing.assert_close(
    residual_stack_to_logit_diff(final_token_residual_stream, cache), original_average_logit_diff
)
Solution
def residual_stack_to_logit_diff(
    residual_stack: Float[Tensor, "... batch d_model"],
    cache: ActivationCache,
    logit_diff_directions: Float[Tensor, "batch d_model"] = logit_diff_directions,
) -> Float[Tensor, "..."]:
    """
    Gets the avg logit difference between the correct and incorrect answer for a given stack of
    components in the residual stream.
    """
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
    return (
        einops.einsum(
            scaled_residual_stack, logit_diff_directions, "... batch d_model, batch d_model -> ..."
        )
        / batch_size
    )

Once you have the solution, you can plot your results.

Details on accumulated_resid

Key for the plot below: n_pre means the residual stream at the start of layer n, n_mid means the residual stream after the attention part of layer n (n_post is the same as n+1_pre so is not included)

layer is the layer for which we input the residual stream (this is used to identify which layer norm scaling factor we want) incl_mid is whether to include the residual stream in the middle of a layer, ie after attention & before MLP pos_slice is the subset of the positions used. See utils.Slice for details on the syntax. return_labels is whether to return the labels for each component returned (useful for plotting)
accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
# accumulated_residual has shape (component, batch, d_model)

logit_lens_logit_diffs: Float[Tensor, "component"] = residual_stack_to_logit_diff(
    accumulated_residual, cache
)

line(
    logit_lens_logit_diffs,
    hovermode="x unified",
    title="Logit Difference From Accumulated Residual Stream",
    labels={"x": "Layer", "y": "Logit Diff"},
    xaxis_tickvals=labels,
    width=800,
)
Click to see the expected output
Question - what is the interpretation of this plot? What does this tell you about how the model solves this task?

Fascinatingly, we see that the model is utterly unable to do the task until layer 7, almost all performance comes from attention layer 9, and performance actually decreases from there.

This tells us that there must be something going on (primarily in layers 7, 8 and 9) which writes to the residual stream in the correct way to solve the IOI task. This allows us to narrow in our focus, and start asking questions about what kind of computation is going on in those layers (e.g. the contribution of attention layers vs MLPs, and which attention heads are most important).

Layer Attribution

We can repeat the above analysis but for each layer (this is equivalent to the differences between adjacent residual streams)

Note: Annoying terminology overload - layer k of a transformer means the kth transformer block, but each block consists of an attention layer (to move information around) and an MLP layer (to process information).

per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)

line(
    per_layer_logit_diffs,
    hovermode="x unified",
    title="Logit Difference From Each Layer",
    labels={"x": "Layer", "y": "Logit Diff"},
    xaxis_tickvals=labels,
    width=800,
)
Click to see the expected output
Question - what is the interpretation of this plot? What does this tell you about how the model solves this task?

We see that only attention layers matter, which makes sense! The IOI task is about moving information around (i.e. moving the correct name and not the incorrect name), and less about processing it. And again we note that attention layer 9 improves things a lot, while attention 10 and attention 11 decrease performance.

Head Attribution

We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 12 heads, which each act independently and additively.

Decomposing attention output into sums of heads

The standard way to compute the output of an attention layer is by concatenating the mixed values of each head, and multiplying by a big output weight matrix. But as described in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) this is equivalent to splitting the output weight matrix into a per-head output (here model.blocks[k].attn.W_O) and adding them up (including an overall bias term for the entire layer).

per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_residual = einops.rearrange(
    per_head_residual, "(layer head) ... -> layer head ...", layer=model.cfg.n_layers
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)

imshow(
    per_head_logit_diffs,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
    width=600,
)
Click to see the expected output

We see that only a few heads really matter - heads 9.6 and 9.9 contribute a lot positively (explaining why attention layer 9 is so important), while heads 10.7 and 11.10 contribute a lot negatively (explaining why attention layer 10 and layer 11 are actively harmful). These correspond to (some of) the name movers and negative name movers discussed in the paper. There are also several heads that matter positively or negatively but less strongly (other name movers and backup name movers).

There are a few meta observations worth making here - our model has 144 heads, yet we could localise this behaviour to a handful of specific heads, using straightforward, general techniques. This supports the claim in A Mathematical Framework that attention heads are the right level of abstraction to understand attention. It also really surprising that there are negative heads - eg 10.7 makes the incorrect logit 7x more likely. I'm not sure what's going on there, though the paper discusses some possibilities.

Recap of useful functions from this section

Here, we take stock of all the functions from transformerlens which you might not have seen previously.

  • cache.apply_ln_to_stack
    • Apply layernorm scaling to a stack of residual stream values.
    • We used this to help us go from "final value in residual stream" to "projection of logits in logit difference directions", without getting the code too messy!
  • cache.accumulated_resid(layer=None)
    • Returns the accumulated residual stream up to layer layer (or up to the final value of residual stream if layer is None), i.e. a stack of previous residual streams up to that layer's input.
    • Useful when studying the logit lens.
    • First dimension of output is (0_pre, 0_mid, 1_pre, 1_mid, ..., final_post)
  • cache.decompose_resid(layer).
    • Decomposes the residual stream input to layer layer into a stack of the output of previous layers. The sum of these is the input to layer layer.
    • First dimension of output is (embed, pos_embed, 0_attn_out, 0_mlp_out, ...).
  • cache.stack_head_results(layer)
    • Returns a stack of all head results (i.e. residual stream contribution) up to layer layer
    • (i.e. like decompose_resid except it splits each attention layer by head rather than splitting each layer by attention/MLP)
    • First dimension of output is layer * head (we needed to rearrange to (layer, head) to plot it).

Attention Analysis

Attention heads are particularly fruitful to study because we can look directly at their attention patterns and study from what positions they move information from and to. This is particularly useful here as we're looking at the direct effect on the logits so we need only look at the attention patterns from the final token.

We use the circuitsvis library (developed from Anthropic's PySvelte library) to visualize the attention patterns! We visualize the top 3 positive and negative heads by direct logit attribution, and show these for the first prompt (as an illustration).

Interpreting Attention Patterns

A common mistake to make when looking at attention patterns is thinking that they must convey information about the token looked at (maybe accounting for the context of the token). But actually, all we can confidently say is that it moves information from the residual stream position corresponding to that input token. Especially later on in the model, there may be components in the residual stream that are nothing to do with the input token! Eg the period at the end of a sentence may contain summary information for that sentence, and the head may solely move that, rather than caring about whether it ends in ".", "!" or "?"

def topk_of_Nd_tensor(tensor: Float[Tensor, "rows cols"], k: int):
    """
    Helper function: does same as tensor.topk(k).indices, but works over 2D tensors.
    Returns a list of indices, i.e. shape [k, tensor.ndim].

    Example: if tensor is 2D array of values for each head in each layer, this will
    return a list of heads.
    """
    i = t.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()


k = 3

for head_type in ["Positive", "Negative"]:
    # Get the heads with largest (or smallest) contribution to the logit difference
    top_heads = topk_of_Nd_tensor(
        per_head_logit_diffs * (1 if head_type == "Positive" else -1), k
    )

    # Get all their attention patterns
    attn_patterns_for_important_heads: Float[Tensor, "head q k"] = t.stack(
        [cache["pattern", layer][:, head][0] for layer, head in top_heads]
    )

    # Display results
    display(HTML(f"<h2>Top {k} {head_type} Logit Attribution Heads</h2>"))
    display(
        cv.attention.attention_patterns(
            attention=attn_patterns_for_important_heads,
            tokens=model.to_str_tokens(tokens[0]),
            attention_head_names=[f"{layer}.{head}" for layer, head in top_heads],
        )
    )
Click to see the expected output

Reminder - you can use attention_patterns or attention_heads for these visuals. The former lets you see the actual values, the latter lets you hover over tokens in a printed sentence (and it provides other useful features like locking on tokens, or a superposition of all heads in the display). Both can be useful in different contexts (although I'd recommend usually using attention_patterns, it's more useful in most cases for quickly getting a sense of attention patterns).

Try replacing attention_patterns above with attention_heads, and compare the output.

Help - my attention_heads plots are behaving weirdly.

This seems to be a bug in circuitsvis - on VSCode, the attention head plots continually shrink in size.

Until this is fixed, one way to get around it is to open the plots in your browser. You can do this inline with the webbrowser library:

attn_heads = cv.attention.attention_heads(
    attention = attn_patterns_for_important_heads,
    tokens = model.to_str_tokens(tokens[0]),
    attention_head_names = [f"{layer}.{head}" for layer, head in top_heads],
)
path = "attn_heads.html"
with open(path, "w") as f:
    f.write(str(attn_heads))
webbrowser.open(path)

To check exactly where this is getting saved, you can print your current working directory with os.getcwd().

From these plots, you might want to start thinking about the algorithm which is being implemented. In particular, for the attention heads with high positive attribution scores, where is " to" attending to? How might this head be affecting the logit diff score?

We'll save a full hypothesis for how the model works until the end of the next section.