1️⃣ Intro to SAE Interpretability

Learning Objectives
  • Learn how to use the SAELens library to load in & run SAEs (alongside the TransformerLens models they're attached to)
  • Understand the basic features of Neuronpedia, and how it can be used for things like steering and searching over features
  • Understand SAE dashboards, what each part of them tells you about a particular latent (as well as how to compute them yourself)
  • Learn techniques for finding latents, including direct logit attribution, ablation and attribution patching
  • Use attention SAEs, understand how they differ from regular SAEs (as well as topics specific to attention SAEs, like direct latent attribution)
  • Learn a bit about different SAE architectures or training methods (e.g. gated, end-to-end, meta-saes, transcoders) - some of these will be covered in more detail later

To emphasize - the idea is for this section is to be a whirlwind tour of all basic SAE topics, excluding training & evals (which we'll come back to in section 4). The focus will be on how to understand & interpret SAE latents (in particular all the components of the SAE dashboard). We'll also look at techniques for finding latents (e.g. ablation & attribution methods), as well as taking a deeper dive into attention SAEs and how they work. Because there's a lot of material to cover in this section, we'll have a summary of the key points at the top of each main header section. These summaries are all included below for convenience, before we get started. As well as helping to keep you oriented as you work through the material, these should also give you an idea of which sections you can jump to if you only want to cover a few of them.

Intro to SAELens

In this section, you'll learn what SAELens is, and how to use it to load in & inspect the configs of various supported SAEs. Key points:

- SAELens is a library for training and analysing SAEs. It can be thought of as the equivalent of TransformerLens for SAEs (although it allso integrates closely with TransformerLens, as we'll see in the "Running SAEs" section) - SAELens contains many different model releases, each release containing multiple SAEs (e.g. trained on different model layers / hook points, or with different architectures) - The cfg attribute of an SAE instance contains this information, and anything else that's relevant when performing forward passes

Visualizing SAEs with dashboards

In this section, you'll learn about SAE dashboards, which are a visual tool for quickly understanding what a particular SAE latent represents. Key points:

- Neuronpedia hosts dashboards which help you understand SAE latents - The 5 main components of the dashboard are: top logit tables, logits histogram, activation density plots, top activating sequences, and autointerp - All of these components are important for getting a full picture of what a latent represents, but they can also all be misleading - You can display these dashboards inline, using IFrame

Running SAEs

In this section, you'll learn how to run forward passes with SAEs. This is a pretty simple process, which builds on much of the pre-existing infrastructure in TransformerLens models. Key points:

- You can add SAEs to a TransformerLens model when doing forward passes in pretty much the same way you add hook functions (you can think of SAEs as a special kind of hook function) - When sae.error_term=False (default) you substitute the SAE's output for the transformer activations. When True, you don't substitute (which is sometimes what you want when caching activations) - There's an analogous run_with_saes that works like run_with_hooks - There's also run_with_cache_with_saes that works like run_with_cache, but allows you to cache any SAE activations you want - You can use ActivationStore to get a large batch of activations at once

Replicating SAE dashboards

In this section, you'll replicate the 5 main components of the SAE dashboard: top logits tables, logits histogram, activation density plots, top activating sequences, and autointerp. There's not really any new content here, just putting into practice what you've learned from the previous 2 sections "Visualizing SAEs with dashboards" and "Running SAEs".

Attention SAEs

In this section, you'll learn about attention SAEs, how they work (mostly quite similar to standard SAEs but with a few other considerations), and how to understand their feature dashboards. Key points:

- Attention SAEs have the same architecture as regular SAEs, except they're trained on the concatenated pre-projection output of all attention heads. - If a latent fires on a destination token, we can use direct latent attribution to see which source tokens it primarily came from. - Just like regular SAEs, latents found in different layers of a model are often qualitatively different from each other.

Finding latents for features

In this section, you'll explore different methods (some causal, some not) for finding latents in SAEs corresponding to particular features. Key points:

- You can look at max activating latents on some particular input prompt, this is basically the simplest thing you can do - Direct logit attribution (DLA) is a bit more refined; you can find latents which have a direct effect on specific logits - Ablation of SAE latents can help you find latents which are important in a non-direct way - ...but it's quite costly for a large number of latents, so you can use attribution patching as a cheaper linear approximation of ablation

GemmaScope

This short section introduces you to DeepMind's GemmaScope series, a suite of highly performant SAEs which can be a great source of study in your own interpretability projects!

Feature steering

In this section, you'll learn how to steer on latents to produce interesting model output. Key points:

- Steering involves intervening during a forward pass to change the model's activations in the direction of a particular latent - The steering behaviour is sometimes unpredictable, and not always equivalent to "produce text of the same type as the latent strongly activates on" - Neuronpedia has a steering interface which allows you to steer without any code

Other types of SAEs

This section introduces a few different SAE architectures, some of which will be explored in more detail in later sections. There are no exercises here, just brief descriptions. Key points:

- Different activation functions / encoder architecturs e.g. TopK, JumpReLU and Gated models can solve problems like feature suppression and the pressure for SAEs to be continuous in standard models - End-to-end SAEs are trained with a different loss function, encouraging them to learn features that are functionally useful for the model's output rather than just minimising MSE reconstruction error - Transcoders are a type of SAE which learn to reconstruct a model's computation (e.g. a sparse mapping from MLP input to MLP output) rather than just reconstructing activations; they can sometimes lead to easier circuit analysis

Intro to SAELens

In this section, you'll learn what SAELens is, and how to use it to load in & inspect the configs of various supported SAEs. Key points:

  • SAELens is a library for training and analysing SAEs. It can be thought of as the equivalent of TransformerLens for SAEs (although it allso integrates closely with TransformerLens, as we'll see in the "Running SAEs" section)
  • SAELens contains many different model releases, each release containing multiple SAEs (e.g. trained on different model layers / hook points, or with different architectures)
  • The cfg attribute of an SAE instance contains this information, and anything else that's relevant when performing forward passes

SAELens is a library designed to help researchers:

  • Train sparse autoencoders,
  • Analyse sparse autoencoders / research mechanistic interpretability,
  • Generate insights which make it easier to create safe and aligned AI systems.

You can think of it as the equivalent of TransformerLens for sparse autoencoders (and it also integrates very well with TransformerLens models, which we'll see shortly).

Additionally, SAELens is closely integrated with Neuronpedia, an open platform for interpretability research developed through the joint efforts of Joseph Bloom and Johnny Lin, and which we'll be using throughout this chapter. Neuronpedia allows you to search over latents, run SAEs, and even upload your own SAEs!

Before we load our SAEs, it can be useful to see which are available. The following snippet shows the currently available SAE releases in SAELens, and will remain up-to-date as SAELens continues to add more SAEs.

print(get_pretrained_saes_directory())
{
    'gpt2-small-res-jb': PretrainedSAELookup(
        release='gpt2-small-res-jb',
        repo_id='jbloom/GPT2-Small-SAEs-Reformatted',
        model='gpt2-small',
        conversion_func=None,
        saes_map={'blocks.0.hook_resid_pre': 'blocks.0.hook_resid_pre', ..., 'blocks.11.hook_resid_post': 'blocks.11.hook_resid_post'},
        expected_var_explained={'blocks.0.hook_resid_pre': 0.999, ..., 'blocks.11.hook_resid_post': 0.77},
        expected_l0={'blocks.0.hook_resid_pre': 10.0, ..., 'blocks.11.hook_resid_post': 70.0},
        neuronpedia_id={'blocks.0.hook_resid_pre': 'gpt2-small/0-res-jb', ..., 'blocks.11.hook_resid_post': 'gpt2-small/12-res-jb'},
        config_overrides={'model_from_pretrained_kwargs': {'center_writing_weights': True}}),
    'gpt2-small-hook-z-kk': PretrainedSAELookup(
        ...
    )
}

Let's print out all this data in a more readable format, with only a subset of attributes. We'll look at model (the base model), release (the name of the SAE release), repo_id (the id of the HuggingFace repo containing the SAEs), and also the number of SAEs in each release (e.g. a release might contain an SAE trained on each layer of the base model).

metadata_rows = [
    [data.model, data.release, data.repo_id, len(data.saes_map)]
    for data in get_pretrained_saes_directory().values()
]

# Print all SAE releases, sorted by base model
print(
    tabulate(
        sorted(metadata_rows, key=lambda x: x[0]),
        headers=["model", "release", "repo_id", "n_saes"],
        tablefmt="simple_outline",
    )
)
┌─────────────────────────────────────┬─────────────────────────────────────────────────────┬────────────────────────────────────────────────────────┬──────────┐
│ model                               │ release                                             │ repo_id                                                │   n_saes │
├─────────────────────────────────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────┼──────────┤
│ gemma-2-27b                         │ gemma-scope-27b-pt-res                              │ google/gemma-scope-27b-pt-res                          │       18 │
│ gemma-2-27b                         │ gemma-scope-27b-pt-res-canonical                    │ google/gemma-scope-27b-pt-res                          │        3 │
│ gemma-2-2b                          │ gemma-scope-2b-pt-res                               │ google/gemma-scope-2b-pt-res                           │      310 │
│ gemma-2-2b                          │ gemma-scope-2b-pt-res-canonical                     │ google/gemma-scope-2b-pt-res                           │       58 │
│ gemma-2-2b                          │ gemma-scope-2b-pt-mlp                               │ google/gemma-scope-2b-pt-mlp                           │      260 │
│ gemma-2-2b                          │ gemma-scope-2b-pt-mlp-canonical                     │ google/gemma-scope-2b-pt-mlp                           │       52 │
│ gemma-2-2b                          │ gemma-scope-2b-pt-att                               │ google/gemma-scope-2b-pt-att                           │      260 │
│ gemma-2-2b                          │ gemma-scope-2b-pt-att-canonical                     │ google/gemma-scope-2b-pt-att                           │       52 │
│ gemma-2-9b                          │ gemma-scope-9b-pt-res                               │ google/gemma-scope-9b-pt-res                           │      562 │
│ gemma-2-9b                          │ gemma-scope-9b-pt-res-canonical                     │ google/gemma-scope-9b-pt-res                           │       91 │
│ gemma-2-9b                          │ gemma-scope-9b-pt-att                               │ google/gemma-scope-9b-pt-att                           │      492 │
│ gemma-2-9b                          │ gemma-scope-9b-pt-att-canonical                     │ google/gemma-scope-9b-pt-att                           │       84 │
│ gemma-2-9b                          │ gemma-scope-9b-pt-mlp                               │ google/gemma-scope-9b-pt-mlp                           │      492 │
│ gemma-2-9b                          │ gemma-scope-9b-pt-mlp-canonical                     │ google/gemma-scope-9b-pt-mlp                           │       84 │
│ gemma-2-9b                          │ gemma-scope-9b-it-res                               │ google/gemma-scope-9b-it-res                           │       30 │
│ gemma-2-9b-it                       │ gemma-scope-9b-it-res-canonical                     │ google/gemma-scope-9b-it-res                           │        6 │
│ gemma-2b                            │ gemma-2b-res-jb                                     │ jbloom/Gemma-2b-Residual-Stream-SAEs                   │        5 │
│ gemma-2b                            │ sae_bench_gemma-2-2b_sweep_standard_ctx128_ef2_0824 │ canrager/lm_sae                                        │      180 │
│ gemma-2b                            │ sae_bench_gemma-2-2b_sweep_standard_ctx128_ef8_0824 │ canrager/lm_sae                                        │      240 │
│ gemma-2b                            │ sae_bench_gemma-2-2b_sweep_topk_ctx128_ef2_0824     │ canrager/lm_sae                                        │      180 │
│ gemma-2b                            │ sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824     │ canrager/lm_sae                                        │      240 │
│ gemma-2b-it                         │ gemma-2b-it-res-jb                                  │ jbloom/Gemma-2b-IT-Residual-Stream-SAEs                │        1 │
...
│ pythia-70m-deduped                  │ pythia-70m-deduped-res-sm                           │ ctigges/pythia-70m-deduped__res-sm_processed           │        7 │
│ pythia-70m-deduped                  │ pythia-70m-deduped-mlp-sm                           │ ctigges/pythia-70m-deduped__mlp-sm_processed           │        6 │
│ pythia-70m-deduped                  │ pythia-70m-deduped-att-sm                           │ ctigges/pythia-70m-deduped__att-sm_processed           │        6 │
└─────────────────────────────────────┴─────────────────────────────────────────────────────┴────────────────────────────────────────────────────────┴──────────┘

Any given SAE release may have multiple different mdoels. These might have been trained on different hookpoints or layers in the model, or with different hyperparameters, etc. You can see the data associated with each release as follows:

def format_value(value):
    return (
        "{{{0!r}: {1!r}, ...}}".format(*next(iter(value.items())))
        if isinstance(value, dict)
        else repr(value)
    )


release = get_pretrained_saes_directory()["gpt2-small-res-jb"]

print(
    tabulate(
        [[k, format_value(v)] for k, v in release.__dict__.items()],
        headers=["Field", "Value"],
        tablefmt="simple_outline",
    )
)
┌────────────────────────┬─────────────────────────────────────────────────────────────────────────┐
│ Field                  │ Value                                                                   │
├────────────────────────┼─────────────────────────────────────────────────────────────────────────┤
│ release                │ 'gpt2-small-res-jb'                                                     │
│ repo_id                │ 'jbloom/GPT2-Small-SAEs-Reformatted'                                    │
│ model                  │ 'gpt2-small'                                                            │
│ conversion_func        │ None                                                                    │
│ saes_map               │ {'blocks.0.hook_resid_pre': 'blocks.0.hook_resid_pre', ...}             │
│ expected_var_explained │ {'blocks.0.hook_resid_pre': 0.999, ...}                                 │
│ expected_l0            │ {'blocks.0.hook_resid_pre': 10.0, ...}                                  │
│ neuronpedia_id         │ {'blocks.0.hook_resid_pre': 'gpt2-small/0-res-jb', ...}                 │
│ config_overrides       │ {'model_from_pretrained_kwargs': {'center_writing_weights': True}, ...} │
└────────────────────────┴─────────────────────────────────────────────────────────────────────────┘

Let's get some more info about each of the SAEs associated with each release. We can print out the SAE id, the path (i.e. in the HuggingFace repo, which points to the SAE model weights) and the Neuronpedia ID (which is how we'll get feature dashboards - more on this soon).

data = [[id, path, release.neuronpedia_id[id]] for id, path in release.saes_map.items()]

print(
    tabulate(
        data,
        headers=["SAE id", "SAE path (HuggingFace)", "Neuronpedia ID"],
        tablefmt="simple_outline",
    )
)
┌───────────────────────────┬───────────────────────────┬──────────────────────┐
│ SAE id                    │ SAE path (HuggingFace)    │ Neuronpedia ID       │
├───────────────────────────┼───────────────────────────┼──────────────────────┤
│ blocks.0.hook_resid_pre   │ blocks.0.hook_resid_pre   │ gpt2-small/0-res-jb  │
│ blocks.1.hook_resid_pre   │ blocks.1.hook_resid_pre   │ gpt2-small/1-res-jb  │
│ blocks.2.hook_resid_pre   │ blocks.2.hook_resid_pre   │ gpt2-small/2-res-jb  │
│ blocks.3.hook_resid_pre   │ blocks.3.hook_resid_pre   │ gpt2-small/3-res-jb  │
│ blocks.4.hook_resid_pre   │ blocks.4.hook_resid_pre   │ gpt2-small/4-res-jb  │
│ blocks.5.hook_resid_pre   │ blocks.5.hook_resid_pre   │ gpt2-small/5-res-jb  │
│ blocks.6.hook_resid_pre   │ blocks.6.hook_resid_pre   │ gpt2-small/6-res-jb  │
│ blocks.7.hook_resid_pre   │ blocks.7.hook_resid_pre   │ gpt2-small/7-res-jb  │
│ blocks.8.hook_resid_pre   │ blocks.8.hook_resid_pre   │ gpt2-small/8-res-jb  │
│ blocks.9.hook_resid_pre   │ blocks.9.hook_resid_pre   │ gpt2-small/9-res-jb  │
│ blocks.10.hook_resid_pre  │ blocks.10.hook_resid_pre  │ gpt2-small/10-res-jb │
│ blocks.11.hook_resid_pre  │ blocks.11.hook_resid_pre  │ gpt2-small/11-res-jb │
│ blocks.11.hook_resid_post │ blocks.11.hook_resid_post │ gpt2-small/12-res-jb │
└───────────────────────────┴───────────────────────────┴──────────────────────┘

Next, we'll load the SAE which we'll be working with for most of these exercises: the layer 7 resid pre model from the GPT2 Small SAEs (as well as a copy of GPT2 Small to attach it to). The SAE uses the HookedSAETransformer class, which is adapted from the TransformerLens HookedTransformer class.

Note, the SAE.from_pretrained function has return type tuple[SAE, dict, Tensor | None], with the return elements being the SAE, config dict, and a tensor of feature sparsities. The config dict contains useful metadata on e.g. how the SAE was trained (among other things).

t.set_grad_enabled(False)

gpt2: HookedSAETransformer = HookedSAETransformer.from_pretrained("gpt2-small", device=device)

gpt2_sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    device=str(device),
)

The sae object is an instance of the SAE (Sparse Autoencoder) class. There are many different SAE architectures which may have different weights or activation functions. In order to simplify working with SAEs, SAELens handles most of this complexity for you. You can run the cell below to see each of the SAE config parameters for the one we'll be using.

Click to read a description of each of the SAE config parameters.

1. architecture: Specifies the type of SAE architecture being used, in this case, the standard architecture (encoder and decoder with hidden activations, as opposed to a gated SAE). 2. d_in: Defines the input dimension of the SAE, which is 768 in this configuration. 3. d_sae: Sets the dimension of the SAE's hidden layer, which is 24576 here. This represents the number of possible feature activations. 4. activation_fn_str: Specifies the activation function used in the SAE, which is ReLU in this case. TopK is another option that we will not cover here. 5. apply_b_dec_to_input: Determines whether to apply the decoder bias to the input, set to True here. 6. finetuning_scaling_factor: Indicates whether to use a scaling factor to weight initialization and the forward pass. This is not usually used and was introduced to support a [solution for shrinkage](https://www.lesswrong.com/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes). 7. context_size: Defines the size of the context window, which is 128 tokens in this case. In turns out SAEs trained on small activations from small prompts [often don't perform well on longer prompts](https://www.lesswrong.com/posts/baJyjpktzmcmRfosq/stitching-saes-of-different-sizes). 8. model_name: Specifies the name of the model being used, which is 'gpt2-small' here. [This is a valid model name in TransformerLens](https://transformerlensorg.github.io/TransformerLens/generated/model_properties_table.html). 9. hook_name: Indicates the specific hook in the model where the SAE is applied. 10. hook_layer: Specifies the layer number where the hook is applied, which is layer 7 in this case. 11. hook_head_index: Defines which attention head to hook into; not relevant here since we are looking at a residual stream SAE. 12. prepend_bos: Determines whether to prepend the beginning-of-sequence token, set to True. 13. dataset_path: Specifies the path to the dataset used for training or evaluation. (Can be local or a huggingface dataset.) 14. dataset_trust_remote_code: Indicates whether to trust remote code (from HuggingFace) when loading the dataset, set to True. 15. normalize_activations: Specifies how to normalize activations, set to 'none' in this config. 16. dtype: Defines the data type for tensor operations, set to 32-bit floating point. 17. device: Specifies the computational device to use. 18. sae_lens_training_version: Indicates the version of SAE Lens used for training, set to None here. 19. activation_fn_kwargs: Allows for additional keyword arguments for the activation function. This would be used if e.g. the activation_fn_str was set to topk, so that k could be specified.

print(tabulate(gpt2_sae.cfg.__dict__.items(), headers=["name", "value"], tablefmt="simple_outline"))
┌──────────────────────────────┬──────────────────────────────────┐
│ name                         │ value                            │
├──────────────────────────────┼──────────────────────────────────┤
│ architecture                 │ standard                         │
│ d_in                         │ 768                              │
│ d_sae                        │ 24576                            │
│ activation_fn_str            │ relu                             │
│ apply_b_dec_to_input         │ True                             │
│ finetuning_scaling_factor    │ False                            │
│ context_size                 │ 128                              │
│ model_name                   │ gpt2-small                       │
│ hook_name                    │ blocks.7.hook_resid_pre          │
│ hook_layer                   │ 7                                │
│ hook_head_index              │                                  │
│ prepend_bos                  │ True                             │
│ dataset_path                 │ Skylion007/openwebtext           │
│ dataset_trust_remote_code    │ True                             │
│ normalize_activations        │ none                             │
│ dtype                        │ torch.float32                    │
│ device                       │ cuda                             │
│ sae_lens_training_version    │                                  │
│ activation_fn_kwargs         │ {}                               │
│ neuronpedia_id               │ gpt2-small/7-res-jb              │
│ model_from_pretrained_kwargs │ {'center_writing_weights': True} │
└──────────────────────────────┴──────────────────────────────────┘

Visualizing SAEs with dashboards

In this section, you'll learn about SAE dashboards, which are a visual tool for quickly understanding what a particular SAE latent represents. Key points:

  • Neuronpedia hosts dashboards which help you understand SAE latents
  • The 5 main components of the dashboard are: top logit tables, logits histogram, activation density plots, top activating sequences, and autointerp
  • All of these components are important for getting a full picture of what a latent represents, but they can also all be misleading
  • You can display these dashboards inline, using IFrame

In this section, we're going to have a look at our SAEs, and see what they're actually telling us.

Before we dive too deep however, let's recap something - what actually is an SAE latent?

An SAE latent is a particular direction in the base model's activation space, learned by the SAE. Often, these correspond to features** in the data - in other words, meaningful semantic, syntactic or otherwise interpretable patterns or concepts that exist in the distribution of data the base model was trained on, and which were learned by the base model. These features are usually highly sparse, in other words for any given feature only a small fraction of the overall data distribution will activate that feature. It tends to be the case that sparser features are also more interpretable.

**Note - technically saying "direction" is an oversimplification here, because a given latent can have multiple directions in activation space associated with them, e.g. a separate encoder and decoder direction for standard untied SAEs. When we refer to a latent direction or feature direction, we're usually but not always referring to the decoder weights.

The dashboard shown below provides a detailed view of a single SAE latent.

def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    latent_idx=0,
    width=800,
    height=600,
):
    release = get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

    print(url)
    display(IFrame(url, width=width, height=height))


latent_idx = random.randint(0, gpt2_sae.cfg.d_sae)
display_dashboard(latent_idx=latent_idx)

Let's break down the separate components of the visualization:

  1. Latent Activation Distribution. This shows the proportion of tokens a latent fires on, usually between 0.01% and 1%, and also shows the distribution of positive activations.
  2. Logits Distribution. This is the projection of the decoder weight onto the unembed and roughly gives us a sense of the tokens promoted by a latent. It's less useful in big models / middle layers.
  3. Top / Botomn Logits. These are the 10 most positive and most negative logits in the logit weight distribution.
  4. Max Activating Examples. These are examples of text where the latent fires and usually provide the most information for helping us work out what a latent means.
  5. Autointerp. These are LLM-generated latent explanations, which use the rest of the data in the dashboard (in particular the max activating examples).

See this section of Towards Monosemanticity for more information.

Neuronpedia is a website that hosts SAE dashboards and which runs servers that can run the model and check latent activations. This makes it very convenient to check that a latent fires on the distribution of text you actually think it should fire on. We've been downloading data from Neuronpedia for the dashboards above.

Exercise - find interesting latents

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 5-15 minutes on this exercise.

Spend some time browsing through the SAE dashboard (i.e. running the code with different random indices). What interesting latents can you find? Try and find the following types of latents:

  • Latents for token-level features, which seem to only fire on a particular token and basically no others. Do the top logits make sense, when viewed as bigram frequencies?
  • Latents for concept-level features, which fire not on single tokens but across multiple tokens, provided that a particular concept is present in the text (e.g. latent_idx=4527 is an example of this). What is the concept that this latent represents? Can you see how the positive logits for this token make sense?
  • Highly sparse latents, with activation density less than 0.05%. Does it seem more interpretable than the average latent?
Click on this dropdown to see examples of each type, which you can compare to the ones you found.

Latent 9 seems to only fire on the word "new", in the context of describing a plurality of new things (often related to policies, business estimates, hiring, etc). The positive logits support this, with bigrams like "new arrivals" and "new developments" being boosted. Interestingly, we also have the bigram "newbie(s)" boosted.

Latent 67 seems to be more concept-level, firing on passages that talk about a country's decision to implement a particular policy or decision (especially when that country is being described as the only one to do it). Although the positive logits are also associated with countries or government policies, they don't directly make sense as bigrams - which is what we'd expect given that the latent fires on multiple different tokens in a sentence when that sentence contains the concept in question.

Latent 13 has fires with frequency 0.049%. It seems to activate on the token "win", especially when in the context of winning over people (e.g. winning the presidency, winning over hearts and minds, or winning an argument) or winning a race. This seems pretty specific and interpretable, although it's important when interpreting latents to remember the interpretability illusion - seeing the top activating patterns can given a misplaced sense of confidence in any particular interpretation. In later sections we'll perform more careful hypothesis testing to refine our understanding of latents.

Running SAEs

In this section, you'll learn how to run forward passes with SAEs. This is a pretty simple process, which builds on much of the pre-existing infrastructure in TransformerLens models. Key points:

  • You can add SAEs to a TransformerLens model when doing forward passes in pretty much the same way you add hook functions (you can think of SAEs as a special kind of hook function)
  • When sae.error_term=False (default) you substitute the SAE's output for the transformer activations. When True, you don't substitute (which is sometimes what you want when caching activations)
  • There's an analogous run_with_saes that works like run_with_hooks
  • There's also run_with_cache_with_saes that works like run_with_cache, but allows you to cache any SAE activations you want
  • You can use ActivationStore to get a large batch of activations at once

Now that we've had a look at some SAEs via Neuronpedia, it's time to load them in and start running them ourselves!

One of the key features of HookedSAETransformer is being able to "splice in" SAEs, replacing model activations with their SAE reconstructions. To run a forward pass with SAEs attached, you can use model.run_with_saes(tokens, saes=[list_of_saes]). This function has similar syntax to the standard forward pass (or to model.run_with_hooks), e.g. it can take arguments like return_type to specify whether the return type should be loss or logits. The attached SAEs will be reset immediately after the forward pass, returning the model to its original state. Under the hood, they work just like adding hooks in TransformerLens, only in this case our hooks are "replace these activations with their SAE reconstructions".

There are a lot of other ways you can do SAE-hooked forward passes, which parallel the multiple way you can do regular hooked forward passes. For example, just like you can use with model.hooks(fwd_hooks=...) as a context manager to add hooks temporarily, you can also use with model.saes(saes=...) to run a forward pass with SAEs attached. And just like you can use model.add_hook and model.reset_hooks, you can also use model.add_sae and model.reset_saes.

prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# First see how the model does without SAEs
test_prompt(prompt, answer, gpt2)

# Test our prompt, to see what the model says
with gpt2.saes(saes=[gpt2_sae]):
    test_prompt(prompt, answer, gpt2)

# Same thing, done in a different way
gpt2.add_sae(gpt2_sae)
test_prompt(prompt, answer, gpt2)
gpt2.reset_saes()  # Remember to always do this!

# Using `run_with_saes` method in place of standard forward pass
logits = gpt2(prompt, return_type="logits")
logits_sae = gpt2.run_with_saes(prompt, saes=[gpt2_sae], return_type="logits")
answer_token_id = gpt2.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model:
    top prediction = {gpt2.to_string(token_id_prediction)!r}
    prob = {top_prob.item():.2%}
SAE reconstruction:
    top prediction = {gpt2.to_string(token_id_prediction_sae)!r}
    prob = {top_prob_sae.item():.2%}
""")
Tokenized prompt: ['<|endoftext|>', 'Mit', 'igating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global']
Tokenized answer: [' priority']

Performance on answer token:
Rank: 0        Logit: 19.46 Prob: 52.99% Token: | priority|

Top 0th token. Logit: 19.46 Prob: 52.99% Token: | priority|
Top 1th token. Logit: 17.44 Prob:  7.02% Token: | effort|
Top 2th token. Logit: 16.94 Prob:  4.26% Token: | issue|
Top 3th token. Logit: 16.63 Prob:  3.14% Token: | challenge|
Top 4th token. Logit: 16.37 Prob:  2.42% Token: | goal|
Top 5th token. Logit: 16.06 Prob:  1.78% Token: | concern|
Top 6th token. Logit: 15.88 Prob:  1.47% Token: | focus|
Top 7th token. Logit: 15.61 Prob:  1.13% Token: | approach|
Top 8th token. Logit: 15.53 Prob:  1.04% Token: | policy|
Top 9th token. Logit: 15.42 Prob:  0.93% Token: | initiative|

Ranks of the answer tokens: [(' priority', 0)]

Tokenized prompt: ['<|endoftext|>', 'Mit', 'igating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global'] Tokenized answer: [' priority'] Performance on answer token: Rank: 0 Logit: 18.19 Prob: 39.84% Token: | priority| Top 0th token. Logit: 18.19 Prob: 39.84% Token: | priority| Top 1th token. Logit: 16.51 Prob: 7.36% Token: | issue| Top 2th token. Logit: 16.48 Prob: 7.20% Token: | concern| Top 3th token. Logit: 15.94 Prob: 4.19% Token: | challenge| Top 4th token. Logit: 15.30 Prob: 2.21% Token: | goal| Top 5th token. Logit: 15.12 Prob: 1.85% Token: | responsibility| Top 6th token. Logit: 15.04 Prob: 1.69% Token: | problem| Top 7th token. Logit: 14.98 Prob: 1.60% Token: | effort| Top 8th token. Logit: 14.73 Prob: 1.24% Token: | policy| Top 9th token. Logit: 14.66 Prob: 1.16% Token: | imperative| Ranks of the answer tokens: [(' priority', 0)]
Tokenized prompt: ['<|endoftext|>', 'Mit', 'igating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global'] Tokenized answer: [' priority'] Performance on answer token: Rank: 0 Logit: 18.19 Prob: 39.84% Token: | priority| Top 0th token. Logit: 18.19 Prob: 39.84% Token: | priority| Top 1th token. Logit: 16.51 Prob: 7.36% Token: | issue| Top 2th token. Logit: 16.48 Prob: 7.20% Token: | concern| Top 3th token. Logit: 15.94 Prob: 4.19% Token: | challenge| Top 4th token. Logit: 15.30 Prob: 2.21% Token: | goal| Top 5th token. Logit: 15.12 Prob: 1.85% Token: | responsibility| Top 6th token. Logit: 15.04 Prob: 1.69% Token: | problem| Top 7th token. Logit: 14.98 Prob: 1.60% Token: | effort| Top 8th token. Logit: 14.73 Prob: 1.24% Token: | policy| Top 9th token. Logit: 14.66 Prob: 1.16% Token: | imperative| Ranks of the answer tokens: [(' priority', 0)] Standard model: top prediction = ' priority' prob = 52.99% SAE reconstruction: top prediction = ' priority' prob = 39.84%

Okay, so this is fine if we want to do a forward pass with the model's output replaced by SAE output, but what if we want to just get the SAE activations? Well, that's where running with cache comes in! With HookedSAETransformer, you can cache SAE activations (and all the other standard activations) with logits, cache = model.run_with_cache_with_saes(tokens, saes=saes). Just as run_with_saes is a wapper around the standard forward pass, run_with_cache_with_saes is a wrapper around run_with_cache, and will also only add these saes for one forward pass before returning the model to its original state.

To access SAE activations from the cache, the corresponding hook names will generally be the concatenations of the HookedTransformer hook_name (e.g. "blocks.5.attn.hook_z") and the SAE hook name (e.g. "hook_sae_acts_post"), joined by a period. We can print out all the names below:

_, cache = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])

for name, param in cache.items():
    if "hook_sae" in name:
        print(f"{name:<43}: {tuple(param.shape)}")
blocks.7.hook_resid_pre.hook_sae_input     : (1, 13, 768)
blocks.7.hook_resid_pre.hook_sae_acts_pre  : (1, 13, 24576)
blocks.7.hook_resid_pre.hook_sae_acts_post : (1, 13, 24576)
blocks.7.hook_resid_pre.hook_sae_recons    : (1, 13, 768)
blocks.7.hook_resid_pre.hook_sae_output    : (1, 13, 768)

run_with_cache_with_saes makes it easy to explore which SAE latents are active across any input. We can also use this along with the argument stop_at_layer in our forward pass, because we don't need to compute any activations past the SAE layer.

Let's explore the active latents at the final token in our prompt. You should find that the first latent fires on the word "global", particularly in the context of disasters such as "global warming", "global poverty" and "global war".

# Get top activations on final token
_, cache = gpt2.run_with_cache_with_saes(
    prompt,
    saes=[gpt2_sae],
    stop_at_layer=gpt2_sae.cfg.hook_layer + 1,
)
sae_acts_post = cache[f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post"][0, -1, :]

# Plot line chart of latent activations
px.line(
    sae_acts_post.cpu().numpy(),
    title=f"Latent activations at the final token position ({sae_acts_post.nonzero().numel()} alive)",
    labels={"index": "Latent", "value": "Activation"},
    width=1000,
).update_layout(showlegend=False).show()

# Print the top 5 latents, and inspect their dashboards
for act, ind in zip(*sae_acts_post.topk(3)):
    print(f"Latent {ind} had activation {act:.2f}")
    display_dashboard(latent_idx=ind)
Click to see the expected output



Error term

Important note - the parameter sae.use_error_term determines whether we'll actually substitute the activations with SAE reconstructions during our SAE forward pass. If it's False (default) then we do replace activations with SAE reconstructions, but if it's True then we'll just compute the SAE's hidden activations without replacing the transformer activations with its output.

The use_error_term parameter controls behaviour when we do forward passes, hooked forward passes, forward passes with cache, or anything else that involves running the model with SAEs attached (but obviously this parameter only matters when we're caching values, because doing a forward pass with sae.use_error_term=True and not caching any values is equivalent to just running the base model without any SAEs!).

Why is it called use_error_term ?

It's called this because when set to True we'll have the final output of the forward pass be sae_out + sae_error rather than sae_out. This sae_error term is literally defined as sae_in - sae_out, i.e. the difference between the original input and SAE reconstruction. So this is equivalent to the SAE just being the identity function. But we need to do things this way so we still compute all the internal states of the SAE in exactly the same way as we would if we were actually replacing the transformer's activations with SAE reconstructions.

logits_no_saes, cache_no_saes = gpt2.run_with_cache(prompt)

gpt2_sae.use_error_term = False
logits_with_sae_recon, cache_with_sae_recon = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])

gpt2_sae.use_error_term = True
logits_without_sae_recon, cache_without_sae_recon = gpt2.run_with_cache_with_saes(
    prompt, saes=[gpt2_sae]
)

# Both SAE caches contain the hook values
assert f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post" in cache_with_sae_recon
assert f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post" in cache_without_sae_recon

# But final output will be different, because we don't use SAE reconstructions when use_error_term
t.testing.assert_close(logits_no_saes, logits_without_sae_recon)
logit_diff_from_sae = (logits_no_saes - logits_with_sae_recon).abs().mean()
print(f"Average logit diff from using SAE reconstruction: {logit_diff_from_sae:.4f}")
Average logit diff from using SAE reconstruction: 0.4117

Using ActivationStore

The ActivationsStore class is a convenient alternative to loading a bunch of data yourself. It streams in data from a given dataset; in the case of the from_sae class that dataset will be given by your SAE's config (which is also the same as the SAE's original training dataset):

print(gpt2_sae.cfg.dataset_path)
Skylion007/openwebtext

Let's load one in now. We'll use fairly conservative parameters here so it can be used without running out of memory, but feel free to increase these parameters if you're able to (or decrease them if you still find yourself running out of memory).

gpt2_act_store = ActivationsStore.from_sae(
    model=gpt2,
    sae=gpt2_sae,
    streaming=True,
    store_batch_size_prompts=16,
    n_batches_in_buffer=32,
    device=str(device),
)

# Example of how you can use this:
tokens = gpt2_act_store.get_batch_tokens()
assert tokens.shape == (gpt2_act_store.store_batch_size_prompts, gpt2_act_store.context_size)

Replicating SAE dashboards

In this section, you'll replicate the 5 main components of the SAE dashboard: top logits tables, logits histogram, activation density plots, top activating sequences, and autointerp. There's not really any new content here, just putting into practice what you've learned from the previous 2 sections "Visualizing SAEs with dashboards" and "Running SAEs".

Now that we know how to load in and run SAEs, we can start replicating the components of the SAE dashboard in turn. These exercises will help build up your experience running SAEs & working with their activations, as well as helping you dive deeper into the meaning and significance of the different dashboard components.

To review, basic SAE dashboards have 5 main components:

  1. Activation Distribution - the distribution of a latent's activations
  2. Logits Distribution - projection of decoder weights onto model's unembedding
  3. Top / Botomn Logits - the most positive and most negative logits in the logit weight distribution
  4. Max Activating Examples - sequences (and particular tokens) on which the latent fires strongest
  5. Autointerp - llm-generated latent explanations

We'll go through each of these in turn. We'll be using latent 9 for this exercise; you can compare your results to the expected dashboard:

display_dashboard(latent_idx=9)

Exercise - get the activation distribution

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 15-30 minutes on this exercise.

The function below should iterate through some number of batches (note that you can reduce the default number if you find the code is taking too long), and creates a histogram of the activations for a given latent. Try and return the activation density in the histogram's title too.

Reminder - when using model.run_with_cache_with_saes, you can use the arguments stop_at_layer=sae.cfg.hook_layer+1 as well as names_filter=hook_name; these will help you avoid unnecessary computation and memory usage.

Also, note that if you're working in Colab & using Plotly then you might need to adjust this code so it computes & renders the figure in separate code cells - this is a well known but unfixed Colab bug.

def show_activation_histogram(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 200,
):
    """
    Displays the activation histogram for a particular latent, computed across `total_batches`
    batches from `act_store`.
    """
    raise NotImplementedError()


show_activation_histogram(gpt2, gpt2_sae, gpt2_act_store, latent_idx=9)
Click to see the expected output
Click here for some code to plot a histogram (if you don't really care about this being part of the exercise)

This will work assuming all_positive_acts is a list of all non-zero activation values over all batches:

frac_active = len(all_positive_acts) / (
    total_batches  act_store.store_batch_size_prompts  act_store.context_size
)
px.histogram(
    all_positive_acts,
    nbins=50,
    title=f"ACTIVATIONS DENSITY {frac_active:.3%}",
    labels={"value": "Activation"},
    width=800,
    template="ggplot2",
    color_discrete_sequence=["darkorange"],
).update_layout(bargap=0.02, showlegend=False).show()

Note that if you're in Colab, you might need to return this figure and plot it in a separate cell (cause Colab is weird about plotting in notebooks in the same cell as it performs computation).

Solution
def show_activation_histogram(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 200,
):
    """
    Displays the activation histogram for a particular latent, computed across total_batches
    batches from act_store.
    """
    sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"
    all_positive_acts = []
for i in tqdm(range(total_batches), desc="Computing activations for histogram"):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_post_hook_name],
        )
        acts = cache[sae_acts_post_hook_name][..., latent_idx]
        all_positive_acts.extend(acts[acts > 0].cpu().tolist())
frac_active = len(all_positive_acts) / (
        total_batches  act_store.store_batch_size_prompts  act_store.context_size
    )
px.histogram(
        all_positive_acts,
        nbins=50,
        title=f"ACTIVATIONS DENSITY {frac_active:.3%}",
        labels={"value": "Activation"},
        width=800,
        template="ggplot2",
        color_discrete_sequence=["darkorange"],
    ).update_layout(bargap=0.02, showlegend=False).show()

Exercise - find max activating examples

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 15-30 minutes on this exercise.

We'll start by finding the max-activating examples - the prompts that show the highest level of activation from a latent. We've given you a function with a docstring to complete, although exactly how you want to present the data is entirely up to you.

We've given you the following helper functions, as well as examples showing how to use them:

  • get_k_largest_indices, which will return the batch & seqpos indices of the k largest elements in a (batch, seq)-sized tensor,
  • index_with_buffer, which will index into a (batch, seq)-sized tensor with the results of get_k_largest_indices, including the tokens within buffer from the selected indices in the same sequence (this helps us get context around the selected tokens),
  • display_top_seqs, which will display sequences (with the relevant token highlighted) in a readable way.

When it comes to decoding sequences, you can use model.to_str_tokens to map a 1D tensor of token IDs to a list of string tokens. Note that you're likely to get some unknown tokens "�" in your output - this is an unfortunate byproduct of tokenization the way we're doing it, and you shouldn't worry about it too much.

def get_k_largest_indices(
    x: Float[Tensor, "batch seq"], k: int, buffer: int = 0
) -> Int[Tensor, "k 2"]:
    """
    The indices of the top k elements in the input tensor, i.e. output[i, :] is the (batch, seqpos)
    value of the i-th largest element in x.

    Won't choose any elements within `buffer` from the start or end of their sequence.
    """
    if buffer > 0:
        x = x[:, buffer:-buffer]
    indices = x.flatten().topk(k=k).indices
    rows = indices // x.size(1)
    cols = indices % x.size(1) + buffer
    return t.stack((rows, cols), dim=1)


x = t.arange(40, device=device).reshape((2, 20))
x[0, 10] += 50  # 2nd highest value
x[0, 11] += 100  # highest value
x[1, 1] += 150  # not inside buffer (it's less than 3 from the start of the sequence)
top_indices = get_k_largest_indices(x, k=2, buffer=3)
assert top_indices.tolist() == [[0, 11], [0, 10]]


def index_with_buffer(
    x: Float[Tensor, "batch seq"], indices: Int[Tensor, "k 2"], buffer: int | None = None
) -> Float[Tensor, "k *buffer_x2_plus1"]:
    """
    Indexes into `x` with `indices` (which should have come from the `get_k_largest_indices`
    function), and takes a +-buffer range around each indexed element. If `indices` are less than
    `buffer` away from the start of a sequence then we just take the first `2*buffer+1` elems (same
    for at the end of a sequence).

    If `buffer` is None, then we don't add any buffer and just return the elements at the given indices.
    """
    rows, cols = indices.unbind(dim=-1)
    if buffer is not None:
        rows = einops.repeat(rows, "k -> k buffer", buffer=buffer * 2 + 1)
        cols[cols < buffer] = buffer
        cols[cols > x.size(1) - buffer - 1] = x.size(1) - buffer - 1
        cols = einops.repeat(cols, "k -> k buffer", buffer=buffer * 2 + 1) + t.arange(
            -buffer, buffer + 1, device=cols.device
        )
    return x[rows, cols]


x_top_values_with_context = index_with_buffer(x, top_indices, buffer=3)
assert x_top_values_with_context[0].tolist() == [
    8,
    9,
    10 + 50,
    11 + 100,
    12,
    13,
    14,
]  # highest value in the middle
assert x_top_values_with_context[1].tolist() == [
    7,
    8,
    9,
    10 + 50,
    11 + 100,
    12,
    13,
]  # 2nd highest value in the middle


def display_top_seqs(data: list[tuple[float, list[str], int]]):
    """
    Given a list of (activation: float, str_toks: list[str], seq_pos: int), displays a table of
    these sequences, with the relevant token highlighted.

    We also turn newlines into "\\n", and remove unknown tokens � (usually weird quotation marks)
    for readability.
    """
    table = Table("Act", "Sequence", title="Max Activating Examples", show_lines=True)
    for act, str_toks, seq_pos in data:
        formatted_seq = (
            "".join(
                [
                    f"[b u green]{str_tok}[/]" if i == seq_pos else str_tok
                    for i, str_tok in enumerate(str_toks)
                ]
            )
            .replace("�", "")
            .replace("\n", "↵")
        )
        table.add_row(f"{act:.3f}", repr(formatted_seq))
    rprint(table)


example_data = [
    (0.5, [" one", " two", " three"], 0),
    (1.5, [" one", " two", " three"], 1),
    (2.5, [" one", " two", " three"], 2),
]
display_top_seqs(example_data)
  Max Activating Examples   
┏━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃ Act    Sequence         ┃
┡━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│ 0.500 │ ' one two three' │
├───────┼──────────────────┤
│ 1.500 │ ' one two three' │
├───────┼──────────────────┤
│ 2.500 │ ' one two three' │
└───────┴──────────────────┘

You should fill in the following function. It should return data as a list of tuples of the form (max activation, list of string tokens, sequence position), and if display is True then it should also call display_top_seqs on this data (you'll find this function helpful when we implement autointerp later!).

def fetch_max_activating_examples(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 10,
    buffer: int = 10,
) -> list[tuple[float, list[str], int]]:
    """
    Returns the max activating examples across a number of batches from the activations store.
    """
    raise NotImplementedError()


# Fetch & display the results
buffer = 10
data = fetch_max_activating_examples(
    gpt2, gpt2_sae, gpt2_act_store, latent_idx=9, buffer=buffer, k=5
)
display_top_seqs(data)

# Test one of the results, to see if it matches the expected output
first_seq_str_tokens = data[0][1]
assert first_seq_str_tokens[buffer] == " new"
                                             Max Activating Examples                                              
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Act     Sequence                                                                                              ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 43.451 │ '.↵↵Airline industry↵↵Under the new rules, payment surcharges will have to reflect the'               │
├────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 41.921 │ '145m.↵↵Enforcement↵↵The new rules are being brought in earlier than the rest of'                     │
├────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 39.845 │ ' activity," the minister had said.↵↵The new law is a precursor to banning chewing tobacco in public' │
├────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 37.051 │ '."↵↵Niedermeyer agreed that the new car excitement "tapers off" the longer it'                       │
├────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 37.004 │ ' each other as soon as possible."↵↵The new desert map will be included in the 1.0'                   │
└────────┴───────────────────────────────────────────────────────────────────────────────────────────────────────┘
Solution
def fetch_max_activating_examples(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 10,
    buffer: int = 10,
) -> list[tuple[float, list[str], int]]:
    """
    Returns the max activating examples across a number of batches from the activations store.
    """
    sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"
# Create list to store the top k activations for each batch. Once we're done,
    # we'll filter this to only contain the top k over all batches
    data = []
for _ in tqdm(range(total_batches), desc="Computing activations for max activating examples"):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_post_hook_name],
        )
        acts = cache[sae_acts_post_hook_name][..., latent_idx]
# Get largest indices, get the corresponding max acts, and get the surrounding indices
        k_largest_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
        tokens_with_buffer = index_with_buffer(tokens, k_largest_indices, buffer=buffer)
        str_toks = [model.to_str_tokens(toks) for toks in tokens_with_buffer]
        top_acts = index_with_buffer(acts, k_largest_indices).tolist()
        data.extend(list(zip(top_acts, str_toks, [buffer] * len(str_toks))))
return sorted(data, key=lambda x: x[0], reverse=True)[:k]

Non-overlapping sequences

For the latent above, returning sequences the way you did probably worked pretty well. But other more concept-level latents (where multiple tokens in a sentence fire strongly) are a bit more annoying. You can try this function on a latent like 16873 (which fires on specific bible passages) - the returned sequences will mostly be the same, just shifted over by a different amount.

data = fetch_max_activating_examples(
    gpt2, gpt2_sae, gpt2_act_store, latent_idx=16873, total_batches=200
)
display_top_seqs(data)
Click to see the expected output
                                              Max Activating Examples                                              
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Act     Sequence                                                                                               ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 15.400 │ 'aqara 2:221 =↵↵And do not marry polytheistic women until they believe. And'                           │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 14.698 │ 'Baqara 2:221 =↵↵And do not marry polytheistic women until they believe.'                              │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 14.648 │ ' Testament: Verily, verily, I say unto you, Except a corn of wheat fall into the'                     │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 14.124 │ ' not marry polytheistic women until they believe. And a believing slave woman is better than a        │
│        │ polythe'                                                                                               │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 14.069 │ ' thou to me? And Jesus answering said unto him, Suffer it to be so now: for thus'                     │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 13.461 │ ' "But John forbad him, saying, I have a need to be baptised of thee, and'                             │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 13.024 │ '14–15: "But John forbad him, saying, I have a need to be baptised'                                    │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 12.874 │ ' to me? And Jesus answering said unto him, Suffer it to be so now: for thus it'                       │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 12.801 │ ' Suffer it to be so now: for thus it becometh us to fulfil all righteousness", and'                   │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 12.747 │ ':14–15: "But John forbad him, saying, I have a need to be bapt'                                       │
└────────┴────────────────────────────────────────────────────────────────────────────────────────────────────────┘

One way you can combat this is by emposing the restriction that any given top-activating token can only be in one sequence, i.e. when you pick that token you can't pick any in the range [-buffer, buffer] around it. We've given you a new function get_k_largest_indices below. Try it out with no_overlap=True, are the results much better?

def get_k_largest_indices(
    x: Float[Tensor, "batch seq"],
    k: int,
    buffer: int = 0,
    no_overlap: bool = True,
) -> Int[Tensor, "k 2"]:
    """
    Returns the tensor of (batch, seqpos) indices for each of the top k elements in the tensor x.

    Args:
        buffer:     We won't choose any elements within `buffer` from the start or end of their seq
                    (this helps if we want more context around the chosen tokens).
        no_overlap: If True, this ensures that no 2 top-activating tokens are in the same seq and
                    within `buffer` of each other.
    """
    assert buffer * 2 < x.size(1), "Buffer is too large for the sequence length"
    assert not no_overlap or k <= x.size(0), (
        "Not enough sequences to have a different token in each sequence"
    )

    if buffer > 0:
        x = x[:, buffer:-buffer]

    indices = x.flatten().argsort(-1, descending=True)
    rows = indices // x.size(1)
    cols = indices % x.size(1) + buffer

    if no_overlap:
        unique_indices = t.empty((0, 2), device=x.device).long()
        while len(unique_indices) < k:
            unique_indices = t.cat(
                (unique_indices, t.tensor([[rows[0], cols[0]]], device=x.device))
            )
            is_overlapping_mask = (rows == rows[0]) & ((cols - cols[0]).abs() <= buffer)
            rows = rows[~is_overlapping_mask]
            cols = cols[~is_overlapping_mask]
        return unique_indices

    return t.stack((rows, cols), dim=1)[:k]


x = t.arange(40, device=device).reshape((2, 20))
x[0, 10] += 150  # highest value
x[0, 11] += 100  # 2nd highest value, but won't be chosen because of overlap
x[1, 10] += 50  # 3rd highest, will be chosen
top_indices = get_k_largest_indices(x, k=2, buffer=3)
assert top_indices.tolist() == [[0, 10], [1, 10]]


data = fetch_max_activating_examples(
    gpt2, gpt2_sae, gpt2_act_store, latent_idx=16873, total_batches=200
)
display_top_seqs(data)
Click to see the expected output
                                            Max Activating Examples                                            
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Act    ┃ Sequence                                                                                           ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 17.266 │ ' seven times in a day be converted unto thee, saying: I repent: forgive him. And the apostles'    │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 16.593 │ '. And the apostles said to the Lord: Increase our faith. (Luke 17:3-5)'                           │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 16.146 │ ' say unto you, Arise, and take up your couch, and go into your house." (Luke'                     │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 15.965 │ ' mystery of the kingdom of God: but unto them that are without, all these things are done in par' │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 14.985 │ ' parable. And he said unto them: To you it is given to know the mystery of the kingdom'           │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 14.383 │ ' forgiven you; or to say, Rise up and walk? But that you may know that the Son of'                │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 13.476 │ 'he said unto him that was palsied:) I say unto you, Arise, and take up'                           │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 13.430 │ ' their thoughts, he answered and said unto them, "What reason have you in your hearts? Which is'  │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 13.181 │ ' things are done in parables: That seeing they may see, and not perceive, and hearing they may'   │
├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 13.071 │ ' seven times?" Jesus says unto him, "I say not unto you, Until seven times, but until'            │
└────────┴────────────────────────────────────────────────────────────────────────────────────────────────────┘

Exercise - get top / bottom logits

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-15 minutes on this exercise.

We'll end with the top & bottom logits tables. These don't require data, since they're just functions of the SAE and model's weights. Recall - you can access the unembedding of your base model using model.W_U, and you can access your SAE's decoder weights using sae.W_dec.

def show_top_logits(
    model: HookedSAETransformer,
    sae: SAE,
    latent_idx: int,
    k: int = 10,
) -> None:
    """
    Displays the top & bottom logits for a particular latent.
    """
    raise NotImplementedError()


show_top_logits(gpt2, gpt2_sae, latent_idx=9)
tests.test_show_top_logits(show_top_logits, gpt2, gpt2_sae)
Click to see the expected output
┌─────────────────┬─────────┬─────────────────┬─────────┐
│   Bottom tokens │ Value   │      Top tokens │ Value   │
├─────────────────┼─────────┼─────────────────┼─────────┤
│           'Zip' │ -0.774  │          'bies' │ +1.327  │
│       'acebook' │ -0.761  │           'bie' │ +1.297  │
│           'lua' │ -0.737  │     ' arrivals' │ +1.218  │
│        'ashtra' │ -0.728  │    ' additions' │ +1.018  │
│       'ONSORED' │ -0.708  │      ' edition' │ +0.994  │
│           'OGR' │ -0.705  │   ' millennium' │ +0.966  │
│      'umenthal' │ -0.703  │   ' generation' │ +0.962  │
│        'ecause' │ -0.697  │     ' entrants' │ +0.923  │
│          'icio' │ -0.692  │        ' hires' │ +0.919  │
│          'cius' │ -0.692  │ ' developments' │ +0.881  │
└─────────────────┴─────────┴─────────────────┴─────────┘
Solution
def show_top_logits(
    model: HookedSAETransformer,
    sae: SAE,
    latent_idx: int,
    k: int = 10,
) -> None:
    """
    Displays the top & bottom logits for a particular latent.
    """
    logits = sae.W_dec[latent_idx] @ model.W_U
pos_logits, pos_token_ids = logits.topk(k)
    pos_tokens = model.to_str_tokens(pos_token_ids)
    neg_logits, neg_token_ids = logits.topk(k, largest=False)
    neg_tokens = model.to_str_tokens(neg_token_ids)
print(
        tabulate(
            zip(map(repr, neg_tokens), neg_logits, map(repr, pos_tokens), pos_logits),
            headers=["Bottom tokens", "Value", "Top tokens", "Value"],
            tablefmt="simple_outline",
            stralign="right",
            numalign="left",
            floatfmt="+.3f",
        )
    )

Exercise - autointerp

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵⚪⚪⚪⚪
You should consider skipping this exercise / reading the solution unless you're really interested in autointerp.

Automated interpretability is one particularly exciting area of research at the moment. It originated with the OpenAI paper Language models can explain neurons in language models, which showed that we could take a neuron from GPT-2 and use GPT-4 to generate explanations of its behaviour by showing it relevant text sequences and activations. This shows us one possible way to assess and categorize latents at scale, without requiring humans to manually inspect and label them (which would obviously be totally impractical to scale). You can also read about some more of the recent advancements made here and here.

This scalable efficient categorization is one use-case, but there's also a second one: SAE evaluations. This is a topic we'll dive deeper into in later sections, but to summarize now: SAE evaluations are the ways we measure "how good our SAE is" in a varity of different ways. It turns out this is very hard, because any metric we pick is vulnerable to Goodhearting and not necessarily representative of what we want out of our SAEs. For example, it seems like sparser latents are often more interpretable (if reconstruction loss is held constant), which is why a common way to evaluate SAEs is along a Pareto frontier of sparsity & reconstruction loss. But what if sparsity doesn't lead to more interpretable latents (e.g. because of feature absorption)? Autointerp provides an alternative way of evaluating SAE interpretability, because we can directly quantify how good our latent explanations are! The idea is to convert latent explanations into a set of predictions on some test set of prompts, and then score the accuracy of those predictions. More interpretable latents should lead to better predictions, because the latents will tend to have monosemantic and human-interpretable patterns that can be predicted from their given explanations.

For now though, we'll focus on just the first half of autointerp, i.e. the generation of explanations. You can download and read the Neuronpedia-hosted latent explanations with the following code:

def get_autointerp_df(
    sae_release="gpt2-small-res-jb", sae_id="blocks.7.hook_resid_pre"
) -> pd.DataFrame:
    release = get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = "https://www.neuronpedia.org/api/explanation/export?modelId={}&saeId={}".format(
        *neuronpedia_id.split("/")
    )
    headers = {"Content-Type": "application/json"}
    response = requests.get(url, headers=headers)

    data = response.json()
    return pd.DataFrame(data)


explanations_df = get_autointerp_df()
explanations_df.head()

Now, let's try doing some autointerp ourselves! This will involve 3 steps:

  1. Calling fetch_max_activating_examples to get the top-activating examples for a given latent.
  2. Calling create_prompt to create a system, user & assistant prompt for the OpenAI API which contains this data.
  3. Calling get_autointerp_explanation to pass these prompts to the OpenAI API and get a response.

You've already implemented fetch_max_activating_examples, and we've given you get_autointerp_explanation - you just need to implement create_prompt.

Click to see recommended autointerp prompt structure

One possible method based on Anthropic's past published material is to show a list of top sequences and the activations for every single token, and ask for an explanation (then in the scoring phase we'd ask the model to predict activation values). However, here we'll do something a bit simpler, and just highlight the top-activating token in any given sequence (without giving numerical activation values).

{
    "system": "We're studying neurons in a neural network. Each neuron activates on some particular word or concept in a short document. The activating words in each document are indicated with << ... >>. Look at the parts of the document the neuron activates for and summarize in a single sentence what the neuron is activating on. Try to be specific in your explanations, although don't be so specific that you exclude some of the examples from matching your explanation. Pay attention to things like the capitalization and punctuation of the activating words or concepts, if that seems relevant. Keep the explanation as short and simple as possible, limited to 20 words or less. Omit punctuation and formatting. You should avoid giving long lists of words.",
"user": """The activating documents are given below:
1. and he was <<over the moon>> to find
2. we'll be laughing <<till the cows come home>>! Pro
3. thought Scotland was boring, but really there's more <<than meets the eye>>! I'd""",
"assistant": "this neuron fires on",
}

We feed the system, then user, then assistant prompt into our model. The idea is:

- The system prompt explains what the task will be, - The user prompt contains the actual task data, - The assistant prompt helps condition the model's likely response format.

Note - this is all very low-tech, and we'll expand greatly on these methods when we dive deeper into autointerp in later sections.

We recommend you use the augmented version of get_k_largest_indices you were given above, which doesn't allow for sequence overlap. This is because you don't want to send redundant information in your prompt!

def create_prompt(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 15,
    buffer: int = 10,
) -> dict[Literal["system", "user", "assistant"], str]:
    """
    Returns the system, user & assistant prompts for autointerp.
    """
    raise NotImplementedError()


# Test your function
prompts = create_prompt(
    gpt2, gpt2_sae, gpt2_act_store, latent_idx=9, total_batches=100, k=15, buffer=8
)
assert prompts["system"].startswith("We're studying neurons in a neural network.")
assert "<< new>>" in prompts["user"]
assert prompts["assistant"] == "this neuron fires on"
Solution
def create_prompt(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 15,
    buffer: int = 10,
) -> dict[Literal["system", "user", "assistant"], str]:
    """
    Returns the system, user & assistant prompts for autointerp.
    """
    data = fetch_max_activating_examples(
        model, sae, act_store, latent_idx, total_batches, k, buffer
    )
    str_formatted_examples = "\n".join(
        f"{i + 1}. {''.join(f'<<{tok}>>' if j == buffer else tok for j, tok in enumerate(seq[1]))}"
        for i, seq in enumerate(data)
    )
return {
        "system": "We're studying neurons in a neural network. Each neuron activates on some particular word or concept in a short document. The activating words in each document are indicated with << ... >>. Look at the parts of the document the neuron activates for and summarize in a single sentence what the neuron is activating on. Try to be specific in your explanations, although don't be so specific that you exclude some of the examples from matching your explanation. Pay attention to things like the capitalization and punctuation of the activating words or concepts, if that seems relevant. Keep the explanation as short and simple as possible, limited to 20 words or less. Omit punctuation and formatting. You should avoid giving long lists of words.",
        "user": f"""The activating documents are given below:\n\n{str_formatted_examples}""",
        "assistant": "this neuron fires on",
    }

Once you've passed the tests for create_prompt, you can implement the full get_autointerp_explanation function:

def get_autointerp_explanation(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 15,
    buffer: int = 10,
    n_completions: int = 1,
) -> list[str]:
    """
    Queries OpenAI's API using prompts returned from `create_prompt`, and returns a list of the
    completions.
    """
    raise NotImplementedError()


API_KEY = os.environ.get("OPENAI_API_KEY", None)

if API_KEY is not None:
    completions = get_autointerp_explanation(
        gpt2, gpt2_sae, gpt2_act_store, latent_idx=9, n_completions=5
    )
    for i, completion in enumerate(completions):
        print(f"Completion {i + 1}: {completion!r}")
else:
    print("No API key found, not running the autointerp code.")
Click to see the expected output
Completion 1: 'the concept of new policies or products being introduced'
Completion 2: 'the concept of new initiatives or products'
Completion 3: 'the concept of new ideas or initiatives'
Completion 4: 'the concept of new developments or initiatives'
Completion 5: 'the concept of new initiatives or products'
Solution
def get_autointerp_explanation(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 15,
    buffer: int = 10,
    n_completions: int = 1,
) -> list[str]:
    """
    Queries OpenAI's API using prompts returned from create_prompt, and returns a list of the
    completions.
    """
    client = OpenAI(api_key=API_KEY)
prompts = create_prompt(model, sae, act_store, latent_idx, total_batches, k, buffer)
result = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": prompts["system"]},
            {"role": "user", "content": prompts["user"]},
            {"role": "assistant", "content": prompts["assistant"]},
        ],
        n=n_completions,
        max_tokens=50,
        stream=False,
    )
    return [choice.message.content for choice in result.choices]

Attention SAEs

In this section, you'll learn about attention SAEs, how they work (mostly quite similar to standard SAEs but with a few other considerations), and how to understand their feature dashboards. Key points:

  • Attention SAEs have the same architecture as regular SAEs, except they're trained on the concatenated pre-projection output of all attention heads.
  • If a latent fires on a destination token, we can use direct latent attribution to see which source tokens it primarily came from.
  • Just like regular SAEs, latents found in different layers of a model are often qualitatively different from each other.

In this section, we'll be exploring different ways of finding a latent for some given concept. However, before we get into that, we first need to introduce a new concept - attention SAEs.

Research done by Kissane el al as part of the MATS program has shown that we can use SAEs on the output of attention layers, and it also works well: the SAEs learn sparse, interpretable latents, which gives us insight into what attention layers learn. Subsequent work trained SAEs on the attention output of every layer of GPT2Small; these are the SAEs we'll be using in today's exercises.

Functionally, these SAEs work just like regular ones, except that they take the z output of the attention layer as input (i.e. after taking a linear combination of value vectors, but before projecting via the output matrix) rather than taking the residual stream or post-ReLU MLP activations. These z vectors are usually concatenated together across attention heads, for a single layer.

Can you see why we take the attention output before projection via the output matrix, rather than after?

It would be a waste of parameters. The encoder is a linear map from activation space to latent space, and the attention head's output z @ W_O can't have a larger rank than z (though it might have smaller rank), but it will be larger and hence will lead to less efficient training.

Can you guess why we concatenate across attention heads?

We do this because heads might be in superposition, just like neurons in MLP layers. As well as a single attention head containing many latents, we could have a latent which is split across multiple attention heads. Evidence of shared attention head functionality abounds in regular models, for instance in the intro to mech interp ARENA exercises, we examined a 2-layer model where 2 heads in the second layer came together to form a copying head. In that case, we might expect to find latents which are split across both heads.

However, one interesting thing about attention SAEs is that we also have to think about source tokens, not just the destination token. In other words, once we identify some attention latent that is present at a particular destination token, we still need to ask the question of which source tokens it came from.

This leads us to the tool of direct latent attribution (which we'll abbreviate as "DFA" or "direct feature attribution", just so it doesn't get confused with direct logit attribution!). Just as in direct logit attribution (DLA) we ask which components wrote to the residual stream in ways which directly influenced certain logits, with DFA we can decompose the input to the destination token which caused that latent to fire. This can tell us things like which head contributed most to that latent, or which source token (or both).

Exercise - explore attention SAE dashboards

Difficulty: 🔴⚪⚪⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 10-20 minutes on this exercise.

Run the code below to see an example of a latent dashboard for a layer-9 attention head in GPT2-Small. The green text on the right shows you when this latent was activated, and the orange highlight shows you the DFA for the primary token which activated it.

You should specifically try think about the qualitative difference between latents you see in SAEs trained at different layers (you can change the layer variable below to investigate this). What do you notice? What kinds of latents exist at earlier layers but not later ones? Which layers have more interpretable output logits?

Some things you should see

You should generally find the following themes:

- Early layer head latents often represent low-level grammatical syntax (e.g. firing on single tokens or bigrams). - Middle layer head latents are often the hardest to interpret because they respond to higher-level semantic information, but also aren't always interpretable in terms of their output logits, likely because they write to intermediate representations which are then used by other heads or MLP layers. - Late layer head latents are often understood in terms of their output logits (i.e. they're directly writing predictions to the residual stream). The very last layer is something of an exception to this, since it seems to deal largely with grammatical corrections and adjustments.

For more on this, you can read the table in the LessWrong post [We Inspected Every Head In GPT-2 Small using SAEs So You Don’t Have To](https://www.lesswrong.com/posts/xmegeW5mqiBsvoaim/we-inspected-every-head-in-gpt-2-small-using-saes-so-you-don#Overview_of_Attention_Heads_Across_Layers) (which looks at the same SAEs we're working with here).

attn_saes = {
    layer: SAE.from_pretrained(
        "gpt2-small-hook-z-kk",
        f"blocks.{layer}.hook_z",
        device=str(device),
    )[0]
    for layer in range(gpt2.cfg.n_layers)
}

layer = 9

display_dashboard(
    sae_release="gpt2-small-hook-z-kk",
    sae_id=f"blocks.{layer}.hook_z",
    latent_idx=2,  # or you can try `random.randint(0, attn_saes[layer].cfg.d_sae)`
)

Exercise - derive attention DFA

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 25-45 minutes on this exercise.

Since we've added another component to our latent dashboard, let's perform another derivation! For any given destination token firing, we can define the direct latent attribution (DFA) of a given source token as the dot product of the value vector taken at that source token (scaled by the attention probability) and the SAE encoder direction for this latent. In other words, the pre-ReLU activations of the latent are equal to a sum of its DFA over all source tokens.

When you complete this problem, you'll be able to complete the final part of the attention dashboard. We've given you a function to visualize this, like we used for the residual stream SAEs:

@dataclass
class AttnSeqDFA:
    act: float
    str_toks_dest: list[str]
    str_toks_src: list[str]
    dest_pos: int
    src_pos: int


def display_top_seqs_attn(data: list[AttnSeqDFA]):
    """
    Same as previous function, but we now have 2 str_tok lists and 2 sequence positions to
    highlight, the first being for top activations (destination token) and the second for top DFA
    (src token). We've given you a dataclass to help keep track of this.
    """
    table = Table(
        "Top Act",
        "Src token DFA (for top dest token)",
        "Dest token",
        title="Max Activating Examples",
        show_lines=True,
    )
    for seq in data:
        formatted_seqs = [
            repr(
                "".join(
                    [
                        f"[b u {color}]{str_tok}[/]" if i == seq_pos else str_tok
                        for i, str_tok in enumerate(str_toks)
                    ]
                )
                .replace("�", "")
                .replace("\n", "↵")
            )
            for str_toks, seq_pos, color in [
                (seq.str_toks_src, seq.src_pos, "dark_orange"),
                (seq.str_toks_dest, seq.dest_pos, "green"),
            ]
        ]
        table.add_row(f"{seq.act:.3f}", *formatted_seqs)
    rprint(table)


str_toks = [" one", " two", " three", " four"]
example_data = [
    AttnSeqDFA(
        act=0.5, str_toks_dest=str_toks[1:], str_toks_src=str_toks[:-1], dest_pos=0, src_pos=0
    ),
    AttnSeqDFA(
        act=1.5, str_toks_dest=str_toks[1:], str_toks_src=str_toks[:-1], dest_pos=1, src_pos=1
    ),
    AttnSeqDFA(
        act=2.5, str_toks_dest=str_toks[1:], str_toks_src=str_toks[:-1], dest_pos=2, src_pos=0
    ),
]
display_top_seqs_attn(example_data)
                      Max Activating Examples                       
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Top Act  Src token DFA (for top dest token)  Dest token        ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ 0.500   │ ' one two three'                   │ ' two three four' │
├─────────┼────────────────────────────────────┼───────────────────┤
│ 1.500   │ ' one two three'                   │ ' two three four' │
├─────────┼────────────────────────────────────┼───────────────────┤
│ 2.500   │ ' one two three'                   │ ' two three four' │
└─────────┴────────────────────────────────────┴───────────────────┘

Now, fill in the function below. You'll be testing your function by comparing the output to the dashboard generated from latent_idx=2 (shown above). You should have observed that the top activating tokens are conjunctions liek " and", " or" which connect lists of words like " weapons", " firearms", " missiles", etc. We can also see that the max DFA token is usually a similar word earlier in context, sometimes the word immediately before the top activating token (which makes sense given the positive logits also boost words like this). In other words, part of what this latent seems to be doing is detecting when we're in the middle of a conjunction phrase like weapons and tactics or guns ... not missiles and predicting logical completions for how this phrase will end.

Note - this is a pretty difficult problem (mostly because of all the rearranging and multiple steps in the solution). If you get stuck, we strongly recommend using the hint below.

Help - I'm still confused about this calculation.

We have the value vectors v of shape (batch, seq, n_heads, d_head) at each position. By broadcasting & multiplying by the attention probabilities, we can get v_weighted of shape (batch, seq_dest, seq_src, n_heads, d_head), which represents the vector that'll be taken at each source position and added to the destination position, and will be summed over to produce z (the values we have before projection by the output matrix W_O to add back to the residual stream).

It's this z (after flattening over attention heads) that the SAE gets trained on, i.e. z @ sae.W_enc are the SAE's pre-ReLU activations. So by writing z as a sum of v_weighted over source positions, we can write the pre-ReLU activation for latent latent_idx as a sum of v_weighted[:, :, src_pos, :, :] @ sae.W_enc[:, latent_idx] over all src_pos values. So for any given sequence b in the batch, and destination position dest_pos, we can compute the scalar v_weighted[b, dest_pos, src_pos, :, :] @ sae.W_enc[:, latent_idx] for each src_pos, and find the largest one.

Reminder - W_enc is actually a linear map from n_heads * d_head to d_sae dimensions, so to perform this calculation we'll first need to flatten the values v_weighted over heads.

Note, some of the src token indexing can get a bit fiddly. In particular, when you get the index positions of the top contributing source tokens, some of them might be within buffer of the start of the sequence. The index_with_buffer handles this case for you, because whenever the indexing values are within buffer of the start or end of the sequence, it'll just take the first or last buffer tokens respectively.

def fetch_max_activating_examples_attn(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 250,
    k: int = 10,
    buffer: int = 10,
) -> list[AttnSeqDFA]:
    """
    Returns the max activating examples across a number of batches from the activations store.
    """
    raise NotImplementedError()


# Test your function: compare it to dashboard above
# (max DFA should come from sourcs tokens like " guns", " firearms")
layer = 9
data = fetch_max_activating_examples_attn(gpt2, attn_saes[layer], gpt2_act_store, latent_idx=2)
display_top_seqs_attn(data)
Click to see the expected output
                                              Max Activating Examples                                              
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Top Act  Src token DFA (for top dest token)                 Dest token                                        ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 2.593   │ ' After all, you cant order numerous guns and     │ ' cant order numerous guns and massive amounts of │
│         │ massive amounts of ammunition and also Batman     │ ammunition and also Batman costumes on line       │
│         │ costumes on'                                      │ without generating search'                        │
├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤
│ 2.000   │ ' on, as far as when they would draw a weapon and │ " when they would draw a weapon and when they     │
│         │ when they would use force, is based off'          │ would use force, is based off of officer's        │
│         │                                                   │ perception,"                                      │
├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤
│ 1.921   │ '↵↵So chances are a good guy with a gun will not  │ ' guy with a gun will not stop a bad guy with a   │
│         │ stop a bad guy with a gun.'                       │ gun. In fact, trying to produce more'             │
├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤
│ 1.780   │ ' red lines on chemical weapons, he took out the  │ ' he took out the weapons, but not those who used │
│         │ weapons, but not those who used them. I don'      │ them. I dont think history will'                  │
├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤
│ 1.501   │ ' academy teaches young trainees about when to    │ 'ees about when to use deadly force or when to    │
│         │ use deadly force or when to use non-lethal force, │ use non-lethal force, such as pepper spray or'    │
│         │ such'                                             │                                                   │
├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤
│ 1.430   │ ' the past, religion is being used as both a      │ ', religion is being used as both a weapon and a  │
│         │ weapon and a shield by those seeking to deny      │ shield by those seeking to deny others equality.' │
│         │ others equality'                                  │                                                   │
├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤
│ 1.389   │ ' in the fact that no ancient American steel      │ ' fact that no ancient American steel tools or    │
│         │ tools or weapons (or traces or evidence of them   │ weapons (or traces or evidence of them or their   │
│         │ or their manufacture'                             │ manufacture) have'                                │
├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤
│ 1.381   │ ' with a gun. The body riddled with 9mm bullets   │ ' The body riddled with 9mm bullets and the       │
│         │ and the presence of 9mm shell casings on'         │ presence of 9mm shell casings on the floor is     │
│         │                                                   │ sufficient'                                       │
├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤
│ 1.186   │ '↵Unicorn Mask and Pyjamas and weapons and Sticky │ 'Unicorn Mask and Pyjamas and weapons and Sticky  │
│         │ bombs.↵↵Thanks @Rock'                             │ bombs.↵↵Thanks @Rockstar'                         │
├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤
│ 1.045   │ ' minutes later, he heard shots being fired from  │ ' later, he heard shots being fired from          │
│         │ automatic rifles and cries of distress, the group │ automatic rifles and cries of distress, the group │
│         │ said.↵'                                           │ said.↵↵'                                          │
└─────────┴───────────────────────────────────────────────────┴───────────────────────────────────────────────────┘
Solution
def fetch_max_activating_examples_attn(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 250,
    k: int = 10,
    buffer: int = 10,
) -> list[AttnSeqDFA]:
    """
    Returns the max activating examples across a number of batches from the activations store.
    """
    sae_acts_pre_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_pre"
    v_hook_name = get_act_name("v", sae.cfg.hook_layer)
    pattern_hook_name = get_act_name("pattern", sae.cfg.hook_layer)
    data = []
for _ in tqdm(
        range(total_batches), desc="Computing activations for max activating examples (attn)"
    ):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_pre_hook_name, v_hook_name, pattern_hook_name],
        )
        acts = cache[sae_acts_pre_hook_name][..., latent_idx]  # [batch seq]
# Get largest indices (i.e. dest tokens), and the tokens at those positions (plus buffer)
        k_largest_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
        top_acts = index_with_buffer(acts, k_largest_indices).tolist()
        dest_toks_with_buffer = index_with_buffer(tokens, k_largest_indices, buffer=buffer)
        str_toks_dest_list = [model.to_str_tokens(toks) for toks in dest_toks_with_buffer]
# Get source token value vectors & dest-to-src attention patterns, for each of our chosen
        # destination tokens
        batch_indices, dest_pos_indices = k_largest_indices.unbind(-1)
        v = cache[v_hook_name][batch_indices]  # shape [k src n_heads d_head]
        pattern = cache[pattern_hook_name][batch_indices, :, dest_pos_indices]  # [k n_heads src]
# Multiply them together to get weighted value vectors, and reshape them to d_in = n_heads  d_head
        v_weighted = v  einops.rearrange(pattern, "k n src -> k src n 1")
        v_weighted = v_weighted.flatten(-2, -1)  # [k src d_in]
# Map through our SAE encoder to get direct feature attribution for each src token, and argmax over src tokens
        dfa = v_weighted @ sae.W_enc[:, latent_idx]  # shape [k src]
        src_pos_indices = dfa.argmax(dim=-1)
        src_toks_with_buffer = index_with_buffer(
            tokens, t.stack([batch_indices, src_pos_indices], -1), buffer=buffer
        )
        str_toks_src_list = [model.to_str_tokens(toks) for toks in src_toks_with_buffer]
# Add all this data to our list
        for act, str_toks_dest, str_toks_src, src_pos in zip(
            top_acts, str_toks_dest_list, str_toks_src_list, src_pos_indices
        ):
            data.append(
                AttnSeqDFA(
                    act=act,
                    str_toks_dest=str_toks_dest,  # top activating dest tokens, with buffer
                    str_toks_src=str_toks_src,  # top DFA src tokens for the dest token, with buffer
                    dest_pos=buffer,  # dest token is always in the middle of its buffer
                    src_pos=min(src_pos, buffer),  # deal with case where src token is near start
                )
            )
return sorted(data, key=lambda x: x.act, reverse=True)[:k]

Finding latents for features

In this section, you'll explore different methods (some causal, some not) for finding latents in SAEs corresponding to particular features. Key points:

  • You can look at max activating latents on some particular input prompt, this is basically the simplest thing you can do
  • Direct logit attribution (DLA) is a bit more refined; you can find latents which have a direct effect on specific logits
  • Ablation of SAE latents can help you find latents which are important in a non-direct way
  • ...but it's quite costly for a large number of latents, so you can use attribution patching as a cheaper linear approximation of ablation

We'll now proceed through a set of 3 different methods that can be used to find features which activate on a given concept or language structure. We'll focus on trying to find IOI features - features which seem to activate on the indirect object identification pattern. You might already be familiar with this via the IOI exercises in ARENA. If not, it's essentially sentences of the form "When John and Mary went to the shops, John gave the bag to" -> " Mary". Models like GPT2-Small are able to learn this pattern via the following algorithm:

  • Duplicate token heads in layers 0-3 attend from the second " John" token (we call it S2; the second subject) back to the first " John" token (S1), and store the fact that it's duplicated.
  • S-inhibition heads in layers 7-8 attend from " to" back to the " John" token, and store information that this token is duplicated.
  • Name-mover heads in layers 9-10 attend from " to" back to any non-duplicated names (using Q-composition with the output of the S-Inhibition heads to avoid attending to the duplicated " John" tokens). So they attend to " Mary", and move this information into the unembedding space to be used as the model's prediction.

Sadly, our SAEs aren't yet advanced enough to pick up on S-inhibition features. As discussed above, these mid-layer heads which read from & write to subspaces containing intermediate representations are pretty difficult to interpret. In fact, reading this section of the LessWrong post analyzing our GPT2-Small attention SAEs, we find that the worst layers for "% alive features interpretable" as well as "loss recovered" are around 7 and 8, which is precisely the location of our S-Inhibition heads!

However, the authors of that post were able to find duplicate token features as well as name-mover featires, and in the following exercises we'll replicate their work!

Before we start, let's first make sure that the model can actually solve this sentence. To make our results a bit more robust (e.g. so we're not just isolating "gender features" or something), we'll control by using 4 different prompts: with "John" and "Mary" as answers flipped around, and also with the sentence structure flipped around (ABBA vs ABAB). The code below also gives you the logits_to_ave_logit_diff function, which you might find useful in some later exercises.

names = [" John", " Mary"]
name_tokens = [gpt2.to_single_token(name) for name in names]

prompt_template = "When{A} and{B} went to the shops,{S} gave the bag to"
prompts = [
    prompt_template.format(A=names[i], B=names[1 - i], S=names[j])
    for i, j in itertools.product(range(2), range(2))
]
correct_answers = names[::-1] * 2
incorrect_answers = names * 2
correct_toks = gpt2.to_tokens(correct_answers, prepend_bos=False)[:, 0].tolist()
incorrect_toks = gpt2.to_tokens(incorrect_answers, prepend_bos=False)[:, 0].tolist()


def logits_to_ave_logit_diff(
    logits: Float[Tensor, "batch seq d_vocab"],
    correct_toks: list[int] = correct_toks,
    incorrect_toks: list[int] = incorrect_toks,
    reduction: Literal["mean", "sum"] | None = "mean",
    keep_as_tensor: bool = False,
) -> list[float] | float:
    """
    Returns the avg logit diff on a set of prompts, with fixed s2 pos and stuff.
    """
    correct_logits = logits[range(len(logits)), -1, correct_toks]
    incorrect_logits = logits[range(len(logits)), -1, incorrect_toks]
    logit_diff = correct_logits - incorrect_logits
    if reduction is not None:
        logit_diff = logit_diff.mean() if reduction == "mean" else logit_diff.sum()
    return logit_diff if keep_as_tensor else logit_diff.tolist()


# Testing a single prompt (where correct answer is John), verifying model gets it right
test_prompt(prompts[1], names, gpt2)

# Testing logits over all 4 prompts, verifying the model always has a high logit diff
logits = gpt2(prompts, return_type="logits")
logit_diffs = logits_to_ave_logit_diff(logits, reduction=None)
print(
    tabulate(
        zip(prompts, correct_answers, logit_diffs),
        headers=["Prompt", "Answer", "Logit Diff"],
        tablefmt="simple_outline",
        numalign="left",
        floatfmt="+.3f",
    )
)
Tokenized prompt: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' Mary', ' gave', ' the', ' bag', ' to']
Tokenized answers: [[' John'], [' Mary']]

Performance on answer tokens:
Rank: 0        Logit: 18.03 Prob: 69.35% Token: | John|
Rank: 3        Logit: 14.83 Prob:  2.82% Token: | Mary|

Top 0th token. Logit: 18.03 Prob: 69.35% Token: | John|
Top 1th token. Logit: 15.53 Prob:  5.67% Token: | them|
Top 2th token. Logit: 15.28 Prob:  4.42% Token: | the|
Top 3th token. Logit: 14.83 Prob:  2.82% Token: | Mary|
Top 4th token. Logit: 14.16 Prob:  1.44% Token: | her|
Top 5th token. Logit: 13.94 Prob:  1.16% Token: | him|
Top 6th token. Logit: 13.72 Prob:  0.93% Token: | a|
Top 7th token. Logit: 13.68 Prob:  0.89% Token: | Joseph|
Top 8th token. Logit: 13.61 Prob:  0.83% Token: | Jesus|
Top 9th token. Logit: 13.34 Prob:  0.64% Token: | their|

Ranks of the answer tokens: [[(' John', 0), (' Mary', 3)]]

┌────────────────────────────────────────────────────────────┬──────────┬──────────────┐
│ Prompt                                                     │ Answer   │ Logit Diff   │
├────────────────────────────────────────────────────────────┼──────────┼──────────────┤
│ When John and Mary went to the shops, John gave the bag to │ Mary     │ +3.337       │
│ When John and Mary went to the shops, Mary gave the bag to │ John     │ +3.202       │
│ When Mary and John went to the shops, John gave the bag to │ Mary     │ +3.918       │
│ When Mary and John went to the shops, Mary gave the bag to │ John     │ +2.220       │
└────────────────────────────────────────────────────────────┴──────────┴──────────────┘

Exercise - verify model + SAEs can still solve this

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-15 minutes on this exercise.

We want to make sure that the logit diff for the model is still high when we substitute in SAEs for a given layer. If this isn't the case, then that means our SAEs aren't able to implement the IOI circuit, and there's no reason to expect we'll find any interesting features!

You should use the with model.saes context manager (or whatever your preferred way of running with SAEs is) to get the average logit diff over the 4 prompts, for each layer's attention SAE. Verify that this difference is still high, i.e. the SAEs don't ruin performance.

# YOUR CODE HERE - verify model + SAEs can still solve this
Click to see the expected output
┏━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Ablation   ┃ Logit diff ┃ % of clean ┃
┡━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ Clean      │ +3.169     │ 100.0%     │
│ SAE in L00 │ +3.005     │ 94.8%      │
│ SAE in L01 │ +3.164     │ 99.8%      │
│ SAE in L02 │ +3.155     │ 99.5%      │
│ SAE in L03 │ +2.991     │ 94.4%      │
│ SAE in L04 │ +2.868     │ 90.5%      │
│ SAE in L05 │ +3.308     │ 104.4%     │
│ SAE in L06 │ +2.872     │ 90.6%      │
│ SAE in L07 │ +2.565     │ 80.9%      │
│ SAE in L08 │ +2.411     │ 76.1%      │
│ SAE in L09 │ +3.030     │ 95.6%      │
│ SAE in L10 │ +3.744     │ 118.1%     │
│ SAE in L11 │ +3.809     │ 120.2%     │
└────────────┴────────────┴────────────┘
Solution
logits = gpt2(prompts, return_type="logits")
clean_logit_diff = logits_to_ave_logit_diff(logits)
table = Table("Ablation", "Logit diff", "% of clean")
table.add_row("Clean", f"{clean_logit_diff:+.3f}", "100.0%")
for layer in range(gpt2.cfg.n_layers):
    with gpt2.saes(saes=[attn_saes[layer]]):
        logits = gpt2(prompts, return_type="logits")
        logit_diff = logits_to_ave_logit_diff(logits)
        table.add_row(
            f"SAE in L{layer:02}",
            f"{logit_diff:+.3f}",
            f"{logit_diff / clean_logit_diff:.1%}",
        )
rprint(table)
Discussion of results

You should find that most layers have close to 100% recovery, even some layers like 10 and 11 which increase the logit diff when they're substituted in. You shouldn't read too much into that in this case, because we're only working with a small dataset and so some noise is expected.

The worst layers in terms of logit diff recovery are 7 and 8, which makes sense given that these are the layers with our S-Inhibition heads. The fact that the logit diff is still positive for both, and remains positive when you substitute in both layer 7 and 8 SAEs at once (although it does drop to 58% of the clean logit diff) means that these SAEs are still presumably capturing some amount of the S-Inhibition heads' behaviour - however that doesn't guarantee we have monosemantic S-Inhibition features, which in fact we don't.

Now we've done this, it's time for exercises!

Exercise - find name mover features with max activations

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-15 minutes on this exercise.

Now we've established that the model can still solve the task with SAEs substituted, we're ready to start looking for features!

In our first exercise, we'll look for name mover features - features which activate on the final token in the prompt, and seem to predict the IO token comes next. Since attention heads in layer 9 are name movers (primarily 9.6 and 9.9), we'll be looking at our layer 9 SAE, i.e. attn_saes[9]. You should fill in the cell below to create a plot of latent activations at the final token, averaged over all 4 prompts, and also display the dashboards of the top 3 ranking latents. (If you want some help, you can borrow bits of code from earlier - specifically, the 3rd code cell of the "Running SAEs" section.) Do you find any latents that seem like they might correspond to name mover features?

# YOUR CODE HERE - find name mover latents by using max activations
Click to see the expected output



Spoiler - what you should find

You should return the features 11368, 18767 and 3101.

These all appear to be name movers:

- 11368 is a name mover for "John" (i.e. it always attends from a token before "John" could appear, back to a previous instance of "John") - 18767 is a name mover for "Mary" - 3101 is a name mover for "Jesus"

Note that the activations of 11368 and 18767 should be much larger than the activations of any other feature.

Solution
layer = 9
# Compute mean post-ReLU SAE activations at last token posn
_, cache = gpt2.run_with_cache_with_saes(prompts, saes=[attn_saes[layer]])
sae_acts_post = cache[f"{attn_saes[layer].cfg.hook_name}.hook_sae_acts_post"][:, -1].mean(0)
# Plot the activations
px.line(
    sae_acts_post.cpu().numpy(),
    title=f"Activations at the final token position ({sae_acts_post.nonzero().numel()} alive)",
    labels={"index": "Latent", "value": "Activation"},
    template="ggplot2",
    width=1000,
).update_layout(showlegend=False).show()
# Print the top 3 latents, and inspect their dashboards
for act, ind in zip(*sae_acts_post.topk(3)):
    print(f"Latent {ind} had activation {act:.2f}")
    display_dashboard(
        sae_release="gpt2-small-hook-z-kk",
        sae_id=f"blocks.{layer}.hook_z",
        latent_idx=int(ind),
    )

Exercise - identify name mover heads

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to ~10 minutes on this exercise.

The IOI paper found that heads 9.6 and 9.9 were the primary name movers. Can you verify this by looking at the features' weight distribution over heads, i.e. seeing which heads these particular features have the largest exposure to?

# YOUR CODE HERE - verify model + SAEs can still solve this
Click to see the expected output
┏━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Head ┃ Latent 18767  ┃ Latent 10651  ┃
┡━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ 9.0  │ 5.36%         │ 7.47%         │
│ 9.1  │ 2.71%         │ 2.50%         │
│ 9.2  │ 5.00%         │ 6.66%         │
│ 9.3  │ 4.29%         │ 5.14%         │
│ 9.4  │ 6.16%         │ 6.09%         │
│ 9.5  │ 3.84%         │ 4.97%         │
│ 9.6  │ 13.17%        │ 18.16%        │
│ 9.7  │ 4.97%         │ 6.60%         │
│ 9.8  │ 6.14%         │ 8.20%         │
│ 9.9  │ 42.66%        │ 26.13%        │
│ 9.10 │ 2.23%         │ 3.49%         │
│ 9.11 │ 3.47%         │ 4.58%         │
└──────┴───────────────┴───────────────┘
Hint - what calculation you should perform

Each feature has an associated decoder weight sae.W_dec[:, feature_idx] of shape (d_in=n_heads*d_head,). We can measure the norm of each (d_head,)-length vector (representing the exposure to each attention head in the model), and see which head has the largest fraction of the total norm.

Solution
features = [18767, 10651]
decoder_weights = einops.rearrange(
    attn_saes[layer].W_dec[features],
    "feats (n_heads d_head) -> feats n_heads d_head",
    n_heads=model.cfg.n_heads,
)
norm_per_head = decoder_weights.pow(2).sum(-1).sqrt()
norm_frac_per_head = norm_per_head / norm_per_head.sum(-1, keepdim=True)
table = Table("Head", [f"Feature {i}" for i in features])
for i in range(model.cfg.n_heads):
    table.add_row(
        f"9.{i}", [f"{frac:.2%}" for frac in norm_frac_per_head[:, i].tolist()]
    )
rprint(table)

You should find that both these features have largest exposure to 9.9, and second largest to 9.6.

Direct logit attribution

A slightly more refined technique than looking for max activating latents is to look fonr ones which have a particular effect on some token logits, or logit difference. In the context of the IOI circuit, we look at the logit difference between the indirect object and the subject token (which we'll abbreviate as IO - S). For example, if a latent's output into the residual stream has a very large value of IO - S when we apply the model's unembedding matrix W_U, this suggests that the latent might have been causally important for identifying that the correct answer was "Mary" rather than "John" in sentences like "When John and Mary went to the shops, John gave the bag to ???". This is essentially the same as DLA you might have already seen in previous contexts or ARENA exercises (e.g. DLA for attention heads that we covered in the IOI exercises).

Exercise - find name mover features with DLA

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 20-40 minutes on this exercise.

Write code below to perform this calculation and visualize the results (line chart & dashboards for the top-scoring latents). Your code here should look a lot like your code for finding name movers via max activations, except rather than argmaxing over avg latent activations you should be argmaxing over the latents' average DLA onto the IO - S direction.

Help - I'm not sure what this calculation should look like.

Start with the latent direction in the decoder, which is a vector of length d_sae = n_heads * d_head. We interpret this as a vector in model's concatenated z vectors (i.e. the linear combinations of value vectors you get before applying the projection matrix), so taking the projection matrix W_O of shape (n_heads, d_head, d_model) and flattening it & multiplying it by our decoder vector (without reducing over the d_sae dimension) gives us a matrix of shape (d_sae, d_model) where each row is a vector that gets added into the residual stream when that particular latent fires. Now, we can take this matrix and pass it through our unembedding W_U to get a matrix of shape (d_sae, d_vocab), and then index into it to get the logit difference for each latent (or if we want to be more efficient, we can just dot product our matrix with the vector W_U[:, IO] - W_U[:, S] and directly get a vector of d_sae logit differences).

This is how we get DLA for latents for a single prompt. We can parallelize this to get all 4 prompts at once (i.e. output of shape (4, d_sae)), then average over the batch dimension to get a single DLA vector.

# YOUR CODE HERE - verify model + SAEs can still solve this
Click to see the expected output



Solutions & discussion

Solution code:

# Get logits in the "IO - S" direction, of shape (4, d_model)
logit_direction = gpt2.W_U.T[correct_toks] - gpt2.W_U.T[incorrect_toks]
# Get latent activations, of shape (4, d_sae)
sae_acts_post_hook_name = f"{attn_saes[layer].cfg.hook_name}.hook_sae_acts_post"
_, cache = gpt2.run_with_cache_with_saes(
    prompts, saes=[attn_saes[layer]], names_filter=[sae_acts_post_hook_name]
)
sae_acts_post = cache[sae_acts_post_hook_name][:, -1]
# Get values written to the residual stream by each latent
sae_resid_dirs = einops.einsum(
    sae_acts_post,
    attn_saes[layer].W_dec,
    gpt2.W_O[layer].flatten(0, 1),
    "batch d_sae, d_sae nheads_x_dhead, nheads_x_dhead d_model -> batch d_sae d_model",
)
# Get DLA by computing average dot product of each latent's residual dir onto the logit dir
dla = (sae_resid_dirs  logit_direction[:, None, :]).sum(-1).mean(0)
# Display the results
px.line(
    dla.cpu().numpy(),
    title="Latent DLA (in IO - S direction) at the final token position",
    labels={"index": "Latent", "value": "DLA"},
    template="ggplot2",
    width=1000,
).update_layout(showlegend=False).show()
# Print the top 3 features, and inspect their dashboards
for value, ind in zip(dla.topk(3)):
    print(f"Latent {ind} had max act {sae_acts_post[:, ind].max():.2f}, mean DLA {value:.2f}")
    display_dashboard(
        sae_release="gpt2-small-hook-z-kk",
        sae_id=f"blocks.{layer}.hook_z",
        latent_idx=int(ind),
    )

You should get the same top 2 features as before, but the difference is that now other features are much weaker: no feature has a DLA greater than 10% of our top 2 features. The new 3rd feature we find seems like it might be a name mover for " Joan", and is only activating (very weakly) because the names " John" and " Joan" are similar.

This highlights an important property we might expect from features: sparsity of their functional form. SAEs are built on the assumption of a certain kind of sparsity (i.e. viewing a feature as a property of the dataset which is only present in a small fraction of all tokens), but a true feature should also be considered as a functional form, with functional sparsity. In other words, any feature can be described as doing a very narrow set of things in the model. Here, we see that by filtering not just for features which were active on a particular token, but which were active and contributing to the IO - S direction, we were able to more effectively isolate our name movers.

This idea of contrastive pairs is one that comes up again and again.

Exercise (optional) - replicate these results, with sentiment

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵⚪⚪⚪

You can replicate these results the sentiment-based prompts from Anthropic's Scaling Monosemanticity paper:

prompt = 'John says, "I want to be alone right now." John feels very'
correct_completion = " sad"
incorrect_completion = " happy"

test_prompt(prompt, correct_completion, gpt2)
test_prompt(prompt, incorrect_completion, gpt2)
Tokenized prompt: ['<|endoftext|>', 'John', ' says', ',', ' "', 'I', ' want', ' to', ' be', ' alone', ' right', ' now', '."', ' John', ' feels', ' very']
Tokenized answer: [' sad']

Performance on answer token:
Rank: 4        Logit: 15.65 Prob:  3.37% Token: | sad|

Top 0th token. Logit: 16.83 Prob: 11.03% Token: | lonely|
Top 1th token. Logit: 16.27 Prob:  6.31% Token: | alone|
Top 2th token. Logit: 16.02 Prob:  4.88% Token: | uncomfortable|
Top 3th token. Logit: 15.66 Prob:  3.41% Token: | much|
Top 4th token. Logit: 15.65 Prob:  3.37% Token: | sad|
Top 5th token. Logit: 15.21 Prob:  2.17% Token: | guilty|
Top 6th token. Logit: 15.19 Prob:  2.13% Token: | bad|
Top 7th token. Logit: 15.13 Prob:  2.02% Token: |,|
Top 8th token. Logit: 15.07 Prob:  1.90% Token: | comfortable|
Top 9th token. Logit: 14.85 Prob:  1.52% Token: | strongly|

Ranks of the answer tokens: [(' sad', 4)]

Tokenized prompt: ['<|endoftext|>', 'John', ' says', ',', ' "', 'I', ' want', ' to', ' be', ' alone', ' right', ' now', '."', ' John', ' feels', ' very'] Tokenized answer: [' happy'] Performance on answer token: Rank: 11 Logit: 14.68 Prob: 1.28% Token: | happy| Top 0th token. Logit: 16.83 Prob: 11.03% Token: | lonely| Top 1th token. Logit: 16.27 Prob: 6.31% Token: | alone| Top 2th token. Logit: 16.02 Prob: 4.88% Token: | uncomfortable| Top 3th token. Logit: 15.66 Prob: 3.41% Token: | much| Top 4th token. Logit: 15.65 Prob: 3.37% Token: | sad| Top 5th token. Logit: 15.21 Prob: 2.17% Token: | guilty| Top 6th token. Logit: 15.19 Prob: 2.13% Token: | bad| Top 7th token. Logit: 15.13 Prob: 2.02% Token: |,| Top 8th token. Logit: 15.07 Prob: 1.90% Token: | comfortable| Top 9th token. Logit: 14.85 Prob: 1.52% Token: | strongly| Ranks of the answer tokens: [(' happy', 11)]

The model seems to understand that the sentiment of this sentence is negative, because words like "lonely", "alone" and "uncomfortable" are the top predictions (and there's a positive logit diff between "sad" and "happy"). Can you find features which seem to represent this sentiment? You might want to go back to our original layer-7 sae rather than using the attention SAEs.

How much better is DLA than just taking argmax over feature activations? Do either / both of these techniques find negative sentiment features?

# YOUR CODE HERE - replicate these results with sentiment prompts
Click to see the expected output



Solution
logit_dir = (
    gpt2.W_U[:, gpt2.to_single_token(correct_completion)]
    - gpt2.W_U[:, gpt2.to_single_token(incorrect_completion)]
)
_, cache = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])
sae_acts_post = cache[f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post"][0, -1, :]
sae_attribution = sae_acts_post  (gpt2_sae.W_dec @ logit_dir)
px.line(
    sae_attribution.cpu().numpy(),
    title=f"Attributions for (sad - happy) at the final token position ({sae_attribution.nonzero().numel()} non-zero attribution)",
    labels={"index": "Latent", "value": "Attribution"},
    template="ggplot2",
    width=1000,
).update_layout(showlegend=False).show()
for attr, ind in zip(sae_attribution.topk(3)):
    print(f"#{ind} had attribution {attr:.2f}, activation {sae_acts_post[ind]:.2f}")
    display_dashboard(latent_idx=int(ind))

Ablation

Techniques like DLA work fine when you expect your features to have a significant direct effect on the model's output - but how about when you think your features are causally important, but not in a direct way? This is where techniques like ablation or activation patching / path patching come in. You can causally intervene on your model during a forward pass, and set activations to zero (or to their value on some different distribution), and see how some downstream metric (e.g. loss, or something more specific to the task you're investigating) changes.

For more on ablation and activation / path patching, you can look at the ARENA IOI exercises, or the later exercise sets in this chapter on sparse feature circuits. In these exercises however, we'll mostly keep things simple - we'll just focus on single feature ablation.

Exercise - find duplicate token features with ablation

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 15-30 minutes on this exercise.

In this exercise, we'll try to find duplicate token features with ablation. These are features which fire in the early layers of the model (specifically layers 0 & 3) and seem to attend from a token back to previous instances of that same token. These don't directly affect the logit output, but are still important parts of the IOI circuit, so we should find that ablating them has a large effect on the model's logit diff.

You should fill in the code below to perform this ablation. You should:

  • Complete the ablate_sae_feature hook function, which sets the SAE activations at a given position & feature to zero (and returns the modified activations),
  • Create a tensor ablation_effects of shape (d_sae,), where the i-th element is the change in logit diff when the i-th feature is ablated at the s2_pos position, averaged across the 4 prompts.

You'll find the function model.run_with_hooks_with_saes useful for this. It takes a list of saes like you've done in previous exercises, as well as a list of fwd_hooks which are lists of (hook_name, hook_fn) tuples like you might have come across in previous work with TransformerLens.

layer = 3
s2_pos = 10
assert gpt2.to_str_tokens(prompts[0])[s2_pos] == " John"


def ablate_sae_latent(
    sae_acts: Tensor,
    hook: HookPoint,
    latent_idx: int | None = None,
    seq_pos: int | None = None,
) -> Tensor:
    """
    Ablate a particular latent at a particular sequence position. If either argument is None, we
    ablate at all latents / sequence positions respectively.
    """
    raise NotImplementedError()


# YOUR CODE HERE - replicate these results with sentiment prompts
Click to see the expected output



Help - I'm not sure how to compute ablation_effects.

First, you can get a list of all the features which are non-zero on at least one of the S2 positions across the 4 prompts. This prevents you having to run through all features (since obviously if a feature isn't active then ablating it won't do anything!).

Next, you can iterate through those features, ablating them one at a time, getting the logits using model.run_with_hooks_with_saes, and then getting the logit diff using the logits_to_ave_logit_diff function.

Finally, you can store the difference between this and the logit diff you get when running with SAEs, but no ablation.

Solution (and explanation)

Solution code:

layer = 3
s2_pos = 10
assert gpt2.to_str_tokens(prompts[0])[s2_pos] == " John"
def ablate_sae_latent(
    sae_acts: Tensor,
    hook: HookPoint,
    latent_idx: int | None = None,
    seq_pos: int | None = None,
) -> Tensor:
    """
    Ablate a particular latent at a particular sequence position. If either argument is None, we
    ablate at all latents / sequence positions respectively.
    """
    sae_acts[:, seq_pos, latent_idx] = 0.0
    return sae_acts
_, cache = gpt2.run_with_cache_with_saes(prompts, saes=[attn_saes[layer]])
acts = cache[hook_sae_acts_post := f"{attn_saes[layer].cfg.hook_name}.hook_sae_acts_post"]
alive_latents = (acts[:, s2_pos] > 0.0).any(dim=0).nonzero().squeeze().tolist()
ablation_effects = t.zeros(attn_saes[layer].cfg.d_sae)
logits = gpt2.run_with_saes(prompts, saes=[attn_saes[layer]])
logit_diff = logits_to_ave_logit_diff(logits)
for i in tqdm(alive_latents, desc="Computing causal effects for ablating each latent"):
    logits_with_ablation = gpt2.run_with_hooks_with_saes(
        prompts,
        saes=[attn_saes[layer]],
        fwd_hooks=[
            (hook_sae_acts_post, partial(ablate_sae_latent, latent_idx=i, seq_pos=s2_pos))
        ],
    )
logit_diff_with_ablation = logits_to_ave_logit_diff(logits_with_ablation)
    ablation_effects[i] = logit_diff - logit_diff_with_ablation
px.line(
    ablation_effects.cpu().numpy(),
    title=f"Causal effects of latent ablation on logit diff ({len(alive_latents)} alive)",
    labels={"index": "Latent", "value": "Causal effect on logit diff"},
    template="ggplot2",
    width=1000,
).update_layout(showlegend=False).show()
# Print the top 5 latents, and inspect their dashboards
for value, ind in zip(*ablation_effects.topk(3)):
    print(f"#{ind} had mean act {acts[:, s2_pos, ind].mean():.2f}, causal effect {value:.2f}")
    display_dashboard(
        sae_release="gpt2-small-hook-z-kk",
        sae_id=f"blocks.{layer}.hook_z",
        latent_idx=int(ind),
    )

In layer 3, you should find a few features that appear to be duplicate token features, some of these (e.g. the top two, 7803 and 10137) having particularly strong activations on names or other capitalized words. Using the same method as we did for our name mover features, we can see that these features have largest exposure to head 3.0 (which is what we expect from the IOI paper).

Strangely, although layer 0 also has duplicate token features which can be found using the same method, their ablation seems to increase the logit diff rather than decrease it. I'm not exactly sure why, possibly a larger dataset than just 4 prompts with the same 2 names would give different results - this is left as an exercise to the reader!

Attribution patching

Ablation is one way to measure the downstream effect of a particular feature. Another is attribution patching, which is a gradient-based attribution technique which can serve as a good approximation to ablation or other kinds of activation patching, allowing for more efficient and scalable circuit analysis with SAEs. This is particularly valuable given how many SAE features there are (many factors more than the number of neurons or residual stream directions in your base model).

The formula works as follows: for any inputs $x_{\text{clean}}$ and $x_{\text{corrupted}}$, plus some scalar metric function $M = M(x)$ (in our case $M$ is the logit difference), we can estimate the effect of patching from clean -> corrupt activations using a first-order approximation:

$$ M_{\text{clean}} - M_{\text{corrupted}} \approx (\hat{x}_{\text{clean}} - \hat{x}_{\text{corrupted}}) \cdot \nabla_x M_{\text{corrupted}} $$

Furthermore, if we want to estimate the effect of ablation for multiple independent components at once (independent in the sense that one isn't a function of the other, e.g. these could be feature activations in the same layer of a model), then we can perform the elementwise multiplication:

$$ \hat{M}_{\text{clean}} - M_\text{corrupted} \approx (\hat{x}_{\text{clean}} - \hat{x}_{\text{corrupted}}) \times \nabla_x M_{\text{corrupted}} $$

In this instance, we'll be approximating the effect of ablation (i.e. setting $\hat{x}_{\text{corrupted}}$ to zero) on the logit difference $M_{\text{corrupted}} - M_{\text{clean}}$. Since ablation is a very OOD operation (in comparison to e.g. activation patching from a slightly modified distribution), we'll use the easier version of this formula where we take the clean gradient rather than the corrupted one:

$$ \hat{M}_{\text{clean}} - M_\text{ablated} \approx \hat{x}_{\text{clean}} \times \nabla_x M_{\text{clean}} $$

Fortunately, TransformerLens makes computing gradients easy with the model.add_hook method. This takes 3 important arguments:

  • name, which can be a string or a filter function mapping strings to booleans
  • hook, our hook function
  • dir, which is either the string "fwd" or "bwd" indicating whether we want a forward or backward hook

Backward hooks work exactly the same as forward hooks, except they're called when you call .backward() on a tensor you've computed. The hook function still takes 2 arguments: a tensor and the hook itself, the only difference is that now this tensor will be the gradient of the output with respect to the activation, rather than the activation itself. The hook point names for fwd and bwd functions the same.

This is also easy when adding SAEs into the mix! They are automatically spliced into the computational graph when you call model.add_sae (or use them in a context manager), allowing us to implement attribution patching easily.

Exercise - compare ablation to attribution patching

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 10-20 minutes on this exercise. Understanding what's happening is the important thing here; the code itself is pretty short.

You should fill in the code below (where it says # YOUR CODE HERE), to compare the effect of ablation and attribution patching on our features. Your final output will be a line plot of ablation effect vs attribution patching value for each live feature, hopefully showing an almost exact relationship (you can see the solutions Colab for what this should look like).

We've given you the function get_cache_fwd_and_bwd to help you - this will return the forward and backward caches for a model given a particular metric (and a particular SAE or list of SAEs to add).

A few notes on the code below:

  • You'll need to have computed the ablation_effects tensor from the previous exercise, otherwise the code won't run.
  • For our metric function passed to get_cache_fwd_and_bwd, we've used reduction="sum" rather than "mean". Can you see why?
Answer: why we reduce using sum not mean

To return to the formula above, we're computing the value $\hat{x}_{\text{clean}} \times \nabla_x M_{\text{clean}}$, where $M$ is the logit diff metric. We do this by effectively computing a vector of 4 of these values (one for each prompt) in a single forward & backward pass, then averaging over that vector. But if we used the mean over prompts for our metric, then each sequence would only get 1/4 of the gradient that it would if we were running it on its own, meaning we'd actually be averaging over 4 terms of the form $\hat{x}_{\text{clean}} \times \nabla_x M_{\text{clean}} / 4$, and the metric we'd eventually be estimating the gradient wrt would be 1/4 of the logit diff rather than the full logit diff.

This is related to something we saw in a previous section, where we were training multiple instances of our toy model at once: we need the loss function to be a sum of losses over instances, so that backpropagating wrt that loss function once is equivalent to backpropagating wrt each instance's loss function individually.

def get_cache_fwd_and_bwd(
    model: HookedSAETransformer, saes: list[SAE], input, metric
) -> tuple[ActivationCache, ActivationCache]:
    """
    Get forward and backward caches for a model, given a metric.
    """
    filter_sae_acts = lambda name: "hook_sae_acts_post" in name

    # This hook function will store activations in the appropriate cache
    cache_dict = {"fwd": {}, "bwd": {}}

    def cache_hook(act, hook, dir: Literal["fwd", "bwd"]):
        cache_dict[dir][hook.name] = act.detach()

    with model.saes(saes=saes):
        # We add hooks to cache values from the forward and backward pass respectively
        with model.hooks(
            fwd_hooks=[(filter_sae_acts, partial(cache_hook, dir="fwd"))],
            bwd_hooks=[(filter_sae_acts, partial(cache_hook, dir="bwd"))],
        ):
            # Forward pass fills the fwd cache, then backward pass fills the bwd cache (we don't
            # care about metric value)
            metric(model(input)).backward()

    return (
        ActivationCache(cache_dict["fwd"], model),
        ActivationCache(cache_dict["bwd"], model),
    )


clean_logits = gpt2.run_with_saes(prompts, saes=[attn_saes[layer]])
clean_logit_diff = logits_to_ave_logit_diff(clean_logits)

t.set_grad_enabled(True)
clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(
    gpt2,
    [attn_saes[layer]],
    prompts,
    lambda logits: logits_to_ave_logit_diff(logits, keep_as_tensor=True, reduction="sum"),
)
t.set_grad_enabled(False)

# YOUR CODE HERE - compute `attribution_values` from the clean activations & clean grad cache
attribution_values = ...

# Visualize results
px.scatter(
    pd.DataFrame(
        {
            "Ablation": ablation_effects[alive_latents].cpu().numpy(),
            "Attribution Patching": attribution_values.cpu().numpy(),
            "Latent": alive_latents,
        }
    ),
    x="Ablation",
    y="Attribution Patching",
    hover_data=["Latent"],
    title="Attribution Patching vs Ablation",
    template="ggplot2",
    width=800,
    height=600,
).add_shape(
    type="line",
    x0=attribution_values.min(),
    x1=attribution_values.max(),
    y0=attribution_values.min(),
    y1=attribution_values.max(),
    line=dict(color="red", width=2, dash="dash"),
).show()
Click to see the expected output
Solution
# Extract activations and gradients
hook_sae_acts_post = f"{attn_saes[layer].cfg.hook_name}.hook_sae_acts_post"
clean_sae_acts_post = clean_cache[hook_sae_acts_post]
clean_grad_sae_acts_post = clean_grad_cache[hook_sae_acts_post]
# Compute attribution values for all features, then index to get live ones
attribution_values = (clean_grad_sae_acts_post * clean_sae_acts_post)[:, s2_pos, alive_features].mean(0)

GemmaScope

Before introducing the final set of exercises in this section, we'll take a moment to talk about a recent release of sparse autoencoders from Google DeepMind, which any would-be SAE researchers should be aware of. From their associated blog post published on 31st July 2024:

Today, we’re announcing Gemma Scope, a new set of tools to help researchers understand the inner workings of Gemma 2, our lightweight family of open models. Gemma Scope is a collection of hundreds of freely available, open sparse autoencoders (SAEs) for Gemma 2 9B and Gemma 2 2B.

If you're interested in analyzing large and well-trained sparse autoencoders, there's a good chance that GemmaScope is the best available release you could be using.

Let's first load in the SAE. We're using the canonical recommendations for working with GemmaScope SAEs, which were chosen based on their L0 values (see the exercises on SAE training for more about how to think about these kinds of metrics!). This particular SAE was trained on the residual stream of the 20th layer of the Gemma-2-2B model, has a width of 16k, and uses a JumpReLU activation function - see the short section at the end for more on this activation function, although you don't really need to worry about the details now.

Note that you'll probably have to go through a couple of steps before gaining access to these SAE models. You should do the following:

  1. Visit the gemma-2b HuggingFace repo and click "Agree and access repository".
  2. When you've been granted access, create a read token in your user settings and copy it, then run the command huggingface-cli login --token <your-token-here> in your terminal (or alternatively you can just run huggingface-cli login then create a token at the link it prints for you, and pasrte it in).

Once you've done this, you should be able to load in your models as follows:

USING_GEMMA = os.environ.get("HUGGINGFACE_KEY") is not None

if USING_GEMMA:
    !huggingface-cli login --token {os.environ["HUGGINGFACE_KEY"]}

    gemma_2_2b = HookedSAETransformer.from_pretrained("gemma-2-2b", device=device)

    gemmascope_sae_release = "gemma-scope-2b-pt-res-canonical"
    gemmascope_sae_id = "layer_20/width_16k/canonical"
    gemma_2_2b_sae = SAE.from_pretrained(
        gemmascope_sae_release, gemmascope_sae_id, device=str(device)
    )[0]
else:
    print("Please supply your Hugging Face API key before running this cell")

You should inspect the configs of these objects, and make sure you roughly understand their structure. You can also try displaying a few latent dashboards, to get a sense of what the latents look like.

Help - I get the error "Not enough free disk space to download the file."

In this case, try and free up space by clearing your cache of huggingface models, by running huggingface-cli delete-cache in your terminal (you might have to pip install huggingface_hub[cli] first). You'll be shown an interface which you can navigate using the up/down arrow keys, press space to choose which models to delete, and then enter to confirm deletion.

If you still get the above error message after clearing your cache of all models you're no longer using (or you're getting other errors e.g. OOMs when you try to run the model), we recommend one of the following options:

  • Choosing a latent from the GPT2-Small model you've been working with so far, and doing the exercises with that instead (note that at time of writing there are no highly performant SAEs trained on GPT2-Medium, Large, or XL models, but this might not be the case when you're reading this, in which case you could try those instead!).
  • Using float16 precision for the model, rather than 32 (you can pass dtype="float16" to the from_pretrained method).
  • Using a more powerful machine, e.g. renting an A100 from vast.ai or using Google Colab Pro (or Pro+).

Feature Steering

In this section, you'll learn how to steer on latents to produce interesting model output. Key points:

  • Steering involves intervening during a forward pass to change the model's activations in the direction of a particular latent
  • The steering behaviour is sometimes unpredictable, and not always equivalent to "produce text of the same type as the latent strongly activates on"
  • Neuronpedia has a steering interface which allows you to steer without any code

Before we wrap up this set of exercises, let's do something fun!

Once we've found a latent corresponding to some particular feature, we can use it to steer our model, resulting in a corresponding behavioural change. You might already have come across this via Anthropic's viral Golden Gate Claude model. Steering simply involves intervening on the model's activations during a forward pass, and adding some multiple of a feature's decoder weight into our residual stream (or possibly scaling the component that was already present in the residual stream, or just clamping this component to some fixed value). When choosing the value, we are usually guided by the maximum activation of this feature over some distribution of text (so we don't get too OOD).

Sadly we can't quite replicate Golden Gate Claude with GemmaScope SAEs. There are some features which seem to fire on the word "Golden" especially in the context of titles like "Golden Gate Bridge" (e.g. feature 14667 in the layer 18 canonical 16k-width residual stream GemmaScope SAE, or feature 1566 in the layer 20 SAE), but these are mostly single-token features (i.e. they fire on just the word "Golden" rather than firing on context which discusses the Golden Gate Bridge), so their efficacy in causing these kinds of behavioural changes is limited. For example, imagine if you did really find a bigram feature that just caused the model to output "Gate" after "Golden" - steering on this would eventually just cause the model to output an endless string of "Gate" tokens (something like this in fact does happen for the 2 aforementioned features, and you can try it for yourself if you want). Instead, we want to look for a feature with a better consistent activation heuristic value - roughly speaking, this is the correlation between feature activations on adjacent tokens, so a high value might suggest a concept-level feature rather than a token-level one. Specifically, we'll be using a "dog feature" which seems to activate on discussions of dogs:

if USING_GEMMA:
    latent_idx = 12082

    display_dashboard(
        sae_release=gemmascope_sae_release, sae_id=gemmascope_sae_id, latent_idx=latent_idx
    )

Exercise - implement generate_with_steering

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-30 minutes on completing the set of functions below.

First, you should implement the basic function steering_hook below. This will be added to your model as a hook function during its forward pass, and it should add a multiple steering_coefficient of the steering vector (i.e. the decoder weight for this feature) to the activations tensor.

def steering_hook(
    activations: Float[Tensor, "batch pos d_in"],
    hook: HookPoint,
    sae: SAE,
    latent_idx: int,
    steering_coefficient: float,
) -> Tensor:
    """
    Steers the model by returning a modified activations tensor, with some multiple of the steering
    vector added to all sequence positions.
    """
    return activations + steering_coefficient * sae.W_dec[latent_idx]


if USING_GEMMA:
    tests.test_steering_hook(steering_hook, gemma_2_2b_sae)

You should now finish this exercise by implementing generate_with_steering. You can run this function to produce your own steered output text!

Help - I'm not sure about the model syntax for generating text with steering.

You can add a hook in a context manager, then steer like this:

with model.hooks(fwd_hooks=[(hook_name, steering_hook)]):
    output = model.generate(
        prompt,
        max_new_tokens=max_new_tokens,
        prepend_bos=sae.cfg.prepend_bos,
        **GENERATE_KWARGS
    )

Make sure you remember to use the prepend_bos argument - it can often be important for getting the right behaviour!

We've given you suggested sampling parameters in the GENERATE_KWARGS dict.

The output will by default be a string.

Help - I'm not sure what hook to add my steering hook to.

You should add it to sae.cfg.hook_name, since these are the activations that get reconstructed by the SAE.

Note that we can choose the value of steering_coefficient based on the maximum activation of the latent we're steering on (it's usually wise to choose quite close to the max activation, but not so far above that you steer the model far out of distribution - however this varies from latent to latent, e.g. in the case of this particular latent we'll find it still produces coherent output quite far above the max activation value). If we didn't have neuronpedia then we couldn't do this, and we'd be better off measuring the max activation over some suitably large dataset to guide what value to choose for our steering coefficient.

GENERATE_KWARGS = dict(temperature=0.5, freq_penalty=2.0, verbose=False)


def generate_with_steering(
    model: HookedSAETransformer,
    sae: SAE,
    prompt: str,
    latent_idx: int,
    steering_coefficient: float = 1.0,
    max_new_tokens: int = 50,
):
    """
    Generates text with steering. A multiple of the steering vector (the decoder weight for this
    latent) is added to the last sequence position before every forward pass.
    """
    raise NotImplementedError()


if USING_GEMMA:
    prompt = "When I look at myself in the mirror, I see"
    latent_idx = 12082

    no_steering_output = gemma_2_2b.generate(prompt, max_new_tokens=50, **GENERATE_KWARGS)

    table = Table(show_header=False, show_lines=True, title="Steering Output")
    table.add_row("Normal", no_steering_output)
    for i in tqdm(range(3), "Generating steered examples..."):
        table.add_row(
            f"Steered #{i}",
            generate_with_steering(
                gemma_2_2b,
                gemma_2_2b_sae,
                prompt,
                latent_idx,
                steering_coefficient=240.0,  # roughly 1.5-2x the latent's max activation
            ).replace("\n", "↵"),
        )
    rprint(table)
Click to see the expected output
                                                  Steering Output                                                  
┌────────────┬────────────────────────────────────────────────────────────────────────────────────────────────────┐
│ Normal     │ When I look at myself in the mirror, I see a beautiful woman.                                      │
│            │                                                                                                    │
│            │ I’m not perfect, but I’m pretty good looking.                                                      │
│            │                                                                                                    │
│            │ I have a round face and full lips. My eyes are deep set and my nose is small. My hair is light     │
│            │ brown with highlights of blonde and                                                                │
├────────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ Steered #0 │ When I look at myself in the mirror, I see a dog.↵I’s not like my parents are used to seeing a     │
│            │ person in the mirror, but they don’t see me as a dog either.↵↵My tail is always wagging and I have │
│            │ a big smile on my face because                                                                     │
├────────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ Steered #1 │ When I look at myself in the mirror, I see a lot of things.↵↵I see a dog-eared, wrinkled and       │
│            │ overweight owner of a small, fluffy and very well-trained dog.↵↵I am also the owner of a young     │
│            │ adult that is still learning about life.↵↵He’s                                                     │
├────────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ Steered #2 │ When I look at myself in the mirror, I see a person who loves to chase after her dreams.↵↵I’ve     │
│            │ been on a journey of learning and training for over 7 years now, and it’s been an incredible       │
│            │ journey.↵↵I’ve trained with some of the best trainers in                                           │
└────────────┴────────────────────────────────────────────────────────────────────────────────────────────────────┘
Solution
GENERATE_KWARGS = dict(temperature=0.5, freq_penalty=2.0, verbose=False)
def generate_with_steering(
    model: HookedSAETransformer,
    sae: SAE,
    prompt: str,
    latent_idx: int,
    steering_coefficient: float = 1.0,
    max_new_tokens: int = 50,
):
    """
    Generates text with steering. A multiple of the steering vector (the decoder weight for this
    latent) is added to the last sequence position before every forward pass.
    """
    _steering_hook = partial(
        steering_hook,
        sae=sae,
        latent_idx=latent_idx,
        steering_coefficient=steering_coefficient,
    )
with model.hooks(fwd_hooks=[(sae.cfg.hook_name, _steering_hook)]):
        output = model.generate(prompt, max_new_tokens=max_new_tokens, **GENERATE_KWARGS)
return output

Steering with neuronpedia

Neuronpedia actually has a steering interface, which you can use to see the effect of stering on particular latents without even writing any code! Visit the associated Neuronpedia page to try it out. You can hover over the "How it works" button to see what the interpretation of the different coefficients are in the steering API (it's pretty similar to how we've used them in our experiments).

Try experimenting with the steering API, with this latent and some others. You can also try some other models, like the instruction-tuned Gemma models from DeepMind. There are some interesting patterns that start appearing when we get to finetuned models, such as a divergence between what a latent seems to be firing on and the downstream effect of steering on that latent. For example, you might find latents which activate on certain kinds of harmful or offensive language, but which induce refusal behaviour when steered on: possibly those latents existed in the non-finetuned model and would have steered towards more harmful behaviour when steered on, but during finetuning their output behaviour was re-learned. This links to one key idea when doing latent interpretability: the duality between the view of latents as representations and latents as functions (see the section on circuits for more on this).

Other types of SAEs

This section introduces a few different SAE architectures, some of which will be explored in more detail in later sections. There are no exercises here, just brief descriptions. Key points:

  • Different activation functions / encoder architectures e.g. TopK, JumpReLU and Gated models can solve problems like feature suppression and the pressure for SAEs to be continuous in standard models
  • End-to-end SAEs are trained with a different loss function, encouraging them to learn features that are functionally useful for the model's output rather than just minimising MSE reconstruction error
  • Meta SAEs are SAEs trained to decompose SAE activations (since we might not always expect SAE latents to be monosemantic, for reasons like feature absorbtion)
  • Transcoders are a type of SAE which learn to reconstruct a model's computation (e.g. a sparse mapping from MLP input to MLP output) rather than just reconstructing activations; they can sometimes lead to easier circuit analysis

In this section, we'll touch briefly on other kinds of SAEs that we haven't discussed yet, some of which will be explored in later exercises. This section won't have any exercises; the purpose of it is just to try and paint more of a complete picture for anyone who is only completing up to the end of part 1 but not planning to do any later exercises.

The topics here are listed roughly in order of complexity, starting with relatively simple extensions to standard SAEs and ending with some more complex ideas which we'll return to in later exercises.

TopK, JumpReLU, Gated

These represent three different small modifications to the basic SAE architecture, all of which seem to offer improvements over standard SAEs. To address them each in turn:

TopK SAEs use a different activation function: rather than $z = \operatorname{ReLU}(W_{enc}(x-b_{dec}) + b_{enc})$, we compute them as $z = \operatorname{TopK}(W_{enc}(x-b_{dec}))$, where $\operatorname{TopK}$ returns the top $K$ elements of the input tensor, setting the rest to zero. This removes the need for the $L_1$ penalty (which tackles problems like feature suppression), and allows us to set the $L_0$ value directly rather than tuning it to some specific value. It can also be composed with arbitrary activation functions.

JumpReLU SAEs use the JumpReLU activation function in place of a regular ReLU. JumpReLU is just a ReLU with an extra step, and (often learnable) threshold parameter, i.e. $\operatorname{JumpReLU}_\theta(x) = xH(x-\theta)$ where $\theta$ is a learnable threshold parameter and $H$ is the Heaviside step function (equalling 1 when its argument is positive, and 0 otherwise).

Intuitively, why might we expect this to perform better than regular ReLU SAEs? One reason is that empirically, seem to "want to be binary". For instance, we often see features like "is this about a basketball" which are better thought of as "off" or "on" than occupying some continuous range from 0 to 1. In practice reconstructing the precise coefficients does matter, and they often seem important for indicating something like the model's confidence in a particular feature being present. But still, we'd ideally like an architecture which can learn this discontinuity.

JumpReLU is unfortunately difficult to train (since we have to resort to sneaky methods to get around the function's jump discontinuity, which makes it non-differentiable). As a result, many groups seem to have failed to replicate DeepMind's initial paper.

Gated SAEs are the most recent (as of writing) variant on the basic SAE architecture to be proposed, coming from another recent DeepMind paper. They offer the same jump-discontinuity benefit as JumpReLU (in fact you can show that with a bit of weight tying Gated SAEs are equivalent to JumpReLUs), but they offer one other advantage too: they decouple the jump discontinuity from the magnitude. With JumpReLU functions there's only one axis to vary along, but ideally you'd want the freedom to independently determine whether a feature should be on or off, and what its magnitude should be when it's on. Gated SAEs accomplish this by having 2 separate encoder weight matrices, one for computing magnitude and one for masking. Like JumpReLUs, they're also discontinuous and need a special training objective function, but unlike JumpReLUs they've generally proven much easier to train. From Neel's Extremely Opinionated Annotated List of My Favourite Mechanistic Interpretability Papers:

"I (very biasedly) think [the DeepMind paper] is worth reading as a good exemplar of how to rigorously evaluate whether an SAE change was an improvement, and because I recommend using Gated SAEs where possible."

End-to-End SAEs

In the paper Identifying Functionally Important Features with End-to-End Sparse Dictionary Learning, the authors propose a different way of training standard SAEs. Rather than using mean squared reconstruction error (MSE) of activations as the training objective, they use the KL divergence between the original output logits and the output logits we get when passing the SAE output through the rest of the network. The intuition here is that we want to identify the functionally important features which are actually important for explaining the model's behaviour on the training distribution. Minimizing MSE can be a good heuristic for this (because important features often need to be represented with high magnitude), but it's not directly getting at the thing we want to measure, and could be considered vulnerable to Goodhart's Law ("When a measure becomes a target, it ceases to be a good measure").

The full paper contains the results of experiments run on end-to-end (e2e) SAEs, compared to standard (local) SAEs. They find that e2e SAEs tend to require a smaller L0 for the same level of model performance captured, although they have much larger per-layer MSE (the authors suggest some ways to mitigate this, and find a balance between the local and end-to-end objectives).

Meta SAEs

Meta SAEs are a special type of SAE, trained to reconstruct the decoder directions of a normal SAE. This allows us to find sparse reconstructions of the base SAE latents, in situations where the SAE latents aren't monosemantic. One reason why we might not expect them to be monosemantic is feature absorption - the SAE might learn a feature like "starts with e", but that feature fails to fire on "elephant" because there's already a learned feature for "elephant" that has absorbed the "starts with e" information. This is better for the model's sparsity (because it means the "starts with e" feature fires on one less word), but unfortunately it prevents our features from being monosemantic, and prevents our SAEs giving us a decomposition into a sparse set of causal mediators.

The paper on Meta-SAEs finds the following key results:

  • SAE latents can be decomposed into more atomic, interpretable meta-latents.
  • We show that when latents in a larger SAE have split out from latents in a smaller SAE, a meta SAE trained on the larger SAE often recovers this structure.
  • We demonstrate that meta-latents allow for more precise causal interventions on model behavior than SAE latents on a targeted knowledge editing task.

You can visit the dashboard which the authors built, that lets you explore meta-SAE latents.

Transcoders

The MLP-layer SAEs we've looked at attempt to represent activations as a sparse linear combination of feature vectors; importantly, they only operate on activations at a single point in the model. They don't actually learn to perform the MLP layer's computation, rather they learn to reconstruct the results of that computation. It's very hard to do any weights-based analysis on MLP layers in superposition using standard SAEs, since many features are highly dense in the neuron basis, meaning the neurons are hard to decompose.

In contrast, transcoders take in the activations before the MLP layer (i.e. the possibly-normalized residual stream values) and aim to represent the post-MLP activations of that MLP layer, again as a sparse linear combination of feature vectors. The transcoder terminology is the most common, although these have also been called input-output SAEs (because we take the input to some base model layer, and try to learn the output) and predicting future activations (for obvious reasons). Note that transcoders aren't technically autoencoders, because they're learning a mapping rather than a reconstruction - however a lot of our intuitions from SAEs carry over to transcoders.

Why might transcoders be an improvement over standard SAEs? Mainly, they offer a much clearer insight into the function of a model's layers. From the Transcoders LessWrong post:

One of the strong points of transcoders is that they decompose the function of an MLP layer into sparse, independently-varying, and meaningful units (like neurons were originally intended to be before superposition was discovered). This significantly simplifies circuit analysis.

Intuitively it might seem like transcoders are solving a different (more complicated) kind of optimization problem - trying to mimic the MLP's computation rather than just reproduce output - and so they would suffer a performance tradeoff relative to standard SAEs. However, evidence suggests that this might not be the case, and transcoders might offer a pareto improvement over standard SAEs.

In the section on Circuits with SAEs, we'll dive much deeper into transcoders, and how & why they work so well.

Bonus

We've now finished the main content of this section! We recommend at this point that you jump to later sections (whichever ones interest you most), however you can also try out some of the bonus sections below to dive a bit deeper on some topics we covered earlier in this exercise set.

Reproduce circular subspace geometry from Not all Language Model Features are Linear

In our replication of the latent dashboards, we've written code pretty similar to the code we need for replicating some of the analysis from Not All Language Model Features are Linear. In this paper, the authors demonstrate an awesome circular representation of latents representing days of the week in GPT2 Small.

You can now replicate the circular geometry results, mostly usign the code you've already written. The end goal is to produce a plot like Figure 1 from the first page of the paper:

As a guide, you should do the following:

  1. Get activations for all the days of the week latents identified by the authors (we've included them in a list below, as well as the latents for months of the year & years of the 20th century). Note that these latents correspond to the SAE we've been working with, i.e. release "gpt2-small-res-jb" and id "blocks.7.hook_resid_pre".
  2. For every token where at least one of those latents is active, compute the SAE reconstruction for just the subset of those latents (in other words, the activations for the nonzero latents in your latent list, mapped through the SAE decoder). Also store the token: we can simplify by grouping the tokens into one of 8 groups: either one of the 7 days of the week, or "Other" (see Figure 1 for an illustration).
  3. Perform PCA over all the SAE reconstructions, and plot the second & third principal components. You should observe a circular geometry, with the days of the week forming a circle in the 2D plane (and same for the months of the year, and the years of the 20th century, if you try those too).

We also encourage you to think about why circular geometry might be useful when representing latents like days of the week. We definitely recommend reading the full paper if this is an area that interests you!

from sklearn.decomposition import PCA

day_of_the_week_latents = [2592, 4445, 4663, 4733, 6531, 8179, 9566, 20927, 24185]
# months_of_the_year = [3977, 4140, 5993, 7299, 9104, 9401, 10449, 11196, 12661, 14715, 17068, 17528, 19589, 21033, 22043, 23304]
# years_of_20th_century = [1052, 2753, 4427, 6382, 8314, 9576, 9606, 13551, 19734, 20349]

# YOUR CODE HERE - replicate circular subspace geometry, for `day_of_the_week_latents`
Click to see the expected output
Use this dropdown to get the PCA code, if you're stuck here.

Assuming all_reconstructions is a tensor of shape (n_datapoints, d_model), this code will create a tensor of shape (n_datapoints, 3) containing the 3 first principal components of the reconstructions:

from sklearn.decomposition import PCA
pca = PCA(n_components=3)
pca_embedding = pca.fit_transform(all_reconstructions.detach().cpu().numpy())
Use this dropdown to get some code for plotting the results, if you're stuck here.

This code will work, assuming pca_df is a dataframe with the following columns:

- PC2 and PC3, containing the 2nd and 3rd principal components of the SAE reconstructions (restricted to the days-of-the-week features) - token contains the token from which the reconstructions were taken - token_group is the same as token, but with all non-day-of-week tokens replaced by "Other" - context is a string of the context around each token (this is optional, and can be removed if you'd prefer)

px.scatter(
    pca_df,
    x="PC2",
    y="PC3",
    hover_data=["context"],
    hover_name="token",
    height=700,
    width=1000,
    color="token_group",
    color_discrete_sequence=px.colors.sample_colorscale("Viridis", 7) + ["#aaa"],
    title="PCA Subspace Reconstructions",
    labels={"token_group": "Activating token"},
    category_orders={"token_group": days_of_the_week + ["Other"]},
).show()
Solution
from sklearn.decomposition import PCA
day_of_the_week_latents = [2592, 4445, 4663, 4733, 6531, 8179, 9566, 20927, 24185]
# months_of_the_year = [3977, 4140, 5993, 7299, 9104, 9401, 10449, 11196, 12661, 14715, 17068, 17528, 19589, 21033, 22043, 23304]
# years_of_20th_century = [1052, 2753, 4427, 6382, 8314, 9576, 9606, 13551, 19734, 20349]
days_of_the_week = [
    "Monday",
    "Tuesday",
    "Wednesday",
    "Thursday",
    "Friday",
    "Saturday",
    "Sunday",
]
buffer = 5
seq_len = gpt2_act_store.context_size
sae_acts_post_hook_name = f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post"
all_data = {"recons": [], "context": [], "token": [], "token_group": []}
total_batches = 400
for i in tqdm(range(total_batches), desc="Computing activations data for PCA, over all batches"):
    _, cache = gpt2.run_with_cache_with_saes(
        tokens := gpt2_act_store.get_batch_tokens(),
        saes=[gpt2_sae],
        stop_at_layer=gpt2_sae.cfg.hook_layer + 1,
        names_filter=[sae_acts_post_hook_name],
    )
    acts = cache[sae_acts_post_hook_name][..., day_of_the_week_latents].flatten(0, 1)
any_latent_fired = (acts > 0).any(dim=1)
    acts = acts[any_latent_fired]
    reconstructions = acts @ gpt2_sae.W_dec[day_of_the_week_latents]
all_data["recons"].append(reconstructions)
for batch_seq_flat_idx in t.nonzero(any_latent_fired).squeeze(-1).tolist():
        batch, seq = divmod(batch_seq_flat_idx, seq_len)  # type: ignore
token = gpt2.tokenizer.decode(tokens[batch, seq])  # type: ignore
        token_group = token.strip() if token.strip() in days_of_the_week else "Other"
context = gpt2.tokenizer.decode(  # type: ignore
            tokens[batch, max(seq - buffer, 0) : min(seq + buffer + 1, seq_len)]
        )
all_data["context"].append(context)
        all_data["token"].append(token)
        all_data["token_group"].append(token_group)
pca = PCA(n_components=3)
pca_embedding = pca.fit_transform(t.concat(all_data.pop("recons")).detach().cpu().numpy())
px.scatter(
    pd.DataFrame(all_data | {"PC2": pca_embedding[:, 1], "PC3": pca_embedding[:, 2]}),
    x="PC2",
    y="PC3",
    hover_data=["context"],
    hover_name="token",
    height=700,
    width=1000,
    color="token_group",
    color_discrete_sequence=px.colors.sample_colorscale("Viridis", 7) + ["#aaa"],
    title="PCA Subspace Reconstructions",
    labels={"token_group": "Activating token"},
    category_orders={"token_group": days_of_the_week + ["Other"]},
).show()

Long and short-prefix induction

In their LessWrong post, the authors not only find induction features in GPT2-Small, but also find that SAE latents can teach us something meaningful about the roles of heads by meaningfully distinguishing between different roles for two L5 induction heads. They find that of the 2 induction heads in this layer (5.1 and 5.5), one head seems to be specialising in "long prefix induction" while the other mostly does "standard induction".

For example, the authors find two different latents primarily attributed to heads 5.1 and 5.5 respectively. These latents both attend back to "-"-joined expressions, but they have slightly different applications. One of them latents primarily performs "short prefix induction", i.e. induction where there is no long prefix for inferring we're in an induction pattern, for example:

  • "Indo-German ... Indo-""German"
  • "center-left ... center-""left"

The other one primarily performs "long prefix induction", i.e. induction where the second half of the induction pattern has been going on for a while, for example:

  • "Ways To Prevent Computer-Related Eye Strain ... Ways To Prevent Computer-""Related"
  • "shooting, a number of NRA-supported legislators ... a number of NRA-""supported"

Can you find these two latents in question, and figure out which head is which? Can you find the latents using multiple different techniques from the "finding latents for features" section?

# YOUR CODE HERE - replicate long and short-prefix induction results
Click to see the expected output
Top latent (long_form)
┌─────────────┬──────────────┐
│ Latent idx  │ 19293        │
│ Attribution │ 6.847        │
│ Activation  │ 2.427        │
│ Top head    │ 5.1 (43.45%) │
│ Second head │ 5.5 (22.84%) │
└─────────────┴──────────────┘

Top latent (short_form)
┌─────────────┬──────────────┐
│ Latent idx  │ 35744        │
│ Attribution │ 1.107        │
│ Activation  │ 0.476        │
│ Top head    │ 5.5 (58.49%) │
│ Second head │ 5.8 (8.38%)  │
└─────────────┴──────────────┘

Solution
induction_prompts = {
    "long_form": [
        "To reduce the risk of computer-related injuries, it's important to maintain proper posture and take regular breaks. To reduce the risk of computer",
        "observed that many people suffer from stress-induced headaches, which can be alleviated through relaxation techniques. And because many people suffer from stress",
        "Experts are increasingly worried about the impact of technology-driven automation on jobs. Experts are increasingly worried about the impact of technology",
    ],
    "short_form": [
        "A lot of NRA-supported legislation has been controversial. Furthermore, NRA",
        "The company is pursuing technology-driven solutions. This is because technology",
        "Humanity is part-angel, part",
    ],
}
layer = 5
sae_acts_post_hook_name = f"{attn_saes[layer].cfg.hook_name}.hook_sae_acts_post"
logit_dir = gpt2.W_U[:, gpt2.to_single_token("-")]
for induction_type in ["long_form", "short_form"]:
    prompts = induction_prompts[induction_type]
    _, cache = gpt2.run_with_cache_with_saes(
        prompts, saes=[attn_saes[layer]], names_filter=[sae_acts_post_hook_name]
    )
    sae_acts_post = cache[sae_acts_post_hook_name][:, -1, :].mean(0)
    alive_latents = sae_acts_post.nonzero().squeeze().tolist()
sae_attribution = sae_acts_post * (
        attn_saes[layer].W_dec @ gpt2.W_O[layer].flatten(0, 1) @ logit_dir
    )
ind = sae_attribution.argmax().item()
    latent_dir = attn_saes[layer].W_dec[ind]
    norm_per_head = latent_dir.reshape(gpt2.cfg.n_heads, gpt2.cfg.d_head).pow(2).sum(-1).sqrt()
    norm_frac_per_head = norm_per_head / norm_per_head.sum(-1, keepdim=True)
    top_head_values, top_heads = norm_frac_per_head.topk(2, dim=-1)
print(
        f"Top latent ({induction_type})\n"
        + tabulate(
            [
                ["Latent idx", ind],
                ["Attribution", f"{sae_attribution[ind]:.3f}"],
                ["Activation", f"{sae_acts_post[ind]:.3f}"],
                ["Top head", f"5.{top_heads[0]} ({top_head_values[0]:.2%})"],
                ["Second head", f"5.{top_heads[1]} ({top_head_values[1]:.2%})"],
            ],
            tablefmt="simple_outline",
        ),
    )
# Line chart of latent attributions
    px.line(
        sae_attribution.cpu().numpy(),
        title=f"Attributions for correct token ({induction_type} induction) at final token position ({len(alive_latents)} non-zero attribution)",
        labels={"index": "Latent", "value": "Attribution"},
        template="ggplot2",
        width=1000,
    ).update_layout(showlegend=False).show()
# Display dashboard
    display_dashboard(
        sae_release="gpt2-small-hook-z-kk",
        sae_id=f"blocks.{layer}.hook_z",
        latent_idx=int(ind),
    )