Exercise Status: All exercises complete and verified

1️⃣ Latent Gradients

Learning Objectives
  • Learn how to compute latent-to-latent gradients between SAE latents in different layers of the transformer
  • Compute token-to-latent gradients to understand which input tokens drive particular latent activations
  • Compute latent-to-logit gradients to understand how latents affect the model's output
  • Use these gradient-based methods to find circuits in attention SAEs (e.g. induction circuits)

Introduction

Previous SAE work (in material 1.3.3) has focused on understanding individual latents. Here, we address a highly important topic - what about circuits of SAE latents? Circuit analysis has already been somewhat successful in language model interpretability (e.g. see Anthropic's work on induction circuits, or the Indirect Object Identification paper), but many attempts to push circuit analysis further have hit speedbumps: most connections in the model are not sparse, and it's very hard to disentangle all the messy cross-talk between different components and residual stream subspaces. Circuits offer a better path forward, since we should expect that not only are individual latents generally sparse, they are also sparsely connected - any given latent should probably only have a downstream effect on a small number of other latents.

Indeed, if this does end up being true, it's a strong piece of evidence that latents found by SAEs are the fundamental units of computation used by the model, as opposed to just being an interesting clustering algorithm. Of course we do already have some evidence for this (e.g. the effectiveness of latent steering, and the fact that latents have already revealed important information about models which isn't clear when just looking at the basic components), but finding clear latent circuits would be a much stronger piece of evidence.

Latent Gradients

