1️⃣ Intro to SAE Interpretability
Learning Objectives
- Learn how to use the
SAELenslibrary to load in & run SAEs (alongside the TransformerLens models they're attached to)- Understand the basic features of Neuronpedia, and how it can be used for things like steering and searching over features
- Understand SAE dashboards, what each part of them tells you about a particular latent (as well as how to compute them yourself)
- Learn techniques for finding latents, including direct logit attribution, ablation and attribution patching
- Use attention SAEs, understand how they differ from regular SAEs (as well as topics specific to attention SAEs, like direct latent attribution)
- Learn a bit about different SAE architectures or training methods (e.g. gated, end-to-end, meta-saes, transcoders) - some of these will be covered in more detail later
To emphasize - the idea is for this section is to be a whirlwind tour of all basic SAE topics, excluding training & evals (which we'll come back to in section 4). The focus will be on how to understand & interpret SAE latents (in particular all the components of the SAE dashboard). We'll also look at techniques for finding latents (e.g. ablation & attribution methods), as well as taking a deeper dive into attention SAEs and how they work. Because there's a lot of material to cover in this section, we'll have a summary of the key points at the top of each main header section. These summaries are all included below for convenience, before we get started. As well as helping to keep you oriented as you work through the material, these should also give you an idea of which sections you can jump to if you only want to cover a few of them.
Intro to SAELens
In this section, you'll learn what SAELens is, and how to use it to load in & inspect the configs of various supported SAEs. Key points:
- SAELens is a library for training and analysing SAEs. It can be thought of as the equivalent of TransformerLens for SAEs (although it allso integrates closely with TransformerLens, as we'll see in the "Running SAEs" section)
- SAELens contains many different model releases, each release containing multiple SAEs (e.g. trained on different model layers / hook points, or with different architectures)
- The cfg attribute of an SAE instance contains this information, and anything else that's relevant when performing forward passes
Visualizing SAEs with dashboards
In this section, you'll learn about SAE dashboards, which are a visual tool for quickly understanding what a particular SAE latent represents. Key points:
- Neuronpedia hosts dashboards which help you understand SAE latents
- The 5 main components of the dashboard are: top logit tables, logits histogram, activation density plots, top activating sequences, and autointerp
- All of these components are important for getting a full picture of what a latent represents, but they can also all be misleading
- You can display these dashboards inline, using IFrame
Running SAEs
In this section, you'll learn how to run forward passes with SAEs. This is a pretty simple process, which builds on much of the pre-existing infrastructure in TransformerLens models. Key points:
- You can add SAEs to a TransformerLens model when doing forward passes in pretty much the same way you add hook functions (you can think of SAEs as a special kind of hook function)
- When sae.error_term=False (default) you substitute the SAE's output for the transformer activations. When True, you don't substitute (which is sometimes what you want when caching activations)
- There's an analogous run_with_saes that works like run_with_hooks
- There's also run_with_cache_with_saes that works like run_with_cache, but allows you to cache any SAE activations you want
- You can use ActivationStore to get a large batch of activations at once
Replicating SAE dashboards
In this section, you'll replicate the 5 main components of the SAE dashboard: top logits tables, logits histogram, activation density plots, top activating sequences, and autointerp. There's not really any new content here, just putting into practice what you've learned from the previous 2 sections "Visualizing SAEs with dashboards" and "Running SAEs".
Attention SAEs
In this section, you'll learn about attention SAEs, how they work (mostly quite similar to standard SAEs but with a few other considerations), and how to understand their feature dashboards. Key points:
- Attention SAEs have the same architecture as regular SAEs, except they're trained on the concatenated pre-projection output of all attention heads. - If a latent fires on a destination token, we can use direct latent attribution to see which source tokens it primarily came from. - Just like regular SAEs, latents found in different layers of a model are often qualitatively different from each other.
Finding latents for features
In this section, you'll explore different methods (some causal, some not) for finding latents in SAEs corresponding to particular features. Key points:
- You can look at max activating latents on some particular input prompt, this is basically the simplest thing you can do - Direct logit attribution (DLA) is a bit more refined; you can find latents which have a direct effect on specific logits - Ablation of SAE latents can help you find latents which are important in a non-direct way - ...but it's quite costly for a large number of latents, so you can use attribution patching as a cheaper linear approximation of ablation
GemmaScope
This short section introduces you to DeepMind's GemmaScope series, a suite of highly performant SAEs which can be a great source of study in your own interpretability projects!
Feature steering
In this section, you'll learn how to steer on latents to produce interesting model output. Key points:
- Steering involves intervening during a forward pass to change the model's activations in the direction of a particular latent - The steering behaviour is sometimes unpredictable, and not always equivalent to "produce text of the same type as the latent strongly activates on" - Neuronpedia has a steering interface which allows you to steer without any code
Other types of SAEs
This section introduces a few different SAE architectures, some of which will be explored in more detail in later sections. There are no exercises here, just brief descriptions. Key points:
- Different activation functions / encoder architecturs e.g. TopK, JumpReLU and Gated models can solve problems like feature suppression and the pressure for SAEs to be continuous in standard models - End-to-end SAEs are trained with a different loss function, encouraging them to learn features that are functionally useful for the model's output rather than just minimising MSE reconstruction error - Transcoders are a type of SAE which learn to reconstruct a model's computation (e.g. a sparse mapping from MLP input to MLP output) rather than just reconstructing activations; they can sometimes lead to easier circuit analysis
Intro to SAELens
In this section, you'll learn what
SAELensis, and how to use it to load in & inspect the configs of various supported SAEs. Key points:
- SAELens is a library for training and analysing SAEs. It can be thought of as the equivalent of TransformerLens for SAEs (although it allso integrates closely with TransformerLens, as we'll see in the "Running SAEs" section)
- SAELens contains many different model releases, each release containing multiple SAEs (e.g. trained on different model layers / hook points, or with different architectures)
- The
cfgattribute of anSAEinstance contains this information, and anything else that's relevant when performing forward passes
SAELens is a library designed to help researchers:
- Train sparse autoencoders,
- Analyse sparse autoencoders / research mechanistic interpretability,
- Generate insights which make it easier to create safe and aligned AI systems.
You can think of it as the equivalent of TransformerLens for sparse autoencoders (and it also integrates very well with TransformerLens models, which we'll see shortly).
Additionally, SAELens is closely integrated with Neuronpedia, an open platform for interpretability research developed through the joint efforts of Joseph Bloom and Johnny Lin, and which we'll be using throughout this chapter. Neuronpedia allows you to search over latents, run SAEs, and even upload your own SAEs!
Before we load our SAEs, it can be useful to see which are available. The following snippet shows the currently available SAE releases in SAELens, and will remain up-to-date as SAELens continues to add more SAEs.
print(get_pretrained_saes_directory())
{
'gpt2-small-res-jb': PretrainedSAELookup(
release='gpt2-small-res-jb',
repo_id='jbloom/GPT2-Small-SAEs-Reformatted',
model='gpt2-small',
conversion_func=None,
saes_map={'blocks.0.hook_resid_pre': 'blocks.0.hook_resid_pre', ..., 'blocks.11.hook_resid_post': 'blocks.11.hook_resid_post'},
expected_var_explained={'blocks.0.hook_resid_pre': 0.999, ..., 'blocks.11.hook_resid_post': 0.77},
expected_l0={'blocks.0.hook_resid_pre': 10.0, ..., 'blocks.11.hook_resid_post': 70.0},
neuronpedia_id={'blocks.0.hook_resid_pre': 'gpt2-small/0-res-jb', ..., 'blocks.11.hook_resid_post': 'gpt2-small/12-res-jb'},
config_overrides={'model_from_pretrained_kwargs': {'center_writing_weights': True}}),
'gpt2-small-hook-z-kk': PretrainedSAELookup(
...
)
}
Let's print out all this data in a more readable format, with only a subset of attributes. We'll look at model (the base model), release (the name of the SAE release), repo_id (the id of the HuggingFace repo containing the SAEs), and also the number of SAEs in each release (e.g. a release might contain an SAE trained on each layer of the base model).
metadata_rows = [
[data.model, data.release, data.repo_id, len(data.saes_map)]
for data in get_pretrained_saes_directory().values()
]
# Print all SAE releases, sorted by base model
print(
tabulate(
sorted(metadata_rows, key=lambda x: x[0]),
headers=["model", "release", "repo_id", "n_saes"],
tablefmt="simple_outline",
)
)
┌─────────────────────────────────────┬─────────────────────────────────────────────────────┬────────────────────────────────────────────────────────┬──────────┐ │ model │ release │ repo_id │ n_saes │ ├─────────────────────────────────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────┼──────────┤ │ gemma-2-27b │ gemma-scope-27b-pt-res │ google/gemma-scope-27b-pt-res │ 18 │ │ gemma-2-27b │ gemma-scope-27b-pt-res-canonical │ google/gemma-scope-27b-pt-res │ 3 │ │ gemma-2-2b │ gemma-scope-2b-pt-res │ google/gemma-scope-2b-pt-res │ 310 │ │ gemma-2-2b │ gemma-scope-2b-pt-res-canonical │ google/gemma-scope-2b-pt-res │ 58 │ │ gemma-2-2b │ gemma-scope-2b-pt-mlp │ google/gemma-scope-2b-pt-mlp │ 260 │ │ gemma-2-2b │ gemma-scope-2b-pt-mlp-canonical │ google/gemma-scope-2b-pt-mlp │ 52 │ │ gemma-2-2b │ gemma-scope-2b-pt-att │ google/gemma-scope-2b-pt-att │ 260 │ │ gemma-2-2b │ gemma-scope-2b-pt-att-canonical │ google/gemma-scope-2b-pt-att │ 52 │ │ gemma-2-9b │ gemma-scope-9b-pt-res │ google/gemma-scope-9b-pt-res │ 562 │ │ gemma-2-9b │ gemma-scope-9b-pt-res-canonical │ google/gemma-scope-9b-pt-res │ 91 │ │ gemma-2-9b │ gemma-scope-9b-pt-att │ google/gemma-scope-9b-pt-att │ 492 │ │ gemma-2-9b │ gemma-scope-9b-pt-att-canonical │ google/gemma-scope-9b-pt-att │ 84 │ │ gemma-2-9b │ gemma-scope-9b-pt-mlp │ google/gemma-scope-9b-pt-mlp │ 492 │ │ gemma-2-9b │ gemma-scope-9b-pt-mlp-canonical │ google/gemma-scope-9b-pt-mlp │ 84 │ │ gemma-2-9b │ gemma-scope-9b-it-res │ google/gemma-scope-9b-it-res │ 30 │ │ gemma-2-9b-it │ gemma-scope-9b-it-res-canonical │ google/gemma-scope-9b-it-res │ 6 │ │ gemma-2b │ gemma-2b-res-jb │ jbloom/Gemma-2b-Residual-Stream-SAEs │ 5 │ │ gemma-2b │ sae_bench_gemma-2-2b_sweep_standard_ctx128_ef2_0824 │ canrager/lm_sae │ 180 │ │ gemma-2b │ sae_bench_gemma-2-2b_sweep_standard_ctx128_ef8_0824 │ canrager/lm_sae │ 240 │ │ gemma-2b │ sae_bench_gemma-2-2b_sweep_topk_ctx128_ef2_0824 │ canrager/lm_sae │ 180 │ │ gemma-2b │ sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824 │ canrager/lm_sae │ 240 │ │ gemma-2b-it │ gemma-2b-it-res-jb │ jbloom/Gemma-2b-IT-Residual-Stream-SAEs │ 1 │ ... │ pythia-70m-deduped │ pythia-70m-deduped-res-sm │ ctigges/pythia-70m-deduped__res-sm_processed │ 7 │ │ pythia-70m-deduped │ pythia-70m-deduped-mlp-sm │ ctigges/pythia-70m-deduped__mlp-sm_processed │ 6 │ │ pythia-70m-deduped │ pythia-70m-deduped-att-sm │ ctigges/pythia-70m-deduped__att-sm_processed │ 6 │ └─────────────────────────────────────┴─────────────────────────────────────────────────────┴────────────────────────────────────────────────────────┴──────────┘
Any given SAE release may have multiple different mdoels. These might have been trained on different hookpoints or layers in the model, or with different hyperparameters, etc. You can see the data associated with each release as follows:
def format_value(value):
return (
"{{{0!r}: {1!r}, ...}}".format(*next(iter(value.items())))
if isinstance(value, dict)
else repr(value)
)
release = get_pretrained_saes_directory()["gpt2-small-res-jb"]
print(
tabulate(
[[k, format_value(v)] for k, v in release.__dict__.items()],
headers=["Field", "Value"],
tablefmt="simple_outline",
)
)
┌────────────────────────┬─────────────────────────────────────────────────────────────────────────┐
│ Field │ Value │
├────────────────────────┼─────────────────────────────────────────────────────────────────────────┤
│ release │ 'gpt2-small-res-jb' │
│ repo_id │ 'jbloom/GPT2-Small-SAEs-Reformatted' │
│ model │ 'gpt2-small' │
│ conversion_func │ None │
│ saes_map │ {'blocks.0.hook_resid_pre': 'blocks.0.hook_resid_pre', ...} │
│ expected_var_explained │ {'blocks.0.hook_resid_pre': 0.999, ...} │
│ expected_l0 │ {'blocks.0.hook_resid_pre': 10.0, ...} │
│ neuronpedia_id │ {'blocks.0.hook_resid_pre': 'gpt2-small/0-res-jb', ...} │
│ config_overrides │ {'model_from_pretrained_kwargs': {'center_writing_weights': True}, ...} │
└────────────────────────┴─────────────────────────────────────────────────────────────────────────┘
Let's get some more info about each of the SAEs associated with each release. We can print out the SAE id, the path (i.e. in the HuggingFace repo, which points to the SAE model weights) and the Neuronpedia ID (which is how we'll get feature dashboards - more on this soon).
data = [[id, path, release.neuronpedia_id[id]] for id, path in release.saes_map.items()]
print(
tabulate(
data,
headers=["SAE id", "SAE path (HuggingFace)", "Neuronpedia ID"],
tablefmt="simple_outline",
)
)
┌───────────────────────────┬───────────────────────────┬──────────────────────┐ │ SAE id │ SAE path (HuggingFace) │ Neuronpedia ID │ ├───────────────────────────┼───────────────────────────┼──────────────────────┤ │ blocks.0.hook_resid_pre │ blocks.0.hook_resid_pre │ gpt2-small/0-res-jb │ │ blocks.1.hook_resid_pre │ blocks.1.hook_resid_pre │ gpt2-small/1-res-jb │ │ blocks.2.hook_resid_pre │ blocks.2.hook_resid_pre │ gpt2-small/2-res-jb │ │ blocks.3.hook_resid_pre │ blocks.3.hook_resid_pre │ gpt2-small/3-res-jb │ │ blocks.4.hook_resid_pre │ blocks.4.hook_resid_pre │ gpt2-small/4-res-jb │ │ blocks.5.hook_resid_pre │ blocks.5.hook_resid_pre │ gpt2-small/5-res-jb │ │ blocks.6.hook_resid_pre │ blocks.6.hook_resid_pre │ gpt2-small/6-res-jb │ │ blocks.7.hook_resid_pre │ blocks.7.hook_resid_pre │ gpt2-small/7-res-jb │ │ blocks.8.hook_resid_pre │ blocks.8.hook_resid_pre │ gpt2-small/8-res-jb │ │ blocks.9.hook_resid_pre │ blocks.9.hook_resid_pre │ gpt2-small/9-res-jb │ │ blocks.10.hook_resid_pre │ blocks.10.hook_resid_pre │ gpt2-small/10-res-jb │ │ blocks.11.hook_resid_pre │ blocks.11.hook_resid_pre │ gpt2-small/11-res-jb │ │ blocks.11.hook_resid_post │ blocks.11.hook_resid_post │ gpt2-small/12-res-jb │ └───────────────────────────┴───────────────────────────┴──────────────────────┘
Next, we'll load the SAE which we'll be working with for most of these exercises: the layer 7 resid pre model from the GPT2 Small SAEs (as well as a copy of GPT2 Small to attach it to). The SAE uses the HookedSAETransformer class, which is adapted from the TransformerLens HookedTransformer class.
Note, the SAE.from_pretrained function has return type tuple[SAE, dict, Tensor | None], with the return elements being the SAE, config dict, and a tensor of feature sparsities. The config dict contains useful metadata on e.g. how the SAE was trained (among other things).
t.set_grad_enabled(False)
gpt2: HookedSAETransformer = HookedSAETransformer.from_pretrained("gpt2-small", device=device)
gpt2_sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.7.hook_resid_pre",
device=str(device),
)
The sae object is an instance of the SAE (Sparse Autoencoder) class. There are many different SAE architectures which may have different weights or activation functions. In order to simplify working with SAEs, SAELens handles most of this complexity for you. You can run the cell below to see each of the SAE config parameters for the one we'll be using.
Click to read a description of each of the SAE config parameters.
1. architecture: Specifies the type of SAE architecture being used, in this case, the standard architecture (encoder and decoder with hidden activations, as opposed to a gated SAE).
2. d_in: Defines the input dimension of the SAE, which is 768 in this configuration.
3. d_sae: Sets the dimension of the SAE's hidden layer, which is 24576 here. This represents the number of possible feature activations.
4. activation_fn_str: Specifies the activation function used in the SAE, which is ReLU in this case. TopK is another option that we will not cover here.
5. apply_b_dec_to_input: Determines whether to apply the decoder bias to the input, set to True here.
6. finetuning_scaling_factor: Indicates whether to use a scaling factor to weight initialization and the forward pass. This is not usually used and was introduced to support a [solution for shrinkage](https://www.lesswrong.com/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes).
7. context_size: Defines the size of the context window, which is 128 tokens in this case. In turns out SAEs trained on small activations from small prompts [often don't perform well on longer prompts](https://www.lesswrong.com/posts/baJyjpktzmcmRfosq/stitching-saes-of-different-sizes).
8. model_name: Specifies the name of the model being used, which is 'gpt2-small' here. [This is a valid model name in TransformerLens](https://transformerlensorg.github.io/TransformerLens/generated/model_properties_table.html).
9. hook_name: Indicates the specific hook in the model where the SAE is applied.
10. hook_layer: Specifies the layer number where the hook is applied, which is layer 7 in this case.
11. hook_head_index: Defines which attention head to hook into; not relevant here since we are looking at a residual stream SAE.
12. prepend_bos: Determines whether to prepend the beginning-of-sequence token, set to True.
13. dataset_path: Specifies the path to the dataset used for training or evaluation. (Can be local or a huggingface dataset.)
14. dataset_trust_remote_code: Indicates whether to trust remote code (from HuggingFace) when loading the dataset, set to True.
15. normalize_activations: Specifies how to normalize activations, set to 'none' in this config.
16. dtype: Defines the data type for tensor operations, set to 32-bit floating point.
17. device: Specifies the computational device to use.
18. sae_lens_training_version: Indicates the version of SAE Lens used for training, set to None here.
19. activation_fn_kwargs: Allows for additional keyword arguments for the activation function. This would be used if e.g. the activation_fn_str was set to topk, so that k could be specified.
print(tabulate(gpt2_sae.cfg.__dict__.items(), headers=["name", "value"], tablefmt="simple_outline"))
┌──────────────────────────────┬──────────────────────────────────┐
│ name │ value │
├──────────────────────────────┼──────────────────────────────────┤
│ architecture │ standard │
│ d_in │ 768 │
│ d_sae │ 24576 │
│ activation_fn_str │ relu │
│ apply_b_dec_to_input │ True │
│ finetuning_scaling_factor │ False │
│ context_size │ 128 │
│ model_name │ gpt2-small │
│ hook_name │ blocks.7.hook_resid_pre │
│ hook_layer │ 7 │
│ hook_head_index │ │
│ prepend_bos │ True │
│ dataset_path │ Skylion007/openwebtext │
│ dataset_trust_remote_code │ True │
│ normalize_activations │ none │
│ dtype │ torch.float32 │
│ device │ cuda │
│ sae_lens_training_version │ │
│ activation_fn_kwargs │ {} │
│ neuronpedia_id │ gpt2-small/7-res-jb │
│ model_from_pretrained_kwargs │ {'center_writing_weights': True} │
└──────────────────────────────┴──────────────────────────────────┘
Visualizing SAEs with dashboards
In this section, you'll learn about SAE dashboards, which are a visual tool for quickly understanding what a particular SAE latent represents. Key points:
- Neuronpedia hosts dashboards which help you understand SAE latents
- The 5 main components of the dashboard are: top logit tables, logits histogram, activation density plots, top activating sequences, and autointerp
- All of these components are important for getting a full picture of what a latent represents, but they can also all be misleading
- You can display these dashboards inline, using
IFrame
In this section, we're going to have a look at our SAEs, and see what they're actually telling us.
Before we dive too deep however, let's recap something - what actually is an SAE latent?
An SAE latent is a particular direction in the base model's activation space, learned by the SAE. Often, these correspond to features** in the data - in other words, meaningful semantic, syntactic or otherwise interpretable patterns or concepts that exist in the distribution of data the base model was trained on, and which were learned by the base model. These features are usually highly sparse, in other words for any given feature only a small fraction of the overall data distribution will activate that feature. It tends to be the case that sparser features are also more interpretable.
**Note - technically saying "direction" is an oversimplification here, because a given latent can have multiple directions in activation space associated with them, e.g. a separate encoder and decoder direction for standard untied SAEs. When we refer to a latent direction or feature direction, we're usually but not always referring to the decoder weights.
The dashboard shown below provides a detailed view of a single SAE latent.
def display_dashboard(
sae_release="gpt2-small-res-jb",
sae_id="blocks.7.hook_resid_pre",
latent_idx=0,
width=800,
height=600,
):
release = get_pretrained_saes_directory()[sae_release]
neuronpedia_id = release.neuronpedia_id[sae_id]
url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
print(url)
display(IFrame(url, width=width, height=height))
latent_idx = random.randint(0, gpt2_sae.cfg.d_sae)
display_dashboard(latent_idx=latent_idx)
Let's break down the separate components of the visualization:
- Latent Activation Distribution. This shows the proportion of tokens a latent fires on, usually between 0.01% and 1%, and also shows the distribution of positive activations.
- Logits Distribution. This is the projection of the decoder weight onto the unembed and roughly gives us a sense of the tokens promoted by a latent. It's less useful in big models / middle layers.
- Top / Botomn Logits. These are the 10 most positive and most negative logits in the logit weight distribution.
- Max Activating Examples. These are examples of text where the latent fires and usually provide the most information for helping us work out what a latent means.
- Autointerp. These are LLM-generated latent explanations, which use the rest of the data in the dashboard (in particular the max activating examples).
See this section of Towards Monosemanticity for more information.
Neuronpedia is a website that hosts SAE dashboards and which runs servers that can run the model and check latent activations. This makes it very convenient to check that a latent fires on the distribution of text you actually think it should fire on. We've been downloading data from Neuronpedia for the dashboards above.
Exercise - find interesting latents
Spend some time browsing through the SAE dashboard (i.e. running the code with different random indices). What interesting latents can you find? Try and find the following types of latents:
- Latents for token-level features, which seem to only fire on a particular token and basically no others. Do the top logits make sense, when viewed as bigram frequencies?
- Latents for concept-level features, which fire not on single tokens but across multiple tokens, provided that a particular concept is present in the text (e.g.
latent_idx=4527is an example of this). What is the concept that this latent represents? Can you see how the positive logits for this token make sense? - Highly sparse latents, with activation density less than 0.05%. Does it seem more interpretable than the average latent?
Click on this dropdown to see examples of each type, which you can compare to the ones you found.
Latent 9 seems to only fire on the word "new", in the context of describing a plurality of new things (often related to policies, business estimates, hiring, etc). The positive logits support this, with bigrams like "new arrivals" and "new developments" being boosted. Interestingly, we also have the bigram "newbie(s)" boosted.
Latent 67 seems to be more concept-level, firing on passages that talk about a country's decision to implement a particular policy or decision (especially when that country is being described as the only one to do it). Although the positive logits are also associated with countries or government policies, they don't directly make sense as bigrams - which is what we'd expect given that the latent fires on multiple different tokens in a sentence when that sentence contains the concept in question.
Latent 13 has fires with frequency 0.049%. It seems to activate on the token "win", especially when in the context of winning over people (e.g. winning the presidency, winning over hearts and minds, or winning an argument) or winning a race. This seems pretty specific and interpretable, although it's important when interpreting latents to remember the interpretability illusion - seeing the top activating patterns can given a misplaced sense of confidence in any particular interpretation. In later sections we'll perform more careful hypothesis testing to refine our understanding of latents.
Running SAEs
In this section, you'll learn how to run forward passes with SAEs. This is a pretty simple process, which builds on much of the pre-existing infrastructure in TransformerLens models. Key points:
- You can add SAEs to a TransformerLens model when doing forward passes in pretty much the same way you add hook functions (you can think of SAEs as a special kind of hook function)
- When
sae.error_term=False(default) you substitute the SAE's output for the transformer activations. When True, you don't substitute (which is sometimes what you want when caching activations)- There's an analogous
run_with_saesthat works likerun_with_hooks- There's also
run_with_cache_with_saesthat works likerun_with_cache, but allows you to cache any SAE activations you want- You can use
ActivationStoreto get a large batch of activations at once
Now that we've had a look at some SAEs via Neuronpedia, it's time to load them in and start running them ourselves!
One of the key features of HookedSAETransformer is being able to "splice in" SAEs, replacing model activations with their SAE reconstructions. To run a forward pass with SAEs attached, you can use model.run_with_saes(tokens, saes=[list_of_saes]). This function has similar syntax to the standard forward pass (or to model.run_with_hooks), e.g. it can take arguments like return_type to specify whether the return type should be loss or logits. The attached SAEs will be reset immediately after the forward pass, returning the model to its original state. Under the hood, they work just like adding hooks in TransformerLens, only in this case our hooks are "replace these activations with their SAE reconstructions".
There are a lot of other ways you can do SAE-hooked forward passes, which parallel the multiple way you can do regular hooked forward passes. For example, just like you can use with model.hooks(fwd_hooks=...) as a context manager to add hooks temporarily, you can also use with model.saes(saes=...) to run a forward pass with SAEs attached. And just like you can use model.add_hook and model.reset_hooks, you can also use model.add_sae and model.reset_saes.
prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"
# First see how the model does without SAEs
test_prompt(prompt, answer, gpt2)
# Test our prompt, to see what the model says
with gpt2.saes(saes=[gpt2_sae]):
test_prompt(prompt, answer, gpt2)
# Same thing, done in a different way
gpt2.add_sae(gpt2_sae)
test_prompt(prompt, answer, gpt2)
gpt2.reset_saes() # Remember to always do this!
# Using `run_with_saes` method in place of standard forward pass
logits = gpt2(prompt, return_type="logits")
logits_sae = gpt2.run_with_saes(prompt, saes=[gpt2_sae], return_type="logits")
answer_token_id = gpt2.to_single_token(answer)
# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)
print(f"""Standard model:
top prediction = {gpt2.to_string(token_id_prediction)!r}
prob = {top_prob.item():.2%}
SAE reconstruction:
top prediction = {gpt2.to_string(token_id_prediction_sae)!r}
prob = {top_prob_sae.item():.2%}
""")
Tokenized prompt: ['<|endoftext|>', 'Mit', 'igating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global'] Tokenized answer: [' priority'] Performance on answer token: Rank: 0 Logit: 19.46 Prob: 52.99% Token: | priority| Top 0th token. Logit: 19.46 Prob: 52.99% Token: | priority| Top 1th token. Logit: 17.44 Prob: 7.02% Token: | effort| Top 2th token. Logit: 16.94 Prob: 4.26% Token: | issue| Top 3th token. Logit: 16.63 Prob: 3.14% Token: | challenge| Top 4th token. Logit: 16.37 Prob: 2.42% Token: | goal| Top 5th token. Logit: 16.06 Prob: 1.78% Token: | concern| Top 6th token. Logit: 15.88 Prob: 1.47% Token: | focus| Top 7th token. Logit: 15.61 Prob: 1.13% Token: | approach| Top 8th token. Logit: 15.53 Prob: 1.04% Token: | policy| Top 9th token. Logit: 15.42 Prob: 0.93% Token: | initiative| Ranks of the answer tokens: [(' priority', 0)]
Tokenized prompt: ['<|endoftext|>', 'Mit', 'igating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global'] Tokenized answer: [' priority'] Performance on answer token: Rank: 0 Logit: 18.19 Prob: 39.84% Token: | priority| Top 0th token. Logit: 18.19 Prob: 39.84% Token: | priority| Top 1th token. Logit: 16.51 Prob: 7.36% Token: | issue| Top 2th token. Logit: 16.48 Prob: 7.20% Token: | concern| Top 3th token. Logit: 15.94 Prob: 4.19% Token: | challenge| Top 4th token. Logit: 15.30 Prob: 2.21% Token: | goal| Top 5th token. Logit: 15.12 Prob: 1.85% Token: | responsibility| Top 6th token. Logit: 15.04 Prob: 1.69% Token: | problem| Top 7th token. Logit: 14.98 Prob: 1.60% Token: | effort| Top 8th token. Logit: 14.73 Prob: 1.24% Token: | policy| Top 9th token. Logit: 14.66 Prob: 1.16% Token: | imperative| Ranks of the answer tokens: [(' priority', 0)]
Tokenized prompt: ['<|endoftext|>', 'Mit', 'igating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global'] Tokenized answer: [' priority'] Performance on answer token: Rank: 0 Logit: 18.19 Prob: 39.84% Token: | priority| Top 0th token. Logit: 18.19 Prob: 39.84% Token: | priority| Top 1th token. Logit: 16.51 Prob: 7.36% Token: | issue| Top 2th token. Logit: 16.48 Prob: 7.20% Token: | concern| Top 3th token. Logit: 15.94 Prob: 4.19% Token: | challenge| Top 4th token. Logit: 15.30 Prob: 2.21% Token: | goal| Top 5th token. Logit: 15.12 Prob: 1.85% Token: | responsibility| Top 6th token. Logit: 15.04 Prob: 1.69% Token: | problem| Top 7th token. Logit: 14.98 Prob: 1.60% Token: | effort| Top 8th token. Logit: 14.73 Prob: 1.24% Token: | policy| Top 9th token. Logit: 14.66 Prob: 1.16% Token: | imperative| Ranks of the answer tokens: [(' priority', 0)] Standard model: top prediction = ' priority' prob = 52.99% SAE reconstruction: top prediction = ' priority' prob = 39.84%
Okay, so this is fine if we want to do a forward pass with the model's output replaced by SAE output, but what if we want to just get the SAE activations? Well, that's where running with cache comes in! With HookedSAETransformer, you can cache SAE activations (and all the other standard activations) with logits, cache = model.run_with_cache_with_saes(tokens, saes=saes). Just as run_with_saes is a wapper around the standard forward pass, run_with_cache_with_saes is a wrapper around run_with_cache, and will also only add these saes for one forward pass before returning the model to its original state.
To access SAE activations from the cache, the corresponding hook names will generally be the concatenations of the HookedTransformer hook_name (e.g. "blocks.5.attn.hook_z") and the SAE hook name (e.g. "hook_sae_acts_post"), joined by a period. We can print out all the names below:
_, cache = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])
for name, param in cache.items():
if "hook_sae" in name:
print(f"{name:<43}: {tuple(param.shape)}")
blocks.7.hook_resid_pre.hook_sae_input : (1, 13, 768) blocks.7.hook_resid_pre.hook_sae_acts_pre : (1, 13, 24576) blocks.7.hook_resid_pre.hook_sae_acts_post : (1, 13, 24576) blocks.7.hook_resid_pre.hook_sae_recons : (1, 13, 768) blocks.7.hook_resid_pre.hook_sae_output : (1, 13, 768)
run_with_cache_with_saes makes it easy to explore which SAE latents are active across any input. We can also use this along with the argument stop_at_layer in our forward pass, because we don't need to compute any activations past the SAE layer.
Let's explore the active latents at the final token in our prompt. You should find that the first latent fires on the word "global", particularly in the context of disasters such as "global warming", "global poverty" and "global war".
# Get top activations on final token
_, cache = gpt2.run_with_cache_with_saes(
prompt,
saes=[gpt2_sae],
stop_at_layer=gpt2_sae.cfg.hook_layer + 1,
)
sae_acts_post = cache[f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post"][0, -1, :]
# Plot line chart of latent activations
px.line(
sae_acts_post.cpu().numpy(),
title=f"Latent activations at the final token position ({sae_acts_post.nonzero().numel()} alive)",
labels={"index": "Latent", "value": "Activation"},
width=1000,
).update_layout(showlegend=False).show()
# Print the top 5 latents, and inspect their dashboards
for act, ind in zip(*sae_acts_post.topk(3)):
print(f"Latent {ind} had activation {act:.2f}")
display_dashboard(latent_idx=ind)
Click to see the expected output
Error term
Important note - the parameter sae.use_error_term determines whether we'll actually substitute the activations with SAE reconstructions during our SAE forward pass. If it's False (default) then we do replace activations with SAE reconstructions, but if it's True then we'll just compute the SAE's hidden activations without replacing the transformer activations with its output.
The use_error_term parameter controls behaviour when we do forward passes, hooked forward passes, forward passes with cache, or anything else that involves running the model with SAEs attached (but obviously this parameter only matters when we're caching values, because doing a forward pass with sae.use_error_term=True and not caching any values is equivalent to just running the base model without any SAEs!).
Why is it called use_error_term ?
It's called this because when set to True we'll have the final output of the forward pass be sae_out + sae_error rather than sae_out. This sae_error term is literally defined as sae_in - sae_out, i.e. the difference between the original input and SAE reconstruction. So this is equivalent to the SAE just being the identity function. But we need to do things this way so we still compute all the internal states of the SAE in exactly the same way as we would if we were actually replacing the transformer's activations with SAE reconstructions.
logits_no_saes, cache_no_saes = gpt2.run_with_cache(prompt)
gpt2_sae.use_error_term = False
logits_with_sae_recon, cache_with_sae_recon = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])
gpt2_sae.use_error_term = True
logits_without_sae_recon, cache_without_sae_recon = gpt2.run_with_cache_with_saes(
prompt, saes=[gpt2_sae]
)
# Both SAE caches contain the hook values
assert f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post" in cache_with_sae_recon
assert f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post" in cache_without_sae_recon
# But final output will be different, because we don't use SAE reconstructions when use_error_term
t.testing.assert_close(logits_no_saes, logits_without_sae_recon)
logit_diff_from_sae = (logits_no_saes - logits_with_sae_recon).abs().mean()
print(f"Average logit diff from using SAE reconstruction: {logit_diff_from_sae:.4f}")
Average logit diff from using SAE reconstruction: 0.4117
Using ActivationStore
The ActivationsStore class is a convenient alternative to loading a bunch of data yourself. It streams in data from a given dataset; in the case of the from_sae class that dataset will be given by your SAE's config (which is also the same as the SAE's original training dataset):
print(gpt2_sae.cfg.dataset_path)
Skylion007/openwebtext
Let's load one in now. We'll use fairly conservative parameters here so it can be used without running out of memory, but feel free to increase these parameters if you're able to (or decrease them if you still find yourself running out of memory).
gpt2_act_store = ActivationsStore.from_sae(
model=gpt2,
sae=gpt2_sae,
streaming=True,
store_batch_size_prompts=16,
n_batches_in_buffer=32,
device=str(device),
)
# Example of how you can use this:
tokens = gpt2_act_store.get_batch_tokens()
assert tokens.shape == (gpt2_act_store.store_batch_size_prompts, gpt2_act_store.context_size)
Replicating SAE dashboards
In this section, you'll replicate the 5 main components of the SAE dashboard: top logits tables, logits histogram, activation density plots, top activating sequences, and autointerp. There's not really any new content here, just putting into practice what you've learned from the previous 2 sections "Visualizing SAEs with dashboards" and "Running SAEs".
Now that we know how to load in and run SAEs, we can start replicating the components of the SAE dashboard in turn. These exercises will help build up your experience running SAEs & working with their activations, as well as helping you dive deeper into the meaning and significance of the different dashboard components.
To review, basic SAE dashboards have 5 main components:
- Activation Distribution - the distribution of a latent's activations
- Logits Distribution - projection of decoder weights onto model's unembedding
- Top / Botomn Logits - the most positive and most negative logits in the logit weight distribution
- Max Activating Examples - sequences (and particular tokens) on which the latent fires strongest
- Autointerp - llm-generated latent explanations
We'll go through each of these in turn. We'll be using latent 9 for this exercise; you can compare your results to the expected dashboard:
display_dashboard(latent_idx=9)
Exercise - get the activation distribution
The function below should iterate through some number of batches (note that you can reduce the default number if you find the code is taking too long), and creates a histogram of the activations for a given latent. Try and return the activation density in the histogram's title too.
Reminder - when using model.run_with_cache_with_saes, you can use the arguments stop_at_layer=sae.cfg.hook_layer+1 as well as names_filter=hook_name; these will help you avoid unnecessary computation and memory usage.
Also, note that if you're working in Colab & using Plotly then you might need to adjust this code so it computes & renders the figure in separate code cells - this is a well known but unfixed Colab bug.
def show_activation_histogram(
model: HookedSAETransformer,
sae: SAE,
act_store: ActivationsStore,
latent_idx: int,
total_batches: int = 200,
):
"""
Displays the activation histogram for a particular latent, computed across `total_batches`
batches from `act_store`.
"""
raise NotImplementedError()
show_activation_histogram(gpt2, gpt2_sae, gpt2_act_store, latent_idx=9)
Click to see the expected output
Click here for some code to plot a histogram (if you don't really care about this being part of the exercise)
This will work assuming all_positive_acts is a list of all non-zero activation values over all batches:
frac_active = len(all_positive_acts) / (
total_batches act_store.store_batch_size_prompts act_store.context_size
)
px.histogram(
all_positive_acts,
nbins=50,
title=f"ACTIVATIONS DENSITY {frac_active:.3%}",
labels={"value": "Activation"},
width=800,
template="ggplot2",
color_discrete_sequence=["darkorange"],
).update_layout(bargap=0.02, showlegend=False).show()
Note that if you're in Colab, you might need to return this figure and plot it in a separate cell (cause Colab is weird about plotting in notebooks in the same cell as it performs computation).
Solution
def show_activation_histogram(
model: HookedSAETransformer,
sae: SAE,
act_store: ActivationsStore,
latent_idx: int,
total_batches: int = 200,
):
"""
Displays the activation histogram for a particular latent, computed across total_batches
batches from act_store.
"""
sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"
all_positive_acts = []
for i in tqdm(range(total_batches), desc="Computing activations for histogram"):
tokens = act_store.get_batch_tokens()
_, cache = model.run_with_cache_with_saes(
tokens,
saes=[sae],
stop_at_layer=sae.cfg.hook_layer + 1,
names_filter=[sae_acts_post_hook_name],
)
acts = cache[sae_acts_post_hook_name][..., latent_idx]
all_positive_acts.extend(acts[acts > 0].cpu().tolist())
frac_active = len(all_positive_acts) / (
total_batches act_store.store_batch_size_prompts act_store.context_size
)
px.histogram(
all_positive_acts,
nbins=50,
title=f"ACTIVATIONS DENSITY {frac_active:.3%}",
labels={"value": "Activation"},
width=800,
template="ggplot2",
color_discrete_sequence=["darkorange"],
).update_layout(bargap=0.02, showlegend=False).show()
Exercise - find max activating examples
We'll start by finding the max-activating examples - the prompts that show the highest level of activation from a latent. We've given you a function with a docstring to complete, although exactly how you want to present the data is entirely up to you.
We've given you the following helper functions, as well as examples showing how to use them:
get_k_largest_indices, which will return the batch & seqpos indices of the k largest elements in a(batch, seq)-sized tensor,index_with_buffer, which will index into a(batch, seq)-sized tensor with the results ofget_k_largest_indices, including the tokens withinbufferfrom the selected indices in the same sequence (this helps us get context around the selected tokens),display_top_seqs, which will display sequences (with the relevant token highlighted) in a readable way.
When it comes to decoding sequences, you can use model.to_str_tokens to map a 1D tensor of token IDs to a list of string tokens. Note that you're likely to get some unknown tokens "�" in your output - this is an unfortunate byproduct of tokenization the way we're doing it, and you shouldn't worry about it too much.
def get_k_largest_indices(
x: Float[Tensor, "batch seq"], k: int, buffer: int = 0
) -> Int[Tensor, "k 2"]:
"""
The indices of the top k elements in the input tensor, i.e. output[i, :] is the (batch, seqpos)
value of the i-th largest element in x.
Won't choose any elements within `buffer` from the start or end of their sequence.
"""
if buffer > 0:
x = x[:, buffer:-buffer]
indices = x.flatten().topk(k=k).indices
rows = indices // x.size(1)
cols = indices % x.size(1) + buffer
return t.stack((rows, cols), dim=1)
x = t.arange(40, device=device).reshape((2, 20))
x[0, 10] += 50 # 2nd highest value
x[0, 11] += 100 # highest value
x[1, 1] += 150 # not inside buffer (it's less than 3 from the start of the sequence)
top_indices = get_k_largest_indices(x, k=2, buffer=3)
assert top_indices.tolist() == [[0, 11], [0, 10]]
def index_with_buffer(
x: Float[Tensor, "batch seq"], indices: Int[Tensor, "k 2"], buffer: int | None = None
) -> Float[Tensor, "k *buffer_x2_plus1"]:
"""
Indexes into `x` with `indices` (which should have come from the `get_k_largest_indices`
function), and takes a +-buffer range around each indexed element. If `indices` are less than
`buffer` away from the start of a sequence then we just take the first `2*buffer+1` elems (same
for at the end of a sequence).
If `buffer` is None, then we don't add any buffer and just return the elements at the given indices.
"""
rows, cols = indices.unbind(dim=-1)
if buffer is not None:
rows = einops.repeat(rows, "k -> k buffer", buffer=buffer * 2 + 1)
cols[cols < buffer] = buffer
cols[cols > x.size(1) - buffer - 1] = x.size(1) - buffer - 1
cols = einops.repeat(cols, "k -> k buffer", buffer=buffer * 2 + 1) + t.arange(
-buffer, buffer + 1, device=cols.device
)
return x[rows, cols]
x_top_values_with_context = index_with_buffer(x, top_indices, buffer=3)
assert x_top_values_with_context[0].tolist() == [
8,
9,
10 + 50,
11 + 100,
12,
13,
14,
] # highest value in the middle
assert x_top_values_with_context[1].tolist() == [
7,
8,
9,
10 + 50,
11 + 100,
12,
13,
] # 2nd highest value in the middle
def display_top_seqs(data: list[tuple[float, list[str], int]]):
"""
Given a list of (activation: float, str_toks: list[str], seq_pos: int), displays a table of
these sequences, with the relevant token highlighted.
We also turn newlines into "\\n", and remove unknown tokens � (usually weird quotation marks)
for readability.
"""
table = Table("Act", "Sequence", title="Max Activating Examples", show_lines=True)
for act, str_toks, seq_pos in data:
formatted_seq = (
"".join(
[
f"[b u green]{str_tok}[/]" if i == seq_pos else str_tok
for i, str_tok in enumerate(str_toks)
]
)
.replace("�", "")
.replace("\n", "↵")
)
table.add_row(f"{act:.3f}", repr(formatted_seq))
rprint(table)
example_data = [
(0.5, [" one", " two", " three"], 0),
(1.5, [" one", " two", " three"], 1),
(2.5, [" one", " two", " three"], 2),
]
display_top_seqs(example_data)
Max Activating Examples ┏━━━━━━━┳━━━━━━━━━━━━━━━━━━┓ ┃ Act ┃ Sequence ┃ ┡━━━━━━━╇━━━━━━━━━━━━━━━━━━┩ │ 0.500 │ ' one two three' │ ├───────┼──────────────────┤ │ 1.500 │ ' one two three' │ ├───────┼──────────────────┤ │ 2.500 │ ' one two three' │ └───────┴──────────────────┘
You should fill in the following function. It should return data as a list of tuples of the form (max activation, list of string tokens, sequence position), and if display is True then it should also call display_top_seqs on this data (you'll find this function helpful when we implement autointerp later!).
def fetch_max_activating_examples(
model: HookedSAETransformer,
sae: SAE,
act_store: ActivationsStore,
latent_idx: int,
total_batches: int = 100,
k: int = 10,
buffer: int = 10,
) -> list[tuple[float, list[str], int]]:
"""
Returns the max activating examples across a number of batches from the activations store.
"""
raise NotImplementedError()
# Fetch & display the results
buffer = 10
data = fetch_max_activating_examples(
gpt2, gpt2_sae, gpt2_act_store, latent_idx=9, buffer=buffer, k=5
)
display_top_seqs(data)
# Test one of the results, to see if it matches the expected output
first_seq_str_tokens = data[0][1]
assert first_seq_str_tokens[buffer] == " new"
Max Activating Examples ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Act ┃ Sequence ┃ ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ 43.451 │ '.↵↵Airline industry↵↵Under the new rules, payment surcharges will have to reflect the' │ ├────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 41.921 │ '145m.↵↵Enforcement↵↵The new rules are being brought in earlier than the rest of' │ ├────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 39.845 │ ' activity," the minister had said.↵↵The new law is a precursor to banning chewing tobacco in public' │ ├────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 37.051 │ '."↵↵Niedermeyer agreed that the new car excitement "tapers off" the longer it' │ ├────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 37.004 │ ' each other as soon as possible."↵↵The new desert map will be included in the 1.0' │ └────────┴───────────────────────────────────────────────────────────────────────────────────────────────────────┘
Solution
def fetch_max_activating_examples(
model: HookedSAETransformer,
sae: SAE,
act_store: ActivationsStore,
latent_idx: int,
total_batches: int = 100,
k: int = 10,
buffer: int = 10,
) -> list[tuple[float, list[str], int]]:
"""
Returns the max activating examples across a number of batches from the activations store.
"""
sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"
# Create list to store the top k activations for each batch. Once we're done,
# we'll filter this to only contain the top k over all batches
data = []
for _ in tqdm(range(total_batches), desc="Computing activations for max activating examples"):
tokens = act_store.get_batch_tokens()
_, cache = model.run_with_cache_with_saes(
tokens,
saes=[sae],
stop_at_layer=sae.cfg.hook_layer + 1,
names_filter=[sae_acts_post_hook_name],
)
acts = cache[sae_acts_post_hook_name][..., latent_idx]
# Get largest indices, get the corresponding max acts, and get the surrounding indices
k_largest_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
tokens_with_buffer = index_with_buffer(tokens, k_largest_indices, buffer=buffer)
str_toks = [model.to_str_tokens(toks) for toks in tokens_with_buffer]
top_acts = index_with_buffer(acts, k_largest_indices).tolist()
data.extend(list(zip(top_acts, str_toks, [buffer] * len(str_toks))))
return sorted(data, key=lambda x: x[0], reverse=True)[:k]
Non-overlapping sequences
For the latent above, returning sequences the way you did probably worked pretty well. But other more concept-level latents (where multiple tokens in a sentence fire strongly) are a bit more annoying. You can try this function on a latent like 16873 (which fires on specific bible passages) - the returned sequences will mostly be the same, just shifted over by a different amount.
data = fetch_max_activating_examples(
gpt2, gpt2_sae, gpt2_act_store, latent_idx=16873, total_batches=200
)
display_top_seqs(data)
Click to see the expected output
Max Activating Examples ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Act ┃ Sequence ┃ ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ 15.400 │ 'aqara 2:221 =↵↵And do not marry polytheistic women until they believe. And' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 14.698 │ 'Baqara 2:221 =↵↵And do not marry polytheistic women until they believe.' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 14.648 │ ' Testament: Verily, verily, I say unto you, Except a corn of wheat fall into the' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 14.124 │ ' not marry polytheistic women until they believe. And a believing slave woman is better than a │ │ │ polythe' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 14.069 │ ' thou to me? And Jesus answering said unto him, Suffer it to be so now: for thus' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 13.461 │ ' "But John forbad him, saying, I have a need to be baptised of thee, and' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 13.024 │ '14–15: "But John forbad him, saying, I have a need to be baptised' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 12.874 │ ' to me? And Jesus answering said unto him, Suffer it to be so now: for thus it' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 12.801 │ ' Suffer it to be so now: for thus it becometh us to fulfil all righteousness", and' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 12.747 │ ':14–15: "But John forbad him, saying, I have a need to be bapt' │ └────────┴────────────────────────────────────────────────────────────────────────────────────────────────────────┘
One way you can combat this is by emposing the restriction that any given top-activating token can only be in one sequence, i.e. when you pick that token you can't pick any in the range [-buffer, buffer] around it. We've given you a new function get_k_largest_indices below. Try it out with no_overlap=True, are the results much better?
def get_k_largest_indices(
x: Float[Tensor, "batch seq"],
k: int,
buffer: int = 0,
no_overlap: bool = True,
) -> Int[Tensor, "k 2"]:
"""
Returns the tensor of (batch, seqpos) indices for each of the top k elements in the tensor x.
Args:
buffer: We won't choose any elements within `buffer` from the start or end of their seq
(this helps if we want more context around the chosen tokens).
no_overlap: If True, this ensures that no 2 top-activating tokens are in the same seq and
within `buffer` of each other.
"""
assert buffer * 2 < x.size(1), "Buffer is too large for the sequence length"
assert not no_overlap or k <= x.size(0), (
"Not enough sequences to have a different token in each sequence"
)
if buffer > 0:
x = x[:, buffer:-buffer]
indices = x.flatten().argsort(-1, descending=True)
rows = indices // x.size(1)
cols = indices % x.size(1) + buffer
if no_overlap:
unique_indices = t.empty((0, 2), device=x.device).long()
while len(unique_indices) < k:
unique_indices = t.cat(
(unique_indices, t.tensor([[rows[0], cols[0]]], device=x.device))
)
is_overlapping_mask = (rows == rows[0]) & ((cols - cols[0]).abs() <= buffer)
rows = rows[~is_overlapping_mask]
cols = cols[~is_overlapping_mask]
return unique_indices
return t.stack((rows, cols), dim=1)[:k]
x = t.arange(40, device=device).reshape((2, 20))
x[0, 10] += 150 # highest value
x[0, 11] += 100 # 2nd highest value, but won't be chosen because of overlap
x[1, 10] += 50 # 3rd highest, will be chosen
top_indices = get_k_largest_indices(x, k=2, buffer=3)
assert top_indices.tolist() == [[0, 10], [1, 10]]
data = fetch_max_activating_examples(
gpt2, gpt2_sae, gpt2_act_store, latent_idx=16873, total_batches=200
)
display_top_seqs(data)
Click to see the expected output
Max Activating Examples ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Act ┃ Sequence ┃ ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ 17.266 │ ' seven times in a day be converted unto thee, saying: I repent: forgive him. And the apostles' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 16.593 │ '. And the apostles said to the Lord: Increase our faith. (Luke 17:3-5)' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 16.146 │ ' say unto you, Arise, and take up your couch, and go into your house." (Luke' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 15.965 │ ' mystery of the kingdom of God: but unto them that are without, all these things are done in par' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 14.985 │ ' parable. And he said unto them: To you it is given to know the mystery of the kingdom' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 14.383 │ ' forgiven you; or to say, Rise up and walk? But that you may know that the Son of' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 13.476 │ 'he said unto him that was palsied:) I say unto you, Arise, and take up' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 13.430 │ ' their thoughts, he answered and said unto them, "What reason have you in your hearts? Which is' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 13.181 │ ' things are done in parables: That seeing they may see, and not perceive, and hearing they may' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 13.071 │ ' seven times?" Jesus says unto him, "I say not unto you, Until seven times, but until' │ └────────┴────────────────────────────────────────────────────────────────────────────────────────────────────┘
Exercise - get top / bottom logits
We'll end with the top & bottom logits tables. These don't require data, since they're just functions of the SAE and model's weights. Recall - you can access the unembedding of your base model using model.W_U, and you can access your SAE's decoder weights using sae.W_dec.
def show_top_logits(
model: HookedSAETransformer,
sae: SAE,
latent_idx: int,
k: int = 10,
) -> None:
"""
Displays the top & bottom logits for a particular latent.
"""
raise NotImplementedError()
show_top_logits(gpt2, gpt2_sae, latent_idx=9)
tests.test_show_top_logits(show_top_logits, gpt2, gpt2_sae)
Click to see the expected output
┌─────────────────┬─────────┬─────────────────┬─────────┐ │ Bottom tokens │ Value │ Top tokens │ Value │ ├─────────────────┼─────────┼─────────────────┼─────────┤ │ 'Zip' │ -0.774 │ 'bies' │ +1.327 │ │ 'acebook' │ -0.761 │ 'bie' │ +1.297 │ │ 'lua' │ -0.737 │ ' arrivals' │ +1.218 │ │ 'ashtra' │ -0.728 │ ' additions' │ +1.018 │ │ 'ONSORED' │ -0.708 │ ' edition' │ +0.994 │ │ 'OGR' │ -0.705 │ ' millennium' │ +0.966 │ │ 'umenthal' │ -0.703 │ ' generation' │ +0.962 │ │ 'ecause' │ -0.697 │ ' entrants' │ +0.923 │ │ 'icio' │ -0.692 │ ' hires' │ +0.919 │ │ 'cius' │ -0.692 │ ' developments' │ +0.881 │ └─────────────────┴─────────┴─────────────────┴─────────┘
Solution
def show_top_logits(
model: HookedSAETransformer,
sae: SAE,
latent_idx: int,
k: int = 10,
) -> None:
"""
Displays the top & bottom logits for a particular latent.
"""
logits = sae.W_dec[latent_idx] @ model.W_U
pos_logits, pos_token_ids = logits.topk(k)
pos_tokens = model.to_str_tokens(pos_token_ids)
neg_logits, neg_token_ids = logits.topk(k, largest=False)
neg_tokens = model.to_str_tokens(neg_token_ids)
print(
tabulate(
zip(map(repr, neg_tokens), neg_logits, map(repr, pos_tokens), pos_logits),
headers=["Bottom tokens", "Value", "Top tokens", "Value"],
tablefmt="simple_outline",
stralign="right",
numalign="left",
floatfmt="+.3f",
)
)
Exercise - autointerp
Automated interpretability is one particularly exciting area of research at the moment. It originated with the OpenAI paper Language models can explain neurons in language models, which showed that we could take a neuron from GPT-2 and use GPT-4 to generate explanations of its behaviour by showing it relevant text sequences and activations. This shows us one possible way to assess and categorize latents at scale, without requiring humans to manually inspect and label them (which would obviously be totally impractical to scale). You can also read about some more of the recent advancements made here and here.
This scalable efficient categorization is one use-case, but there's also a second one: SAE evaluations. This is a topic we'll dive deeper into in later sections, but to summarize now: SAE evaluations are the ways we measure "how good our SAE is" in a varity of different ways. It turns out this is very hard, because any metric we pick is vulnerable to Goodhearting and not necessarily representative of what we want out of our SAEs. For example, it seems like sparser latents are often more interpretable (if reconstruction loss is held constant), which is why a common way to evaluate SAEs is along a Pareto frontier of sparsity & reconstruction loss. But what if sparsity doesn't lead to more interpretable latents (e.g. because of feature absorption)? Autointerp provides an alternative way of evaluating SAE interpretability, because we can directly quantify how good our latent explanations are! The idea is to convert latent explanations into a set of predictions on some test set of prompts, and then score the accuracy of those predictions. More interpretable latents should lead to better predictions, because the latents will tend to have monosemantic and human-interpretable patterns that can be predicted from their given explanations.
For now though, we'll focus on just the first half of autointerp, i.e. the generation of explanations. You can download and read the Neuronpedia-hosted latent explanations with the following code:
def get_autointerp_df(
sae_release="gpt2-small-res-jb", sae_id="blocks.7.hook_resid_pre"
) -> pd.DataFrame:
release = get_pretrained_saes_directory()[sae_release]
neuronpedia_id = release.neuronpedia_id[sae_id]
url = "https://www.neuronpedia.org/api/explanation/export?modelId={}&saeId={}".format(
*neuronpedia_id.split("/")
)
headers = {"Content-Type": "application/json"}
response = requests.get(url, headers=headers)
data = response.json()
return pd.DataFrame(data)
explanations_df = get_autointerp_df()
explanations_df.head()
Now, let's try doing some autointerp ourselves! This will involve 3 steps:
- Calling
fetch_max_activating_examplesto get the top-activating examples for a given latent. - Calling
create_promptto create a system, user & assistant prompt for the OpenAI API which contains this data. - Calling
get_autointerp_explanationto pass these prompts to the OpenAI API and get a response.
You've already implemented fetch_max_activating_examples, and we've given you get_autointerp_explanation - you just need to implement create_prompt.
Click to see recommended autointerp prompt structure
One possible method based on Anthropic's past published material is to show a list of top sequences and the activations for every single token, and ask for an explanation (then in the scoring phase we'd ask the model to predict activation values). However, here we'll do something a bit simpler, and just highlight the top-activating token in any given sequence (without giving numerical activation values).
{
"system": "We're studying neurons in a neural network. Each neuron activates on some particular word or concept in a short document. The activating words in each document are indicated with << ... >>. Look at the parts of the document the neuron activates for and summarize in a single sentence what the neuron is activating on. Try to be specific in your explanations, although don't be so specific that you exclude some of the examples from matching your explanation. Pay attention to things like the capitalization and punctuation of the activating words or concepts, if that seems relevant. Keep the explanation as short and simple as possible, limited to 20 words or less. Omit punctuation and formatting. You should avoid giving long lists of words.",
"user": """The activating documents are given below:
1. and he was <<over the moon>> to find
2. we'll be laughing <<till the cows come home>>! Pro
3. thought Scotland was boring, but really there's more <<than meets the eye>>! I'd""",
"assistant": "this neuron fires on",
}
We feed the system, then user, then assistant prompt into our model. The idea is:
- The system prompt explains what the task will be, - The user prompt contains the actual task data, - The assistant prompt helps condition the model's likely response format.
Note - this is all very low-tech, and we'll expand greatly on these methods when we dive deeper into autointerp in later sections.
We recommend you use the augmented version of get_k_largest_indices you were given above, which doesn't allow for sequence overlap. This is because you don't want to send redundant information in your prompt!
def create_prompt(
model: HookedSAETransformer,
sae: SAE,
act_store: ActivationsStore,
latent_idx: int,
total_batches: int = 100,
k: int = 15,
buffer: int = 10,
) -> dict[Literal["system", "user", "assistant"], str]:
"""
Returns the system, user & assistant prompts for autointerp.
"""
raise NotImplementedError()
# Test your function
prompts = create_prompt(
gpt2, gpt2_sae, gpt2_act_store, latent_idx=9, total_batches=100, k=15, buffer=8
)
assert prompts["system"].startswith("We're studying neurons in a neural network.")
assert "<< new>>" in prompts["user"]
assert prompts["assistant"] == "this neuron fires on"
Solution
def create_prompt(
model: HookedSAETransformer,
sae: SAE,
act_store: ActivationsStore,
latent_idx: int,
total_batches: int = 100,
k: int = 15,
buffer: int = 10,
) -> dict[Literal["system", "user", "assistant"], str]:
"""
Returns the system, user & assistant prompts for autointerp.
"""
data = fetch_max_activating_examples(
model, sae, act_store, latent_idx, total_batches, k, buffer
)
str_formatted_examples = "\n".join(
f"{i + 1}. {''.join(f'<<{tok}>>' if j == buffer else tok for j, tok in enumerate(seq[1]))}"
for i, seq in enumerate(data)
)
return {
"system": "We're studying neurons in a neural network. Each neuron activates on some particular word or concept in a short document. The activating words in each document are indicated with << ... >>. Look at the parts of the document the neuron activates for and summarize in a single sentence what the neuron is activating on. Try to be specific in your explanations, although don't be so specific that you exclude some of the examples from matching your explanation. Pay attention to things like the capitalization and punctuation of the activating words or concepts, if that seems relevant. Keep the explanation as short and simple as possible, limited to 20 words or less. Omit punctuation and formatting. You should avoid giving long lists of words.",
"user": f"""The activating documents are given below:\n\n{str_formatted_examples}""",
"assistant": "this neuron fires on",
}
Once you've passed the tests for create_prompt, you can implement the full get_autointerp_explanation function:
def get_autointerp_explanation(
model: HookedSAETransformer,
sae: SAE,
act_store: ActivationsStore,
latent_idx: int,
total_batches: int = 100,
k: int = 15,
buffer: int = 10,
n_completions: int = 1,
) -> list[str]:
"""
Queries OpenAI's API using prompts returned from `create_prompt`, and returns a list of the
completions.
"""
raise NotImplementedError()
API_KEY = os.environ.get("OPENAI_API_KEY", None)
if API_KEY is not None:
completions = get_autointerp_explanation(
gpt2, gpt2_sae, gpt2_act_store, latent_idx=9, n_completions=5
)
for i, completion in enumerate(completions):
print(f"Completion {i + 1}: {completion!r}")
else:
print("No API key found, not running the autointerp code.")
Click to see the expected output
Completion 1: 'the concept of new policies or products being introduced' Completion 2: 'the concept of new initiatives or products' Completion 3: 'the concept of new ideas or initiatives' Completion 4: 'the concept of new developments or initiatives' Completion 5: 'the concept of new initiatives or products'
Solution
def get_autointerp_explanation(
model: HookedSAETransformer,
sae: SAE,
act_store: ActivationsStore,
latent_idx: int,
total_batches: int = 100,
k: int = 15,
buffer: int = 10,
n_completions: int = 1,
) -> list[str]:
"""
Queries OpenAI's API using prompts returned from create_prompt, and returns a list of the
completions.
"""
client = OpenAI(api_key=API_KEY)
prompts = create_prompt(model, sae, act_store, latent_idx, total_batches, k, buffer)
result = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": prompts["system"]},
{"role": "user", "content": prompts["user"]},
{"role": "assistant", "content": prompts["assistant"]},
],
n=n_completions,
max_tokens=50,
stream=False,
)
return [choice.message.content for choice in result.choices]
Attention SAEs
In this section, you'll learn about attention SAEs, how they work (mostly quite similar to standard SAEs but with a few other considerations), and how to understand their feature dashboards. Key points:
- Attention SAEs have the same architecture as regular SAEs, except they're trained on the concatenated pre-projection output of all attention heads.
- If a latent fires on a destination token, we can use direct latent attribution to see which source tokens it primarily came from.
- Just like regular SAEs, latents found in different layers of a model are often qualitatively different from each other.
In this section, we'll be exploring different ways of finding a latent for some given concept. However, before we get into that, we first need to introduce a new concept - attention SAEs.
Research done by Kissane el al as part of the MATS program has shown that we can use SAEs on the output of attention layers, and it also works well: the SAEs learn sparse, interpretable latents, which gives us insight into what attention layers learn. Subsequent work trained SAEs on the attention output of every layer of GPT2Small; these are the SAEs we'll be using in today's exercises.
Functionally, these SAEs work just like regular ones, except that they take the z output of the attention layer as input (i.e. after taking a linear combination of value vectors, but before projecting via the output matrix) rather than taking the residual stream or post-ReLU MLP activations. These z vectors are usually concatenated together across attention heads, for a single layer.
Can you see why we take the attention output before projection via the output matrix, rather than after?
It would be a waste of parameters. The encoder is a linear map from activation space to latent space, and the attention head's output z @ W_O can't have a larger rank than z (though it might have smaller rank), but it will be larger and hence will lead to less efficient training.
Can you guess why we concatenate across attention heads?
We do this because heads might be in superposition, just like neurons in MLP layers. As well as a single attention head containing many latents, we could have a latent which is split across multiple attention heads. Evidence of shared attention head functionality abounds in regular models, for instance in the intro to mech interp ARENA exercises, we examined a 2-layer model where 2 heads in the second layer came together to form a copying head. In that case, we might expect to find latents which are split across both heads.
However, one interesting thing about attention SAEs is that we also have to think about source tokens, not just the destination token. In other words, once we identify some attention latent that is present at a particular destination token, we still need to ask the question of which source tokens it came from.
This leads us to the tool of direct latent attribution (which we'll abbreviate as "DFA" or "direct feature attribution", just so it doesn't get confused with direct logit attribution!). Just as in direct logit attribution (DLA) we ask which components wrote to the residual stream in ways which directly influenced certain logits, with DFA we can decompose the input to the destination token which caused that latent to fire. This can tell us things like which head contributed most to that latent, or which source token (or both).
Exercise - explore attention SAE dashboards
Run the code below to see an example of a latent dashboard for a layer-9 attention head in GPT2-Small. The green text on the right shows you when this latent was activated, and the orange highlight shows you the DFA for the primary token which activated it.
You should specifically try think about the qualitative difference between latents you see in SAEs trained at different layers (you can change the layer variable below to investigate this). What do you notice? What kinds of latents exist at earlier layers but not later ones? Which layers have more interpretable output logits?
Some things you should see
You should generally find the following themes:
- Early layer head latents often represent low-level grammatical syntax (e.g. firing on single tokens or bigrams). - Middle layer head latents are often the hardest to interpret because they respond to higher-level semantic information, but also aren't always interpretable in terms of their output logits, likely because they write to intermediate representations which are then used by other heads or MLP layers. - Late layer head latents are often understood in terms of their output logits (i.e. they're directly writing predictions to the residual stream). The very last layer is something of an exception to this, since it seems to deal largely with grammatical corrections and adjustments.
For more on this, you can read the table in the LessWrong post [We Inspected Every Head In GPT-2 Small using SAEs So You Don’t Have To](https://www.lesswrong.com/posts/xmegeW5mqiBsvoaim/we-inspected-every-head-in-gpt-2-small-using-saes-so-you-don#Overview_of_Attention_Heads_Across_Layers) (which looks at the same SAEs we're working with here).
attn_saes = {
layer: SAE.from_pretrained(
"gpt2-small-hook-z-kk",
f"blocks.{layer}.hook_z",
device=str(device),
)[0]
for layer in range(gpt2.cfg.n_layers)
}
layer = 9
display_dashboard(
sae_release="gpt2-small-hook-z-kk",
sae_id=f"blocks.{layer}.hook_z",
latent_idx=2, # or you can try `random.randint(0, attn_saes[layer].cfg.d_sae)`
)
Exercise - derive attention DFA
Since we've added another component to our latent dashboard, let's perform another derivation! For any given destination token firing, we can define the direct latent attribution (DFA) of a given source token as the dot product of the value vector taken at that source token (scaled by the attention probability) and the SAE encoder direction for this latent. In other words, the pre-ReLU activations of the latent are equal to a sum of its DFA over all source tokens.
When you complete this problem, you'll be able to complete the final part of the attention dashboard. We've given you a function to visualize this, like we used for the residual stream SAEs:
@dataclass
class AttnSeqDFA:
act: float
str_toks_dest: list[str]
str_toks_src: list[str]
dest_pos: int
src_pos: int
def display_top_seqs_attn(data: list[AttnSeqDFA]):
"""
Same as previous function, but we now have 2 str_tok lists and 2 sequence positions to
highlight, the first being for top activations (destination token) and the second for top DFA
(src token). We've given you a dataclass to help keep track of this.
"""
table = Table(
"Top Act",
"Src token DFA (for top dest token)",
"Dest token",
title="Max Activating Examples",
show_lines=True,
)
for seq in data:
formatted_seqs = [
repr(
"".join(
[
f"[b u {color}]{str_tok}[/]" if i == seq_pos else str_tok
for i, str_tok in enumerate(str_toks)
]
)
.replace("�", "")
.replace("\n", "↵")
)
for str_toks, seq_pos, color in [
(seq.str_toks_src, seq.src_pos, "dark_orange"),
(seq.str_toks_dest, seq.dest_pos, "green"),
]
]
table.add_row(f"{seq.act:.3f}", *formatted_seqs)
rprint(table)
str_toks = [" one", " two", " three", " four"]
example_data = [
AttnSeqDFA(
act=0.5, str_toks_dest=str_toks[1:], str_toks_src=str_toks[:-1], dest_pos=0, src_pos=0
),
AttnSeqDFA(
act=1.5, str_toks_dest=str_toks[1:], str_toks_src=str_toks[:-1], dest_pos=1, src_pos=1
),
AttnSeqDFA(
act=2.5, str_toks_dest=str_toks[1:], str_toks_src=str_toks[:-1], dest_pos=2, src_pos=0
),
]
display_top_seqs_attn(example_data)
Max Activating Examples ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Top Act ┃ Src token DFA (for top dest token) ┃ Dest token ┃ ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ 0.500 │ ' one two three' │ ' two three four' │ ├─────────┼────────────────────────────────────┼───────────────────┤ │ 1.500 │ ' one two three' │ ' two three four' │ ├─────────┼────────────────────────────────────┼───────────────────┤ │ 2.500 │ ' one two three' │ ' two three four' │ └─────────┴────────────────────────────────────┴───────────────────┘
Now, fill in the function below. You'll be testing your function by comparing the output to the dashboard generated from latent_idx=2 (shown above). You should have observed that the top activating tokens are conjunctions liek " and", " or" which connect lists of words like " weapons", " firearms", " missiles", etc. We can also see that the max DFA token is usually a similar word earlier in context, sometimes the word immediately before the top activating token (which makes sense given the positive logits also boost words like this). In other words, part of what this latent seems to be doing is detecting when we're in the middle of a conjunction phrase like weapons and tactics or guns ... not missiles and predicting logical completions for how this phrase will end.
Note - this is a pretty difficult problem (mostly because of all the rearranging and multiple steps in the solution). If you get stuck, we strongly recommend using the hint below.
Help - I'm still confused about this calculation.
We have the value vectors v of shape (batch, seq, n_heads, d_head) at each position. By broadcasting & multiplying by the attention probabilities, we can get v_weighted of shape (batch, seq_dest, seq_src, n_heads, d_head), which represents the vector that'll be taken at each source position and added to the destination position, and will be summed over to produce z (the values we have before projection by the output matrix W_O to add back to the residual stream).
It's this z (after flattening over attention heads) that the SAE gets trained on, i.e. z @ sae.W_enc are the SAE's pre-ReLU activations. So by writing z as a sum of v_weighted over source positions, we can write the pre-ReLU activation for latent latent_idx as a sum of v_weighted[:, :, src_pos, :, :] @ sae.W_enc[:, latent_idx] over all src_pos values. So for any given sequence b in the batch, and destination position dest_pos, we can compute the scalar v_weighted[b, dest_pos, src_pos, :, :] @ sae.W_enc[:, latent_idx] for each src_pos, and find the largest one.
Reminder - W_enc is actually a linear map from n_heads * d_head to d_sae dimensions, so to perform this calculation we'll first need to flatten the values v_weighted over heads.
Note, some of the src token indexing can get a bit fiddly. In particular, when you get the index positions of the top contributing source tokens, some of them might be within buffer of the start of the sequence. The index_with_buffer handles this case for you, because whenever the indexing values are within buffer of the start or end of the sequence, it'll just take the first or last buffer tokens respectively.
def fetch_max_activating_examples_attn(
model: HookedSAETransformer,
sae: SAE,
act_store: ActivationsStore,
latent_idx: int,
total_batches: int = 250,
k: int = 10,
buffer: int = 10,
) -> list[AttnSeqDFA]:
"""
Returns the max activating examples across a number of batches from the activations store.
"""
raise NotImplementedError()
# Test your function: compare it to dashboard above
# (max DFA should come from sourcs tokens like " guns", " firearms")
layer = 9
data = fetch_max_activating_examples_attn(gpt2, attn_saes[layer], gpt2_act_store, latent_idx=2)
display_top_seqs_attn(data)
Click to see the expected output
Max Activating Examples ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Top Act ┃ Src token DFA (for top dest token) ┃ Dest token ┃ ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ 2.593 │ ' After all, you cant order numerous guns and │ ' cant order numerous guns and massive amounts of │ │ │ massive amounts of ammunition and also Batman │ ammunition and also Batman costumes on line │ │ │ costumes on' │ without generating search' │ ├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤ │ 2.000 │ ' on, as far as when they would draw a weapon and │ " when they would draw a weapon and when they │ │ │ when they would use force, is based off' │ would use force, is based off of officer's │ │ │ │ perception," │ ├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤ │ 1.921 │ '↵↵So chances are a good guy with a gun will not │ ' guy with a gun will not stop a bad guy with a │ │ │ stop a bad guy with a gun.' │ gun. In fact, trying to produce more' │ ├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤ │ 1.780 │ ' red lines on chemical weapons, he took out the │ ' he took out the weapons, but not those who used │ │ │ weapons, but not those who used them. I don' │ them. I dont think history will' │ ├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤ │ 1.501 │ ' academy teaches young trainees about when to │ 'ees about when to use deadly force or when to │ │ │ use deadly force or when to use non-lethal force, │ use non-lethal force, such as pepper spray or' │ │ │ such' │ │ ├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤ │ 1.430 │ ' the past, religion is being used as both a │ ', religion is being used as both a weapon and a │ │ │ weapon and a shield by those seeking to deny │ shield by those seeking to deny others equality.' │ │ │ others equality' │ │ ├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤ │ 1.389 │ ' in the fact that no ancient American steel │ ' fact that no ancient American steel tools or │ │ │ tools or weapons (or traces or evidence of them │ weapons (or traces or evidence of them or their │ │ │ or their manufacture' │ manufacture) have' │ ├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤ │ 1.381 │ ' with a gun. The body riddled with 9mm bullets │ ' The body riddled with 9mm bullets and the │ │ │ and the presence of 9mm shell casings on' │ presence of 9mm shell casings on the floor is │ │ │ │ sufficient' │ ├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤ │ 1.186 │ '↵Unicorn Mask and Pyjamas and weapons and Sticky │ 'Unicorn Mask and Pyjamas and weapons and Sticky │ │ │ bombs.↵↵Thanks @Rock' │ bombs.↵↵Thanks @Rockstar' │ ├─────────┼───────────────────────────────────────────────────┼───────────────────────────────────────────────────┤ │ 1.045 │ ' minutes later, he heard shots being fired from │ ' later, he heard shots being fired from │ │ │ automatic rifles and cries of distress, the group │ automatic rifles and cries of distress, the group │ │ │ said.↵' │ said.↵↵' │ └─────────┴───────────────────────────────────────────────────┴───────────────────────────────────────────────────┘
Solution
def fetch_max_activating_examples_attn(
model: HookedSAETransformer,
sae: SAE,
act_store: ActivationsStore,
latent_idx: int,
total_batches: int = 250,
k: int = 10,
buffer: int = 10,
) -> list[AttnSeqDFA]:
"""
Returns the max activating examples across a number of batches from the activations store.
"""
sae_acts_pre_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_pre"
v_hook_name = get_act_name("v", sae.cfg.hook_layer)
pattern_hook_name = get_act_name("pattern", sae.cfg.hook_layer)
data = []
for _ in tqdm(
range(total_batches), desc="Computing activations for max activating examples (attn)"
):
tokens = act_store.get_batch_tokens()
_, cache = model.run_with_cache_with_saes(
tokens,
saes=[sae],
stop_at_layer=sae.cfg.hook_layer + 1,
names_filter=[sae_acts_pre_hook_name, v_hook_name, pattern_hook_name],
)
acts = cache[sae_acts_pre_hook_name][..., latent_idx] # [batch seq]
# Get largest indices (i.e. dest tokens), and the tokens at those positions (plus buffer)
k_largest_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
top_acts = index_with_buffer(acts, k_largest_indices).tolist()
dest_toks_with_buffer = index_with_buffer(tokens, k_largest_indices, buffer=buffer)
str_toks_dest_list = [model.to_str_tokens(toks) for toks in dest_toks_with_buffer]
# Get source token value vectors & dest-to-src attention patterns, for each of our chosen
# destination tokens
batch_indices, dest_pos_indices = k_largest_indices.unbind(-1)
v = cache[v_hook_name][batch_indices] # shape [k src n_heads d_head]
pattern = cache[pattern_hook_name][batch_indices, :, dest_pos_indices] # [k n_heads src]
# Multiply them together to get weighted value vectors, and reshape them to d_in = n_heads d_head
v_weighted = v einops.rearrange(pattern, "k n src -> k src n 1")
v_weighted = v_weighted.flatten(-2, -1) # [k src d_in]
# Map through our SAE encoder to get direct feature attribution for each src token, and argmax over src tokens
dfa = v_weighted @ sae.W_enc[:, latent_idx] # shape [k src]
src_pos_indices = dfa.argmax(dim=-1)
src_toks_with_buffer = index_with_buffer(
tokens, t.stack([batch_indices, src_pos_indices], -1), buffer=buffer
)
str_toks_src_list = [model.to_str_tokens(toks) for toks in src_toks_with_buffer]
# Add all this data to our list
for act, str_toks_dest, str_toks_src, src_pos in zip(
top_acts, str_toks_dest_list, str_toks_src_list, src_pos_indices
):
data.append(
AttnSeqDFA(
act=act,
str_toks_dest=str_toks_dest, # top activating dest tokens, with buffer
str_toks_src=str_toks_src, # top DFA src tokens for the dest token, with buffer
dest_pos=buffer, # dest token is always in the middle of its buffer
src_pos=min(src_pos, buffer), # deal with case where src token is near start
)
)
return sorted(data, key=lambda x: x.act, reverse=True)[:k]
Finding latents for features
In this section, you'll explore different methods (some causal, some not) for finding latents in SAEs corresponding to particular features. Key points:
- You can look at max activating latents on some particular input prompt, this is basically the simplest thing you can do
- Direct logit attribution (DLA) is a bit more refined; you can find latents which have a direct effect on specific logits
- Ablation of SAE latents can help you find latents which are important in a non-direct way
- ...but it's quite costly for a large number of latents, so you can use attribution patching as a cheaper linear approximation of ablation
We'll now proceed through a set of 3 different methods that can be used to find features which activate on a given concept or language structure. We'll focus on trying to find IOI features - features which seem to activate on the indirect object identification pattern. You might already be familiar with this via the IOI exercises in ARENA. If not, it's essentially sentences of the form "When John and Mary went to the shops, John gave the bag to" -> " Mary". Models like GPT2-Small are able to learn this pattern via the following algorithm:
- Duplicate token heads in layers 0-3 attend from the second
" John"token (we call itS2; the second subject) back to the first" John"token (S1), and store the fact that it's duplicated. - S-inhibition heads in layers 7-8 attend from
" to"back to the" John"token, and store information that this token is duplicated. - Name-mover heads in layers 9-10 attend from
" to"back to any non-duplicated names (using Q-composition with the output of the S-Inhibition heads to avoid attending to the duplicated" John"tokens). So they attend to" Mary", and move this information into the unembedding space to be used as the model's prediction.
Sadly, our SAEs aren't yet advanced enough to pick up on S-inhibition features. As discussed above, these mid-layer heads which read from & write to subspaces containing intermediate representations are pretty difficult to interpret. In fact, reading this section of the LessWrong post analyzing our GPT2-Small attention SAEs, we find that the worst layers for "% alive features interpretable" as well as "loss recovered" are around 7 and 8, which is precisely the location of our S-Inhibition heads!
However, the authors of that post were able to find duplicate token features as well as name-mover featires, and in the following exercises we'll replicate their work!
Before we start, let's first make sure that the model can actually solve this sentence. To make our results a bit more robust (e.g. so we're not just isolating "gender features" or something), we'll control by using 4 different prompts: with "John" and "Mary" as answers flipped around, and also with the sentence structure flipped around (ABBA vs ABAB). The code below also gives you the logits_to_ave_logit_diff function, which you might find useful in some later exercises.
names = [" John", " Mary"]
name_tokens = [gpt2.to_single_token(name) for name in names]
prompt_template = "When{A} and{B} went to the shops,{S} gave the bag to"
prompts = [
prompt_template.format(A=names[i], B=names[1 - i], S=names[j])
for i, j in itertools.product(range(2), range(2))
]
correct_answers = names[::-1] * 2
incorrect_answers = names * 2
correct_toks = gpt2.to_tokens(correct_answers, prepend_bos=False)[:, 0].tolist()
incorrect_toks = gpt2.to_tokens(incorrect_answers, prepend_bos=False)[:, 0].tolist()
def logits_to_ave_logit_diff(
logits: Float[Tensor, "batch seq d_vocab"],
correct_toks: list[int] = correct_toks,
incorrect_toks: list[int] = incorrect_toks,
reduction: Literal["mean", "sum"] | None = "mean",
keep_as_tensor: bool = False,
) -> list[float] | float:
"""
Returns the avg logit diff on a set of prompts, with fixed s2 pos and stuff.
"""
correct_logits = logits[range(len(logits)), -1, correct_toks]
incorrect_logits = logits[range(len(logits)), -1, incorrect_toks]
logit_diff = correct_logits - incorrect_logits
if reduction is not None:
logit_diff = logit_diff.mean() if reduction == "mean" else logit_diff.sum()
return logit_diff if keep_as_tensor else logit_diff.tolist()
# Testing a single prompt (where correct answer is John), verifying model gets it right
test_prompt(prompts[1], names, gpt2)
# Testing logits over all 4 prompts, verifying the model always has a high logit diff
logits = gpt2(prompts, return_type="logits")
logit_diffs = logits_to_ave_logit_diff(logits, reduction=None)
print(
tabulate(
zip(prompts, correct_answers, logit_diffs),
headers=["Prompt", "Answer", "Logit Diff"],
tablefmt="simple_outline",
numalign="left",
floatfmt="+.3f",
)
)
Tokenized prompt: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' Mary', ' gave', ' the', ' bag', ' to'] Tokenized answers: [[' John'], [' Mary']] Performance on answer tokens: Rank: 0 Logit: 18.03 Prob: 69.35% Token: | John| Rank: 3 Logit: 14.83 Prob: 2.82% Token: | Mary| Top 0th token. Logit: 18.03 Prob: 69.35% Token: | John| Top 1th token. Logit: 15.53 Prob: 5.67% Token: | them| Top 2th token. Logit: 15.28 Prob: 4.42% Token: | the| Top 3th token. Logit: 14.83 Prob: 2.82% Token: | Mary| Top 4th token. Logit: 14.16 Prob: 1.44% Token: | her| Top 5th token. Logit: 13.94 Prob: 1.16% Token: | him| Top 6th token. Logit: 13.72 Prob: 0.93% Token: | a| Top 7th token. Logit: 13.68 Prob: 0.89% Token: | Joseph| Top 8th token. Logit: 13.61 Prob: 0.83% Token: | Jesus| Top 9th token. Logit: 13.34 Prob: 0.64% Token: | their| Ranks of the answer tokens: [[(' John', 0), (' Mary', 3)]] ┌────────────────────────────────────────────────────────────┬──────────┬──────────────┐ │ Prompt │ Answer │ Logit Diff │ ├────────────────────────────────────────────────────────────┼──────────┼──────────────┤ │ When John and Mary went to the shops, John gave the bag to │ Mary │ +3.337 │ │ When John and Mary went to the shops, Mary gave the bag to │ John │ +3.202 │ │ When Mary and John went to the shops, John gave the bag to │ Mary │ +3.918 │ │ When Mary and John went to the shops, Mary gave the bag to │ John │ +2.220 │ └────────────────────────────────────────────────────────────┴──────────┴──────────────┘
Exercise - verify model + SAEs can still solve this
We want to make sure that the logit diff for the model is still high when we substitute in SAEs for a given layer. If this isn't the case, then that means our SAEs aren't able to implement the IOI circuit, and there's no reason to expect we'll find any interesting features!
You should use the with model.saes context manager (or whatever your preferred way of running with SAEs is) to get the average logit diff over the 4 prompts, for each layer's attention SAE. Verify that this difference is still high, i.e. the SAEs don't ruin performance.
# YOUR CODE HERE - verify model + SAEs can still solve this
Click to see the expected output
┏━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Ablation ┃ Logit diff ┃ % of clean ┃ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ Clean │ +3.169 │ 100.0% │ │ SAE in L00 │ +3.005 │ 94.8% │ │ SAE in L01 │ +3.164 │ 99.8% │ │ SAE in L02 │ +3.155 │ 99.5% │ │ SAE in L03 │ +2.991 │ 94.4% │ │ SAE in L04 │ +2.868 │ 90.5% │ │ SAE in L05 │ +3.308 │ 104.4% │ │ SAE in L06 │ +2.872 │ 90.6% │ │ SAE in L07 │ +2.565 │ 80.9% │ │ SAE in L08 │ +2.411 │ 76.1% │ │ SAE in L09 │ +3.030 │ 95.6% │ │ SAE in L10 │ +3.744 │ 118.1% │ │ SAE in L11 │ +3.809 │ 120.2% │ └────────────┴────────────┴────────────┘
Solution
logits = gpt2(prompts, return_type="logits")
clean_logit_diff = logits_to_ave_logit_diff(logits)
table = Table("Ablation", "Logit diff", "% of clean")
table.add_row("Clean", f"{clean_logit_diff:+.3f}", "100.0%")
for layer in range(gpt2.cfg.n_layers):
with gpt2.saes(saes=[attn_saes[layer]]):
logits = gpt2(prompts, return_type="logits")
logit_diff = logits_to_ave_logit_diff(logits)
table.add_row(
f"SAE in L{layer:02}",
f"{logit_diff:+.3f}",
f"{logit_diff / clean_logit_diff:.1%}",
)
rprint(table)
Discussion of results
You should find that most layers have close to 100% recovery, even some layers like 10 and 11 which increase the logit diff when they're substituted in. You shouldn't read too much into that in this case, because we're only working with a small dataset and so some noise is expected.
The worst layers in terms of logit diff recovery are 7 and 8, which makes sense given that these are the layers with our S-Inhibition heads. The fact that the logit diff is still positive for both, and remains positive when you substitute in both layer 7 and 8 SAEs at once (although it does drop to 58% of the clean logit diff) means that these SAEs are still presumably capturing some amount of the S-Inhibition heads' behaviour - however that doesn't guarantee we have monosemantic S-Inhibition features, which in fact we don't.
Now we've done this, it's time for exercises!
Exercise - find name mover features with max activations
Now we've established that the model can still solve the task with SAEs substituted, we're ready to start looking for features!
In our first exercise, we'll look for name mover features - features which activate on the final token in the prompt, and seem to predict the IO token comes next. Since attention heads in layer 9 are name movers (primarily 9.6 and 9.9), we'll be looking at our layer 9 SAE, i.e. attn_saes[9]. You should fill in the cell below to create a plot of latent activations at the final token, averaged over all 4 prompts, and also display the dashboards of the top 3 ranking latents. (If you want some help, you can borrow bits of code from earlier - specifically, the 3rd code cell of the "Running SAEs" section.) Do you find any latents that seem like they might correspond to name mover features?
# YOUR CODE HERE - find name mover latents by using max activations
Click to see the expected output
Spoiler - what you should find
You should return the features 11368, 18767 and 3101.
These all appear to be name movers:
- 11368 is a name mover for "John" (i.e. it always attends from a token before "John" could appear, back to a previous instance of "John") - 18767 is a name mover for "Mary" - 3101 is a name mover for "Jesus"
Note that the activations of 11368 and 18767 should be much larger than the activations of any other feature.
Solution
layer = 9
# Compute mean post-ReLU SAE activations at last token posn
_, cache = gpt2.run_with_cache_with_saes(prompts, saes=[attn_saes[layer]])
sae_acts_post = cache[f"{attn_saes[layer].cfg.hook_name}.hook_sae_acts_post"][:, -1].mean(0)
# Plot the activations
px.line(
sae_acts_post.cpu().numpy(),
title=f"Activations at the final token position ({sae_acts_post.nonzero().numel()} alive)",
labels={"index": "Latent", "value": "Activation"},
template="ggplot2",
width=1000,
).update_layout(showlegend=False).show()
# Print the top 3 latents, and inspect their dashboards
for act, ind in zip(*sae_acts_post.topk(3)):
print(f"Latent {ind} had activation {act:.2f}")
display_dashboard(
sae_release="gpt2-small-hook-z-kk",
sae_id=f"blocks.{layer}.hook_z",
latent_idx=int(ind),
)
Exercise - identify name mover heads
The IOI paper found that heads 9.6 and 9.9 were the primary name movers. Can you verify this by looking at the features' weight distribution over heads, i.e. seeing which heads these particular features have the largest exposure to?
# YOUR CODE HERE - verify model + SAEs can still solve this
Click to see the expected output
┏━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Head ┃ Latent 18767 ┃ Latent 10651 ┃ ┡━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ 9.0 │ 5.36% │ 7.47% │ │ 9.1 │ 2.71% │ 2.50% │ │ 9.2 │ 5.00% │ 6.66% │ │ 9.3 │ 4.29% │ 5.14% │ │ 9.4 │ 6.16% │ 6.09% │ │ 9.5 │ 3.84% │ 4.97% │ │ 9.6 │ 13.17% │ 18.16% │ │ 9.7 │ 4.97% │ 6.60% │ │ 9.8 │ 6.14% │ 8.20% │ │ 9.9 │ 42.66% │ 26.13% │ │ 9.10 │ 2.23% │ 3.49% │ │ 9.11 │ 3.47% │ 4.58% │ └──────┴───────────────┴───────────────┘
Hint - what calculation you should perform
Each feature has an associated decoder weight sae.W_dec[:, feature_idx] of shape (d_in=n_heads*d_head,). We can measure the norm of each (d_head,)-length vector (representing the exposure to each attention head in the model), and see which head has the largest fraction of the total norm.
Solution
features = [18767, 10651]
decoder_weights = einops.rearrange(
attn_saes[layer].W_dec[features],
"feats (n_heads d_head) -> feats n_heads d_head",
n_heads=model.cfg.n_heads,
)
norm_per_head = decoder_weights.pow(2).sum(-1).sqrt()
norm_frac_per_head = norm_per_head / norm_per_head.sum(-1, keepdim=True)
table = Table("Head", [f"Feature {i}" for i in features])
for i in range(model.cfg.n_heads):
table.add_row(
f"9.{i}", [f"{frac:.2%}" for frac in norm_frac_per_head[:, i].tolist()]
)
rprint(table)
You should find that both these features have largest exposure to 9.9, and second largest to 9.6.
Direct logit attribution
A slightly more refined technique than looking for max activating latents is to look fonr ones which have a particular effect on some token logits, or logit difference. In the context of the IOI circuit, we look at the logit difference between the indirect object and the subject token (which we'll abbreviate as IO - S). For example, if a latent's output into the residual stream has a very large value of IO - S when we apply the model's unembedding matrix W_U, this suggests that the latent might have been causally important for identifying that the correct answer was "Mary" rather than "John" in sentences like "When John and Mary went to the shops, John gave the bag to ???". This is essentially the same as DLA you might have already seen in previous contexts or ARENA exercises (e.g. DLA for attention heads that we covered in the IOI exercises).
Exercise - find name mover features with DLA
Write code below to perform this calculation and visualize the results (line chart & dashboards for the top-scoring latents). Your code here should look a lot like your code for finding name movers via max activations, except rather than argmaxing over avg latent activations you should be argmaxing over the latents' average DLA onto the IO - S direction.
Help - I'm not sure what this calculation should look like.
Start with the latent direction in the decoder, which is a vector of length d_sae = n_heads * d_head. We interpret this as a vector in model's concatenated z vectors (i.e. the linear combinations of value vectors you get before applying the projection matrix), so taking the projection matrix W_O of shape (n_heads, d_head, d_model) and flattening it & multiplying it by our decoder vector (without reducing over the d_sae dimension) gives us a matrix of shape (d_sae, d_model) where each row is a vector that gets added into the residual stream when that particular latent fires. Now, we can take this matrix and pass it through our unembedding W_U to get a matrix of shape (d_sae, d_vocab), and then index into it to get the logit difference for each latent (or if we want to be more efficient, we can just dot product our matrix with the vector W_U[:, IO] - W_U[:, S] and directly get a vector of d_sae logit differences).
This is how we get DLA for latents for a single prompt. We can parallelize this to get all 4 prompts at once (i.e. output of shape (4, d_sae)), then average over the batch dimension to get a single DLA vector.
# YOUR CODE HERE - verify model + SAEs can still solve this