We'll start with an exercise that illustrates the kind of sparsity you can expect in latent connections, as well as many of the ways latent circuit analysis can be challenging. We'll be implementing the latent_to_latent_gradients function, which returns the gradients between all active pairs of latents belonging to SAEs in two different layers (we'll be using two SAEs from our gpt2-small-res-jb release). These exercises will be split up into a few different steps, since computing these gradients is deceptively complex.

What exactly are latent gradients? Well, for any given input, and any 2 latents in different layers, we can compute the derivative of the second latent's activation with respect to the first latent. This takes the form of a matrix of partial derivatives, i.e. $J_{ij} = \frac{\partial f_i}{\partial x_j}$, and can serve as a linear proxy for how latents in an early layer contribute to latents in a later layer. The pseudocode for computing this is:

# Computed with no gradients, and not patching in SAE reconstructions...
layer_1_latents, layer_2_latents = model.run_with_cache_with_saes(...)

def latent_acts_to_later_latent_acts(layer_1_latents):
    layer_1_resid_acts_recon = SAE_1_decoder(layer_1_latents)
    layer_2_resid_acts_recon = model.blocks[layer_1: layer_2].forward(layer_1_resid_acts_recon)
    layer_2_latents_recon = SAE_2_encoder(layer_2_resid_acts_recon)
    return layer_2_latents_recon

latent_latent_gradients = t.func.jacrev(latent_acts_to_later_latent_acts)(layer_1_latents)

where jacrev is shorthand for "Jacobian reverse-mode differentiation" - it's a PyTorch function that takes in a tensor -> tensor function f(x) = y and returns the Jacobian function, i.e. g s.t. g[i, j] = d(f[x]_i) / d(x_j).

If we wanted to get a sense of how latents communicate with each other across our distribution of data, then we might average these results over a large set of prompts. However for now, we're going to stick with a relatively small set of prompts to avoid running into memory issues, and so we can visualise the results more easily.

First, let's load in our model & SAEs, if you haven't already:

gpt2 = HookedSAETransformer.from_pretrained("gpt2-small", device=device, dtype=dtype)

gpt2_saes = {
    layer: SAE.from_pretrained(
        release="gpt2-small-res-jb",
        sae_id=f"blocks.{layer}.hook_resid_pre",
        device=device,
        dtype=dtype,
    )
    for layer in tqdm(range(gpt2.cfg.n_layers))
}

Now, we can start the exercises!

Note - the subsequent 3 exercises are all somewhat involved, and things like the use of Jacobian can be quite fiddly. For that reason, there's a good case to be made for just reading through the solutions and understanding what the code is doing, rather than trying to do it yourself. One option would be to look at the solutions for these 3 exercises and understand how latent-to-latent gradients work, but then try and implement the token_to_latent_gradients function (after the next 3 exercises) yourself.

Exercise (1/3) - implement the SparseTensor class

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵⚪⚪⚪⚪
You should spend up to 10-15 minutes on this exercise (or skip it)

Firstly, we're going to create a SparseTensor class to help us work with sparse tensors (i.e. tensors where most of the elements are zero). This is because we'll need to do forward passes on the dense tensors (i.e. the tensors containing all values, including the zeros) but we'll often want to compute gradients wrt the sparse tensors (just the non-zero values) because otherwise we'd run into memory issues - there are a lot of latents!

You should fill in the from_dense and from_sparse class methods for the SparseTensor class. The testing code is visible to you, and should help you understand how this class is expected to behave.

class SparseTensor:
    """
    Handles 2D tensor data (assumed to be non-negative) in 2 different formats:
        dense:  The full tensor, which contains zeros. Shape is (n1, ..., nk).
        sparse: A tuple of nonzero values with shape (n_nonzero,), nonzero indices with shape
                (n_nonzero, k), and the shape of the dense tensor.
    """

    sparse: tuple[Tensor, Tensor, tuple[int, ...]]
    dense: Tensor

    def __init__(self, sparse: tuple[Tensor, Tensor, tuple[int, ...]], dense: Tensor):
        self.sparse = sparse
        self.dense = dense

    @classmethod
    def from_dense(cls, dense: Tensor) -> "SparseTensor":
        """Creates a SparseTensor from a dense tensor, extracting the positive entries."""
        raise NotImplementedError()

    @classmethod
    def from_sparse(cls, sparse: tuple[Tensor, Tensor, tuple[int, ...]]) -> "SparseTensor":
        """Creates a SparseTensor from a (values, indices, shape) tuple, reconstructing the dense tensor."""
        raise NotImplementedError()

    @property
    def values(self) -> Tensor:
        return self.sparse[0].squeeze()

    @property
    def indices(self) -> Tensor:
        return self.sparse[1].squeeze()

    @property
    def shape(self) -> tuple[int, ...]:
        return self.sparse[2]


# Test `from_dense`
x = t.zeros(10_000)
nonzero_indices = t.randint(0, 10_000, (10,)).sort().values
nonzero_values = t.rand(10)
x[nonzero_indices] = nonzero_values
sparse_tensor = SparseTensor.from_dense(x)
t.testing.assert_close(sparse_tensor.sparse[0], nonzero_values)
t.testing.assert_close(sparse_tensor.sparse[1].squeeze(-1), nonzero_indices)
t.testing.assert_close(sparse_tensor.dense, x)

# Test `from_sparse`
sparse_tensor = SparseTensor.from_sparse((nonzero_values, nonzero_indices.unsqueeze(-1), tuple(x.shape)))
t.testing.assert_close(sparse_tensor.dense, x)

# Verify other properties
t.testing.assert_close(sparse_tensor.values, nonzero_values)
t.testing.assert_close(sparse_tensor.indices, nonzero_indices)

tests.test_sparse_tensor(SparseTensor)
Solution
class SparseTensor:
    """
    Handles 2D tensor data (assumed to be non-negative) in 2 different formats:
        dense:  The full tensor, which contains zeros. Shape is (n1, ..., nk).
        sparse: A tuple of nonzero values with shape (n_nonzero,), nonzero indices with shape
                (n_nonzero, k), and the shape of the dense tensor.
    """

    sparse: tuple[Tensor, Tensor, tuple[int, ...]]
    dense: Tensor

    def __init__(self, sparse: tuple[Tensor, Tensor, tuple[int, ...]], dense: Tensor):
        self.sparse = sparse
        self.dense = dense

    @classmethod
    def from_dense(cls, dense: Tensor) -> "SparseTensor":
        """Creates a SparseTensor from a dense tensor, extracting the positive entries."""
        sparse = (dense[dense > 0], t.argwhere(dense > 0), tuple(dense.shape))
        return cls(sparse, dense)

    @classmethod
    def from_sparse(cls, sparse: tuple[Tensor, Tensor, tuple[int, ...]]) -> "SparseTensor":
        """Creates a SparseTensor from a (values, indices, shape) tuple, reconstructing the dense tensor."""
        nonzero_values, nonzero_indices, shape = sparse
        dense = t.zeros(shape, dtype=nonzero_values.dtype, device=nonzero_values.device)
        dense[nonzero_indices.unbind(-1)] = nonzero_values
        return cls(sparse, dense)

    @property
    def values(self) -> Tensor:
        return self.sparse[0].squeeze()

    @property
    def indices(self) -> Tensor:
        return self.sparse[1].squeeze()

    @property
    def shape(self) -> tuple[int, ...]:
        return self.sparse[2]

Exercise (2/3) - implement latent_acts_to_later_latent_acts

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-20 minutes on this exercise.

Next, you should implement the latent_acts_to_later_latent_acts. This takes latent activations earlier in the model (in a sparse form, i.e. tuple of (values, indices, shape)) and returns the downstream latent activations as a tuple of (sparse_values, (dense_values,)).

Why do we return 2 copies of latent_acts_next in this strange way? The answer is that we'll be wrapping our function with t.func.jacrev(latent_acts_to_later_latent_acts, has_aux=True). The has_aux argument allows us to return a tuple of tensors which won't be differentiated. In other words, it takes a tensor -> (tensor, tuple_of_tensors) function f(x) = (y, aux) and returns the function g(x) = (J, aux) where J[i, j] = d(f[x]_i) / d(x_j). In other words, we're getting both the Jacobian and the actual reconstructed activations.

Note on what gradients we're actually computing

Eagle-eyed readers might have noticed that what we're actually doing here is not computing gradients between later and earlier latents, but computing the gradient between reconstructed later latents and earlier latents. In other words, the later latents we're differentiating are actually a function of the earlier SAE's residual stream reconstruction, rather than the actual residual stream. This is a bit risky when drawing conclusions from the results, because if your earlier SAE isn't very good at reconstructing its input then you might miss out on ways in which downstream latents are affected by upstream activations. A good way to sanity check this is to compare the latent activations (computed downstream of the earlier SAE's reconstructions) with the true latent activations, and make sure they're similar.

We'll get to applying the Jacobian in the 3rd exercise though - for now, you should just fill in latent_acts_to_later_latent_acts. This should essentially match the pseudocode for latent_acts_to_later_latent_acts which we gave at the start of this section (with the added factor of having to convert tensors to / from their sparse forms). Some guidance on syntax you'll find useful:

  • All SAEs have encode and decode methods, which map from input -> latent activations -> reconstructed input.
  • All TransformerLens models have a forward method with optional arguments start_at_layer and stop_at_layer, if these are supplied then it will compute activations from the latter layer as a function of the former.
def latent_acts_to_later_latent_acts(
    latent_acts_nonzero: Float[Tensor, " nonzero_acts"],
    latent_acts_nonzero_inds: Int[Tensor, "nonzero_acts n_indices"],
    latent_acts_shape: tuple[int, ...],
    sae_from: SAE,
    sae_to: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Given some latent activations for a residual stream SAE earlier in the model, computes the
    latent activations of a later SAE. It does this by mapping the latent activations through the
    path SAE decoder -> intermediate model layers -> later SAE encoder.

    This function must input & output sparse information (i.e. nonzero values and their indices)
    rather than dense tensors, because latent activations are sparse but jacrev() doesn't support
    gradients on real sparse tensors.
    """
    # ... YOUR CODE HERE ...

    return latent_acts_next_recon.sparse[0], (latent_acts_next_recon.dense,)
Solution
def latent_acts_to_later_latent_acts(
    latent_acts_nonzero: Float[Tensor, " nonzero_acts"],
    latent_acts_nonzero_inds: Int[Tensor, "nonzero_acts n_indices"],
    latent_acts_shape: tuple[int, ...],
    sae_from: SAE,
    sae_to: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Given some latent activations for a residual stream SAE earlier in the model, computes the
    latent activations of a later SAE. It does this by mapping the latent activations through the
    path SAE decoder -> intermediate model layers -> later SAE encoder.

    This function must input & output sparse information (i.e. nonzero values and their indices)
    rather than dense tensors, because latent activations are sparse but jacrev() doesn't support
    gradients on real sparse tensors.
    """
    # Convert to dense, map through SAE decoder
    latent_acts = SparseTensor.from_sparse((latent_acts_nonzero, latent_acts_nonzero_inds, latent_acts_shape)).dense
    resid_stream_from = sae_from.decode(latent_acts)

    # Map through model layers
    resid_stream_next = model.forward(
        resid_stream_from,
        start_at_layer=_get_hook_layer(sae_from),
        stop_at_layer=_get_hook_layer(sae_to),
    )

    # Map through SAE encoder, and turn back into SparseTensor
    latent_acts_next_recon = sae_to.encode(resid_stream_next)
    latent_acts_next_recon = SparseTensor.from_dense(latent_acts_next_recon)

    return latent_acts_next_recon.sparse[0], (latent_acts_next_recon.dense,)

Exercise (3/3) - implement latent_to_latent_gradients

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 20-40 minutes on this exercise.

Finally, we're in the position to implement our full latent_to_latent_gradients function. This function should:

  • Compute the true latent activations for both SAEs, using run_with_cache_with_saes (make sure you set sae_from.use_error_term = True, because you want to compute the true latent activations for the later SAE, not those which are computed from the earlier SAE's reconstructions!)
  • Wrap your function latent_acts_to_later_latent_acts to create a function that will return the Jacobian and the later latent activations (code in a dropdown below if you're confused about what this looks like),
  • Call this function to return the Jacobian and the later latent activations,
  • Return the Jacobian and earlier/later latent activations (the latter as SparseTensor objects).
Code for the Jacobian wrapper
latent_acts_to_later_latent_acts_and_gradients = t.func.jacrev(
    latent_acts_to_later_latent_acts, argnums=0, has_aux=True
)

The argnums=0 argument tells jacrev to take the Jacobian with respect to the first argument of latent_acts_to_later_latent_acts, and the has_aux=True argument tells it to also return the auxiliary outputs of latent_acts_to_later_latent_acts (i.e. the tuple of tensors which are the second output of the base function).

You can call this function using:

latent_latent_gradients, (latent_acts_next_recon_dense,) = latent_acts_to_later_latent_acts_and_gradients(
    *latent_acts_prev.sparse, sae_from, sae_to, model
)
Help - I'm getting OOM errors

OOM errors are common when you pass in tensors which aren't the sparsified versions (because computing a 2D matrix of derivatives of 10k+ elements is pretty memory intensive!). We recommend you look at the amount of memory being asked for when you get errors; if it's 30GB+ then you're almost certainly making this mistake.

If you're still getting errors, we recommend you inspect and clear your memory. In particular, loading large models like Gemma onto the GPU will be taking up space that you no longer need. We've provided some util functions for this purpose (we give examples of how to use them at the very start of this notebook, before the first set of exercises, under the header "A note on memory usage").

If all this still doesn't work (i.e. you still get errors after clearing memory), we recommend you try a virtual machine (e.g. vastai) or Colab notebook.

We've given you code below this function, which will run and create a heatmap of the gradients for you. Note, the plot axes use notation of F{layer}.{latent_idx} for the latents.

Challenge - can you find a pair of latents which seem to form a circuit on bigrams consisting of tokenized words where the first token is " E" ?

def latent_to_latent_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae_from: SAE,
    sae_to: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, SparseTensor, SparseTensor, SparseTensor]:
    """
    Computes the gradients between all active pairs of latents belonging to two SAEs.

    Returns:
        latent_latent_gradients:    The gradients between all active pairs of latents
        latent_acts_prev:           The latent activations of the first SAE
        latent_acts_next:           The latent activations of the second SAE
        latent_acts_next_recon:     The reconstructed latent activations of the second SAE (i.e.
                                    based on the first SAE's reconstructions)
    """
    # ... YOUR CODE HERE ...

    return (
        latent_latent_gradients,
        latent_acts_prev,
        latent_acts_next,
        latent_acts_next_recon,
    )


tests.test_latent_to_latent_gradients(latent_to_latent_gradients, gpt2, gpt2_saes)
Solution
def latent_to_latent_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae_from: SAE,
    sae_to: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, SparseTensor, SparseTensor, SparseTensor]:
    """
    Computes the gradients between all active pairs of latents belonging to two SAEs.

    Returns:
        latent_latent_gradients:    The gradients between all active pairs of latents
        latent_acts_prev:           The latent activations of the first SAE
        latent_acts_next:           The latent activations of the second SAE
        latent_acts_next_recon:     The reconstructed latent activations of the second SAE (i.e.
                                    based on the first SAE's reconstructions)
    """
    acts_prev_name = f"{sae_from.cfg.metadata.hook_name}.hook_sae_acts_post"
    acts_next_name = f"{sae_to.cfg.metadata.hook_name}.hook_sae_acts_post"
    sae_from.use_error_term = True  # so we can get both true latent acts at once

    with t.no_grad():
        # Get the true activations for both SAEs
        _, cache = model.run_with_cache_with_saes(
            tokens,
            names_filter=[acts_prev_name, acts_next_name],
            stop_at_layer=_get_hook_layer(sae_to) + 1,
            saes=[sae_from, sae_to],
            remove_batch_dim=False,
        )
        latent_acts_prev = SparseTensor.from_dense(cache[acts_prev_name])
        latent_acts_next = SparseTensor.from_dense(cache[acts_next_name])

    # Compute jacobian between earlier and later latent activations (and also get the activations
    # of the later SAE which are downstream of the earlier SAE's reconstructions)
    latent_latent_gradients, (latent_acts_next_recon_dense,) = t.func.jacrev(
        latent_acts_to_later_latent_acts, has_aux=True
    )(
        *latent_acts_prev.sparse,
        sae_from,
        sae_to,
        model,
    )

    latent_acts_next_recon = SparseTensor.from_dense(latent_acts_next_recon_dense)

    # Set SAE state back to default
    sae_from.use_error_term = False

    return (
        latent_latent_gradients,
        latent_acts_prev,
        latent_acts_next,
        latent_acts_next_recon,
    )

Before running the code below to visualize the gradient heatmap, think about what you expect to see. The heatmap shows gradients between SAE latents in layer 0 and SAE latents in layer 3, for the prompt "The Eiffel tower is in Paris".

  • Do you expect most entries to be zero or nonzero? Why?
  • Which pairs of (source position, destination position) are most likely to have large gradient values?
  • Would you expect the gradient matrix to be dense or sparse?

Run the cell below and compare with your predictions.

prompt = "The Eiffel tower is in Paris"
tokens = gpt2.to_tokens(prompt)
str_toks = gpt2.to_str_tokens(prompt)
layer_from = 0
layer_to = 3

# Get latent-to-latent gradients
t.cuda.empty_cache()
with t.enable_grad():
    (
        latent_latent_gradients,
        latent_acts_prev,
        latent_acts_next,
        latent_acts_next_recon,
    ) = latent_to_latent_gradients(tokens, gpt2_saes[layer_from], gpt2_saes[layer_to], gpt2)

# Verify that ~the same latents are active in both, and the MSE loss is small
nonzero_latents = [tuple(x) for x in latent_acts_next.indices.tolist()]
nonzero_latents_recon = [tuple(x) for x in latent_acts_next_recon.indices.tolist()]
alive_in_one_not_both = set(nonzero_latents) ^ set(nonzero_latents_recon)
print(f"# nonzero latents (true): {len(nonzero_latents)}")
print(f"# nonzero latents (reconstructed): {len(nonzero_latents_recon)}")
print(f"# latents alive in one but not both: {len(alive_in_one_not_both)}")

fig = px.imshow(
    to_numpy(latent_latent_gradients.T),
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    x=[f"F{layer_to}.{latent}, {str_toks[seq]!r} ({seq})" for (_, seq, latent) in latent_acts_next_recon.indices],
    y=[f"F{layer_from}.{latent}, {str_toks[seq]!r} ({seq})" for (_, seq, latent) in latent_acts_prev.indices],
    labels={"x": f"To layer {layer_to}", "y": f"From layer {layer_from}"},
    title=f'Gradients between SAE latents in layer {layer_from} and SAE latents in layer {layer_to}<br><sup>   Prompt: "{"".join(str_toks)}"</sup>',
    width=1600,
    height=1000,
)
fig.show()
Click to see the expected output
# nonzero latents (true): 181
# nonzero latents (reconstructed): 179
# latents alive in one but not both: 8
Some observations

Many of the nonzero gradients are for pairs of tokens which fire on the same token. For example, (F0.9449, " Paris") -> (F3.385, " Paris") seems like it could just be a similar feature in 2 different layers:

display_dashboard(sae_id="blocks.0.hook_resid_pre", latent_idx=9449)
display_dashboard(sae_id="blocks.3.hook_resid_pre", latent_idx=385)

There aren't as many cross-token gradients. One of the most notable is (F0.16911, " E") -> (F3.15266, "iff") which seems like it could be a bigram circuit for words which start with " E":

display_dashboard(sae_id="blocks.0.hook_resid_pre", latent_idx=16911)
display_dashboard(sae_id="blocks.3.hook_resid_pre", latent_idx=15266)

Exercise - get latent-to-token gradients

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵⚪⚪⚪
You should spend up to 30-40 minutes on this exercise.

Now that we've worked through implementing latent-to-latent gradients, let's try doing the whole thing again, but instead computing the gradients between all input tokens and a particular SAE's latents.

You might be wondering what gradients between tokens and latents even mean, because tokens aren't scalar values. The answer is that we'll be multiplying the model's embeddings by some scale factor s (i.e. a vector of different scale factor values for each token in our sequence), and taking the gradient of the SAE's latents wrt these values s, evaluated at s = [1, 1, ..., 1]. This isn't super principled since in practice this kind of embedding vector scaling doesn't happen in our model, but it's a convenient way to get a sense of which tokens are most important for which latents.

Challenge - take the pair of latents from the previous exercise which seemed to form a circuit on bigrams consisting of tokenized words where the first token is " E". Can you find that circuit again, from this plot?

def tokens_to_latent_acts(
    token_scales: Float[Tensor, "batch seq"],
    tokens: Int[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Given scale factors for model's embeddings (i.e. scale factors applied after we compute the sum
    of positional and token embeddings), returns the SAE's latents.

    Returns:
        latent_acts_sparse: The SAE's latents in sparse form (i.e. the tensor of values)
        latent_acts_dense:  The SAE's latents in dense tensor, in a length-1 tuple
    """
    # ... YOUR CODE HERE ...

    return sae_latents.sparse[0], (sae_latents.dense,)


def token_to_latent_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, SparseTensor]:
    """
    Computes the gradients between an SAE's latents and all input tokens.

    Returns:
        token_latent_grads: The gradients between input tokens and SAE latents
        latent_acts:        The SAE's latent activations
    """
    # ... YOUR CODE HERE ...

    return (token_latent_grads, latent_acts)


tests.test_token_to_latent_gradients(token_to_latent_gradients, gpt2, gpt2_saes)

sae_layer = 3
token_latent_grads, latent_acts = token_to_latent_gradients(tokens, sae=gpt2_saes[sae_layer], model=gpt2)

fig = px.imshow(
    to_numpy(token_latent_grads[0]),
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    x=[f"F{sae_layer}.{latent:05}, {str_toks[seq]!r} ({seq})" for (_, seq, latent) in latent_acts.indices],
    y=[f"{str_toks[i]!r} ({i})" for i in range(len(str_toks))],
    labels={"x": f"To layer {sae_layer}", "y": "From tokens"},
    title=f'Gradients between input tokens and SAE latents in layer {sae_layer}<br><sup>   Prompt: "{"".join(str_toks)}"</sup>',
    width=1900,
    height=450,
)
fig.show()
Click to see the expected output
Some observations

In the previous exercise, we saw gradients between (F0.16911, " E") -> (F3.15266, "iff"), which seems like it could be forming a bigram circuit for words which start with " E".

In this plot, we can see a gradient between the " E" token and feature F3.15266, which is what we'd expect based on this.

Solution
def tokens_to_latent_acts(
    token_scales: Float[Tensor, "batch seq"],
    tokens: Int[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Given scale factors for model's embeddings (i.e. scale factors applied after we compute the sum
    of positional and token embeddings), returns the SAE's latents.

    Returns:
        latent_acts_sparse: The SAE's latents in sparse form (i.e. the tensor of values)
        latent_acts_dense:  The SAE's latents in dense tensor, in a length-1 tuple
    """
    resid_after_embed = model(tokens, stop_at_layer=0)
    resid_after_embed = einops.einsum(resid_after_embed, token_scales, "... seq d_model, ... seq -> ... seq d_model")
    resid_before_sae = model(resid_after_embed, start_at_layer=0, stop_at_layer=_get_hook_layer(sae))

    sae_latents = sae.encode(resid_before_sae)
    sae_latents = SparseTensor.from_dense(sae_latents)

    return sae_latents.sparse[0], (sae_latents.dense,)


def token_to_latent_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, SparseTensor]:
    """
    Computes the gradients between an SAE's latents and all input tokens.

    Returns:
        token_latent_grads: The gradients between input tokens and SAE latents
        latent_acts:        The SAE's latent activations
    """
    # Find the gradients from token positions to latents
    token_scales = t.ones(tokens.shape, device=model.cfg.device, requires_grad=True)
    token_latent_grads, (latent_acts_dense,) = t.func.jacrev(tokens_to_latent_acts, has_aux=True)(
        token_scales, tokens, sae, model
    )

    token_latent_grads = einops.rearrange(token_latent_grads, "d_sae_nonzero batch seq -> batch seq d_sae_nonzero")

    latent_acts = SparseTensor.from_dense(latent_acts_dense)

    return (token_latent_grads, latent_acts)

Exercise - get latent-to-logit gradients

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵⚪⚪⚪
You should spend up to 30-45 minutes on this exercise.

Finally, we'll have you compute the latent-to-logit gradients. This exercise will be quite similar to the first one (i.e. the latent-to-latent gradients), but this time the function you pass through jacrev will map SAE activations to logits, rather than to a later SAE's latents. Note that we've given you the argument k to specify only a certain number of top logits to take gradients for (because otherwise you might be computing a large gradient matrix, which could cause an OOM error).

Why are we bothering to compute latent-to-logit gradients in the first place? Well, one obvious answer is that it completes our end-to-end circuits picture (we've now got tokens -> latents -> other latents -> logits). But to give another answer, we can consider latents as having a dual nature: when looking back towards the input, they are representations, but when looking forward towards the logits, they are actions. We might expect sparsity in both directions, in other words not only should latents sparsely represent the activations produced by the input, they should also sparsely affect the gradients influencing the output. As you'll see when doing these exercises, this is partially the case, although not to the same degree as we saw extremely sparse token-to-latent or latent-to-latent gradients. One reason for that is that sparsity as representations is already included in the SAE's loss function (the L1 penalty), but there's no explicit penalty term to encourage sparsity in the latent gradients. Anthropic propose what this might look like in their April 2024 update.

However, despite the results for latent-to-logit gradients being less sparse than the last 2 exercises, they can still teach us a lot about which latents are important for a particular input prompt. Fill in the functions below, and then play around with some latent-to-logit gradients yourself!

def latent_acts_to_logits(
    latent_acts_nonzero: Float[Tensor, " nonzero_acts"],
    latent_acts_nonzero_inds: Int[Tensor, "nonzero_acts n_indices"],
    latent_acts_shape: tuple[int, ...],
    sae: SAE,
    model: HookedSAETransformer,
    token_ids: list[int] | None = None,
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Computes the logits as a downstream function of the SAE's reconstructed residual stream. If we
    supply `token_ids`, it means we only compute & return the logits for those specified tokens.
    """
    ...
    return logits_recon[token_ids], (logits_recon,)


def latent_to_logit_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
    k: int | None = None,
) -> tuple[Tensor, Tensor, Tensor, list[int] | None, SparseTensor]:
    """
    Computes the gradients between active latents and some top-k set of logits (we
    use k to avoid having to compute the gradients for all tokens).

    Returns:
        latent_logit_gradients:  The gradients between the SAE's active latents & downstream logits
        logits:                  The model's true logits
        logits_recon:            The model's reconstructed logits (i.e. based on SAE reconstruction)
        token_ids:               The tokens we computed the gradients for
        latent_acts:             The SAE's latent activations
    """
    assert tokens.shape[0] == 1, "Only supports batch size 1 for now"

    ...

    return (
        latent_logit_gradients,
        logits,
        logits_recon,
        token_ids,
        latent_acts,
    )
tests.test_latent_to_logit_gradients(latent_to_logit_gradients, gpt2, gpt2_saes)

layer = 9
prompt = "The Eiffel tower is in the city of"
answer = " Paris"

tokens = gpt2.to_tokens(prompt, prepend_bos=True)
str_toks = gpt2.to_str_tokens(prompt, prepend_bos=True)
k = 25

# Test the model on this prompt, with & without SAEs
test_prompt(prompt, answer, gpt2)

# How about the reconstruction? More or less; it's rank 20 so still decent
gpt2_saes[layer].use_error_term = False
with gpt2.saes(saes=[gpt2_saes[layer]]):
    test_prompt(prompt, answer, gpt2)

latent_logit_grads, logits, logits_recon, token_ids, latent_acts = latent_to_logit_gradients(
    tokens, sae=gpt2_saes[layer], model=gpt2, k=k
)

# sort by most positive in " Paris" direction
sorted_indices = latent_logit_grads[0].argsort(descending=True)
latent_logit_grads = latent_logit_grads[:, sorted_indices]

fig = px.imshow(
    to_numpy(latent_logit_grads),
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    x=[
        f"{str_toks[seq]!r} ({seq}), latent {latent:05}" for (_, seq, latent) in latent_acts.indices[sorted_indices]
    ],
    y=[f"{tok!r} ({gpt2.to_single_str_token(tok)})" for tok in token_ids],
    labels={"x": f"Features in layer {layer}", "y": "Logits"},
    title=f'Gradients between SAE latents in layer {layer} and final logits (only showing top {k} logits)<br><sup>   Prompt: "{"".join(str_toks)}"</sup>',
    width=1900,
    height=800,
    aspect="auto",
)
fig.show()
Click to see the expected output
Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' tower', ' is', ' in', ' the', ' city', ' of']
Tokenized answer: [' Paris']

Performance on answer token:
Rank: 0        Logit: 14.83 Prob:  4.76% Token: | Paris|

Top 0th token. Logit: 14.83 Prob:  4.76% Token: | Paris|
Top 1th token. Logit: 14.63 Prob:  3.90% Token: | London|
Top 2th token. Logit: 14.47 Prob:  3.32% Token: | Amsterdam|
Top 3th token. Logit: 14.02 Prob:  2.11% Token: | Berlin|
Top 4th token. Logit: 13.90 Prob:  1.87% Token: | L|
Top 5th token. Logit: 13.85 Prob:  1.78% Token: | E|
Top 6th token. Logit: 13.77 Prob:  1.64% Token: | Hamburg|
Top 7th token. Logit: 13.75 Prob:  1.61% Token: | B|
Top 8th token. Logit: 13.61 Prob:  1.40% Token: | Cologne|
Top 9th token. Logit: 13.58 Prob:  1.36% Token: | St|

Ranks of the answer tokens: [(' Paris', 0)]

Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' tower', ' is', ' in', ' the', ' city', ' of'] Tokenized answer: [' Paris'] Performance on answer token: Rank: 9 Logit: 13.20 Prob: 1.18% Token: | Paris| Top 0th token. Logit: 13.83 Prob: 2.23% Token: | New| Top 1th token. Logit: 13.74 Prob: 2.03% Token: | London| Top 2th token. Logit: 13.71 Prob: 1.96% Token: | San| Top 3th token. Logit: 13.70 Prob: 1.94% Token: | Chicago| Top 4th token. Logit: 13.59 Prob: 1.75% Token: | E| Top 5th token. Logit: 13.47 Prob: 1.54% Token: | Berlin| Top 6th token. Logit: 13.28 Prob: 1.27% Token: | L| Top 7th token. Logit: 13.27 Prob: 1.27% Token: | Los| Top 8th token. Logit: 13.21 Prob: 1.20% Token: | St| Top 9th token. Logit: 13.20 Prob: 1.18% Token: | Paris| Ranks of the answer tokens: [(' Paris', 9)]


Some observations

We see that feature F9.22250 stands out as boosting the " Paris" token far more than any of the other top predictions. Investigation reveals this feature fires primarily on French language text, which makes sense!

We also see F9.5879 which seems to strongly boost words associated with Germany (e.g. Berlin, Hamberg, Cologne, Zurich). We see a similar pattern there, where that feature mostly fires on German-language text (or more commonly, English-language text talking about Germany).

display_dashboard(sae_id="blocks.9.hook_resid_pre", latent_idx=22250)
display_dashboard(sae_id="blocks.9.hook_resid_pre", latent_idx=5879)
Solution
def latent_acts_to_logits(
    latent_acts_nonzero: Float[Tensor, " nonzero_acts"],
    latent_acts_nonzero_inds: Int[Tensor, "nonzero_acts n_indices"],
    latent_acts_shape: tuple[int, ...],
    sae: SAE,
    model: HookedSAETransformer,
    token_ids: list[int] | None = None,
) -> tuple[Tensor, tuple[Tensor]]:
    """
    Computes the logits as a downstream function of the SAE's reconstructed residual stream. If we
    supply `token_ids`, it means we only compute & return the logits for those specified tokens.
    """
    # Convert to dense, map through SAE decoder
    latent_acts = SparseTensor.from_sparse((latent_acts_nonzero, latent_acts_nonzero_inds, latent_acts_shape)).dense

    resid = sae.decode(latent_acts)

    # Map through model layers, to the end
    logits_recon = model(resid, start_at_layer=_get_hook_layer(sae))[0, -1]

    return logits_recon[token_ids], (logits_recon,)


def latent_to_logit_gradients(
    tokens: Float[Tensor, "batch seq"],
    sae: SAE,
    model: HookedSAETransformer,
    k: int | None = None,
) -> tuple[Tensor, Tensor, Tensor, list[int] | None, SparseTensor]:
    """
    Computes the gradients between active latents and some top-k set of logits (we
    use k to avoid having to compute the gradients for all tokens).

    Returns:
        latent_logit_gradients:  The gradients between the SAE's active latents & downstream logits
        logits:                  The model's true logits
        logits_recon:            The model's reconstructed logits (i.e. based on SAE reconstruction)
        token_ids:               The tokens we computed the gradients for
        latent_acts:             The SAE's latent activations
    """
    assert tokens.shape[0] == 1, "Only supports batch size 1 for now"

    acts_hook_name = f"{sae.cfg.metadata.hook_name}.hook_sae_acts_post"
    sae.use_error_term = True

    with t.no_grad():
        # Run model up to the position of the first SAE to get those residual stream activations
        logits, cache = model.run_with_cache_with_saes(
            tokens,
            names_filter=[acts_hook_name],
            saes=[sae],
            remove_batch_dim=False,
        )
        latent_acts = cache[acts_hook_name]
        latent_acts = SparseTensor.from_dense(latent_acts)

        logits = logits[0, -1]

    # Get the tokens we'll actually compute gradients for
    token_ids = None if k is None else logits.topk(k=k).indices.tolist()

    # Compute jacobian between latent acts and logits
    latent_logit_gradients, (logits_recon,) = t.func.jacrev(latent_acts_to_logits, has_aux=True)(
        *latent_acts.sparse, sae, model, token_ids
    )

    sae.use_error_term = False

    return (
        latent_logit_gradients,
        logits,
        logits_recon,
        token_ids,
        latent_acts,
    )

Exercise (optional) - find induction circuits in attention SAEs

Difficulty: 🔴🔴🔴🔴🔴
Importance: 🔵🔵🔵⚪⚪
You should spend up to 30-60 minutes on this exercise.

You can study MLP or attention latents in much the same way as you've studied residual stream latents, for any of the last 3 sets of exercises. For example, if we wanted to compute the gradient of logits or some later SAE with respect to an earlier SAE which was trained on the MLP layer, we just replace the MLP layer's activations with that earlier SAE's decoded activations, and then compute the Jacobian of our downstream values wrt this earlier SAE's activations. Note, we can use something like model.run_with_hooks to perform this replacement, without having to manually perform every step of the model's forward pass ourselves.

Try and write a version of the latent_to_latent_gradients_attn function which works for attention SAEs (docstring below). Can you use this function to find previous token latents & induction latents which have non-zero gradients with respect to each other, and come together to form induction circuits?

Hint - where you should be looking

Start by generating a random sequence of tokens, and using circuitvis to visualize the attention patterns:

import circuitsvis as cv

seq_len = 10
tokens = t.randint(0, model.cfg.d_vocab, (1, seq_len)).tolist()[0]
tokens = t.tensor([model.tokenizer.bos_token_id] + tokens + tokens)

_, cache = model.run_with_cache(tokens)

prev_token_heads = [(4, 11)]
induction_heads = [(5, 1), (5, 5), (5, 8)]
all_heads = prev_token_heads + induction_heads

html = cv.attention.attention_patterns(
    tokens=model.to_str_tokens(tokens),
    attention=t.stack([cache["pattern", layer][0, head] for layer, head in all_heads]),
    attention_head_names=[f"{layer}.{head}" for (layer, head) in all_heads],
)
display(html)

You can see from this that layer 4 contains a clear previous token head, and 5 contains several induction heads. So you might want to look at latent-to-latent gradients between layer 4 and layer 5.

Remember - the induction circuit works by having sequences AB...AB, where the previous token head attends from the first B back to the first A, then the induction head attends from the second A back to the first B. Keep this in mind when you're looking the for evidence of an induction circuit in the latent-to-latent gradients heatmap.

The code below this function plots latent-to-latent gradients, and it also adds black squares to the instances where a layer-4 feature fires on the first B and layer-5 feature fires on the second A, in the AB...AB pattern (this is what we expect from an induction circuit). In other words, if your function is working then you should see a pattern of nonzero values in the regions indicated by these squares.

def latent_acts_to_later_latent_acts_attn(
    latent_acts_nonzero: Float[Tensor, " nonzero_acts"],
    latent_acts_nonzero_inds: Int[Tensor, "nonzero_acts n_indices"],
    latent_acts_shape: tuple[int, ...],
    sae_from: SAE,
    sae_to: SAE,
    model: HookedSAETransformer,
    resid_pre_clean: Tensor,
) -> tuple[Tensor, Tensor]:
    """
    Returns the latent activations of an attention SAE, computed downstream of an earlier SAE's
    output (whose values are given in sparse form as the first three arguments).

    `resid_pre_clean` is also supplied, i.e. these are the input values to the attention layer in
    which the earlier SAE is applied.
    """
    ...

    return latent_acts_next_recon.sparse[0], (latent_acts_next_recon.dense,)


def latent_to_latent_gradients_attn(
    tokens: Float[Tensor, "batch seq"],
    sae_from: SAE,
    sae_to: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, SparseTensor, SparseTensor, SparseTensor]:
    """
    Computes the gradients between all active pairs of latents belonging to two SAEs. Both SAEs
    are assumed to be attention SAEs, i.e. they take the concatenated z values as input.

    Returns:
        latent_latent_gradients:  The gradients between all active pairs of latents
        latent_acts_prev:          The latent activations of the first SAE
        latent_acts_next:          The latent activations of the second SAE
        latent_acts_next_recon:    The reconstructed latent activations of the second SAE
    """
    ...

    return (
        latent_latent_gradients,
        latent_acts_prev,
        latent_acts_next,
        latent_acts_next_recon,
    )
attn_saes = {
    layer: SAE.from_pretrained(
        "gpt2-small-hook-z-kk",
        f"blocks.{layer}.hook_z",
        device=device,
        dtype=dtype,
    )
    for layer in range(gpt2.cfg.n_layers)
}

seq_len = 10  # higher seq len / more batches would be more reliable, but this simplifies the plot
tokens = t.randint(0, gpt2.cfg.d_vocab, (1, seq_len)).tolist()[0]
tokens = t.tensor([gpt2.tokenizer.bos_token_id] + tokens + tokens)
str_toks = gpt2.to_str_tokens(tokens)
layer_from = 4
layer_to = 5

# Get latent-to-latent gradients
with t.enable_grad():
    (
        latent_latent_gradients,
        latent_acts_prev,
        latent_acts_next,
        latent_acts_next_recon,
    ) = latent_to_latent_gradients_attn(tokens, attn_saes[layer_from], attn_saes[layer_to], gpt2)

# Verify that ~the same latents are active in both, and the MSE loss is small
nonzero_latents = [tuple(x) for x in latent_acts_next.indices.tolist()]
nonzero_latents_recon = [tuple(x) for x in latent_acts_next_recon.indices.tolist()]
alive_in_one_not_both = set(nonzero_latents) ^ set(nonzero_latents_recon)
print(f"# nonzero latents (true): {len(nonzero_latents)}")
print(f"# nonzero latents (reconstructed): {len(nonzero_latents_recon)}")
print(f"# latents alive in one but not both: {len(alive_in_one_not_both)}")

# Create initial figure
fig = px.imshow(
    to_numpy(latent_latent_gradients.T),
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    x=[f"F{layer_to}.{latent}, {str_toks[seq]!r} ({seq})" for (_, seq, latent) in latent_acts_next_recon.indices],
    y=[f"F{layer_from}.{latent}, {str_toks[seq]!r} ({seq})" for (_, seq, latent) in latent_acts_prev.indices],
    labels={"y": f"From layer {layer_from}", "x": f"To layer {layer_to}"},
    title=f'Gradients between SAE latents in layer {layer_from} and SAE latents in layer {layer_to}<br><sup>   Prompt: "{"".join(str_toks)}"</sup>',
    width=1200,
    height=1000,
)
# Add rectangles to it, to cover the blocks where the layer 4 & 5 positions correspond to what we
# expect for the induction circuit
for first_B_posn in range(2, seq_len + 2):
    second_A_posn = first_B_posn + seq_len - 1
    x0 = (latent_acts_next_recon.indices[:, 1] < second_A_posn).sum().item()
    x1 = (latent_acts_next_recon.indices[:, 1] <= second_A_posn).sum().item()
    y0 = (latent_acts_prev.indices[:, 1] < first_B_posn).sum().item()
    y1 = (latent_acts_prev.indices[:, 1] <= first_B_posn).sum().item()
    fig.add_shape(type="rect", x0=x0, y0=y0, x1=x1, y1=y1)

fig.show()
Click to see the expected output
# nonzero latents (true): 476
# nonzero latents (reconstructed): 492
# latents alive in one but not both: 150
Solution
def latent_acts_to_later_latent_acts_attn(
    latent_acts_nonzero: Float[Tensor, " nonzero_acts"],
    latent_acts_nonzero_inds: Int[Tensor, "nonzero_acts n_indices"],
    latent_acts_shape: tuple[int, ...],
    sae_from: SAE,
    sae_to: SAE,
    model: HookedSAETransformer,
    resid_pre_clean: Tensor,
) -> tuple[Tensor, Tensor]:
    """
    Returns the latent activations of an attention SAE, computed downstream of an earlier SAE's
    output (whose values are given in sparse form as the first three arguments).

    `resid_pre_clean` is also supplied, i.e. these are the input values to the attention layer in
    which the earlier SAE is applied.
    """
    # Convert to dense, map through SAE decoder
    latent_acts = SparseTensor.from_sparse((latent_acts_nonzero, latent_acts_nonzero_inds, latent_acts_shape)).dense
    z_recon = sae_from.decode(latent_acts)

    hook_name_z_prev = get_act_name("z", _get_hook_layer(sae_from))
    hook_name_z_next = get_act_name("z", _get_hook_layer(sae_to))

    def hook_set_z_prev(z: Tensor, hook: HookPoint):
        return z_recon

    def hook_store_z_next(z: Tensor, hook: HookPoint):
        hook.ctx["z"] = z

    # fwd pass: replace earlier z with SAE reconstructions, and store later z (no SAEs needed yet)
    model.run_with_hooks(
        resid_pre_clean,
        start_at_layer=_get_hook_layer(sae_from),
        stop_at_layer=_get_hook_layer(sae_to) + 1,
        fwd_hooks=[
            (hook_name_z_prev, hook_set_z_prev),
            (hook_name_z_next, hook_store_z_next),
        ],
    )
    z = model.hook_dict[hook_name_z_next].ctx.pop("z")
    latent_acts_next_recon = SparseTensor.from_dense(sae_to.encode(z))

    return latent_acts_next_recon.sparse[0], (latent_acts_next_recon.dense,)


def latent_to_latent_gradients_attn(
    tokens: Float[Tensor, "batch seq"],
    sae_from: SAE,
    sae_to: SAE,
    model: HookedSAETransformer,
) -> tuple[Tensor, SparseTensor, SparseTensor, SparseTensor]:
    """
    Computes the gradients between all active pairs of latents belonging to two SAEs. Both SAEs
    are assumed to be attention SAEs, i.e. they take the concatenated z values as input.

    Returns:
        latent_latent_gradients:  The gradients between all active pairs of latents
        latent_acts_prev:          The latent activations of the first SAE
        latent_acts_next:          The latent activations of the second SAE
        latent_acts_next_recon:    The reconstructed latent activations of the second SAE
    """
    resid_pre_name = get_act_name("resid_pre", _get_hook_layer(sae_from))
    acts_prev_name = f"{sae_from.cfg.metadata.hook_name}.hook_sae_acts_post"
    acts_next_name = f"{sae_to.cfg.metadata.hook_name}.hook_sae_acts_post"
    sae_from.use_error_term = True  # so we can get both true latent acts at once
    sae_to.use_error_term = True  # so we can get both true latent acts at once

    with t.no_grad():
        # Get the true activations for both SAEs
        _, cache = model.run_with_cache_with_saes(
            tokens,
            names_filter=[resid_pre_name, acts_prev_name, acts_next_name],
            stop_at_layer=_get_hook_layer(sae_to) + 1,
            saes=[sae_from, sae_to],
            remove_batch_dim=False,
        )
        latent_acts_prev = SparseTensor.from_dense(cache[acts_prev_name])
        latent_acts_next = SparseTensor.from_dense(cache[acts_next_name])

    # Compute jacobian between earlier and later latent activations (and also get the activations
    # of the later SAE which are downstream of the earlier SAE's reconstructions)
    latent_latent_gradients, (latent_acts_next_recon_dense,) = t.func.jacrev(
        latent_acts_to_later_latent_acts_attn, has_aux=True
    )(*latent_acts_prev.sparse, sae_from, sae_to, model, cache[resid_pre_name])

    latent_acts_next_recon = SparseTensor.from_dense(latent_acts_next_recon_dense)

    # Set SAE state back to default
    sae_from.use_error_term = False
    sae_to.use_error_term = False

    return (
        latent_latent_gradients,
        latent_acts_prev,
        latent_acts_next,
        latent_acts_next_recon,
    )

Note - your SAEs should perform worse on reconstructing this data than they did on the previous exercises in this subsection (if measured in terms of the intersection of activating features). My guess is that this is because the induction sequences are more OOD for the distribution which the SAEs were trained on (since they're literally random tokens). Also, it's possible that measuring over a larger batch of data and taking all features that activate on at least some fraction of the total tokens would give us a less noisy picture.

Here's some code which filters down the layer-5 features to ones which are active on every token in the second half of the sequence, and also only looks at the layer-4 features which are active on the first half. Try inspecting individual feature pairs which seem to have large gradients with each other - do they seem like previous token features & induction features respectively?

# Filter for layer-5 latents which are active on every token in the second half (which induction
# latents should be!)
acts_on_second_half = latent_acts_next_recon.indices[latent_acts_next_recon.indices[:, 1] >= seq_len + 1]
c = Counter(acts_on_second_half[:, 2].tolist())
top_feats = sorted([feat for feat, count in c.items() if count >= seq_len])
print(f"Layer 5 SAE latents which fired on all tokens in the second half: {top_feats}")
mask_next = (latent_acts_next_recon.indices[:, 2] == t.tensor(top_feats, device=device)[:, None]).any(dim=0) & (
    latent_acts_next_recon.indices[:, 1] >= seq_len + 1
)

# Filter the layer-4 axis to only show activations at sequence positions that we expect to be used
# in induction
mask_prev = (latent_acts_prev.indices[:, 1] >= 1) & (latent_acts_prev.indices[:, 1] <= seq_len)

# Filter the y-axis, just to these
fig = px.imshow(
    to_numpy(latent_latent_gradients[mask_next][:, mask_prev]),
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    y=[
        f"{str_toks[seq]!r} ({seq}), #{latent:05}" for (_, seq, latent) in latent_acts_next_recon.indices[mask_next]
    ],
    x=[f"{str_toks[seq]!r} ({seq}), #{latent:05}" for (_, seq, latent) in latent_acts_prev.indices[mask_prev]],
    labels={"x": f"From layer {layer_from}", "y": f"To layer {layer_to}"},
    title=f'Gradients between SAE latents in layer {layer_from} and SAE latents in layer {layer_to}<br><sup>   Prompt: "{"".join(str_toks)}"</sup>',
    width=1800,
    height=500,
)
fig.show()
Click to see the expected output
Layer 5 SAE latents which fired on all tokens in the second half: [35425, 36126]
Some observations

I didn't rigorously check this (and the code needs a lot of cleaning up!) but I was able to find the 2 most prominent layer-5 features (35425, 36126) were definitely induction features, and 3/4 of the first one which strongly composed with those in layer 4 seemed like previous token features:

for (layer, latent_idx) in [(5, 35425), (5, 36126), (4, 22975), (4, 21020), (4, 23954)]:
    display_dashboard(
        sae_release="gpt2-small-hook-z-kk",
        sae_id=f"blocks.{layer}.hook_z",
        latent_idx=latent_idx,
    )