Exercise Status: All exercises complete and verified

[1.4.2] SAE Circuits

Colab: exercises | solutions

Please send any problems / bugs on the #errata channel in the Slack group, and ask any questions on the dedicated channels for this chapter of material.

If you want to change to dark mode, you can do this by clicking the three horizontal lines in the top-right, then navigating to Settings → Theme.

Links to all other chapters: (0) Fundamentals, (1) Transformer Interpretability, (2) RL.


Introduction

In these exercises, we explore circuits with SAEs: sets of SAE latents in different layers of a transformer which communicate with each other, and explain some particular model behaviour in an end-to-end way. We'll start by computing latent-to-latent, token-to-latent and latent-to-logit gradients, which give us a linear proxy for how latents in different layers are connected. We'll then move on to transcoders, a variant of SAEs which learn to reconstruct a model layer's computation rather than just its activations, and which offer significant advantages for circuit analysis.

We expect some degree of prerequisite knowledge in these exercises. Specifically, it will be very helpful if you understand:

  • What superposition is, and what the sparse autoencoder architecture is (if you need a refresher on these topics, see exercises 1.5.4 Toy Models of SAEs & Superposition)
  • How to use the SAELens library to load and run SAEs alongside TransformerLens models (covered in the first section of exercises 1.3.3 Interpretability with SAEs)

We've included a short section at the start to speedrun the most relevant background from exercise set 1.3.3 (mainly loading models & SAEs and running forward passes with them), so you don't need to have completed that exercise set in full before starting this one.

One note on terminology: we'll be mostly adopting the convention that features are characteristics of the underlying data distribution that our base models are trained on, and SAE latents (or just "latents") are the directions in the SAE. This is to avoid the overloading of the term "feature", and avoiding the implicit assumption that "SAE features" correspond to real features in the data. We'll relax this terminology when we're looking at SAE latents which very clearly correspond to specific interpretable features in the data.

Reading Material

Content & Learning Objectives

1️⃣ Latent Gradients

SAEs are cool and interesting and we can steer on their latents to produce cool and interesting effects, but does this mean that we've truly unlocked the true units of computation used by our models, or have we just found an interesting clustering algorithm? The answer is that we don't really know yet! One strong piece of evidence for the former would be finding circuits with SAEs, in other words sets of latents in different layers of the transformer which communicate with each other, and explain some particular behaviour in an end-to-end way. In this section, we'll compute gradients between latents in different layers to build up a picture of how they communicate.

Learning Objectives
  • Learn how to compute latent-to-latent gradients between SAE latents in different layers of the transformer
  • Compute token-to-latent gradients to understand which input tokens drive particular latent activations
  • Compute latent-to-logit gradients to understand how latents affect the model's output
  • Use these gradient-based methods to find circuits in attention SAEs (e.g. induction circuits)

2️⃣ Transcoders

Transcoders are a variant of SAEs which learn to reconstruct a model layer's computation (e.g. a sparse mapping from MLP input to MLP output), rather than just reconstructing activations at a single point. They offer significant advantages for circuit analysis, since they decompose the function of an MLP layer into sparse, interpretable units. In this section, we'll load and work with transcoders, study their properties using techniques like de-embeddings, and go through a blind case study where we reverse-engineer a transcoder latent purely from weights-based analysis.

Learning Objectives
  • Understand transcoders, and how they differ from standard SAEs
  • Learn techniques for interpreting transcoder latents: pullbacks, de-embeddings, and extended embeddings
  • Work through a blind case study, interpreting a transcoder latent using only circuit-level analysis (no activation examples)

3️⃣ Attribution graphs

Attribution graphs extend the gradient-based methods from section 1️⃣ into a full framework for understanding end-to-end computation in transformers via transcoder latents. In this section, you'll implement the full attribution graph pipeline from scratch using Gemma 3-1B IT with GemmaScope 2 transcoders: linearising the model by freezing non-linearities, building the reading/writing vector abstraction for all node types, computing edge weights via batched backward passes, and pruning the graph using influence-based Neumann series propagation.

Learning Objectives
  • Understand the local replacement model: why freezing attention patterns, LayerNorm scales, and replacing MLPs with linear skip connections makes the residual stream linear
  • Understand the reading/writing vector abstraction: how token embeddings, transcoder latents, MLP errors, and logit directions all interact with the residual stream
  • Implement the core attribution algorithm: salient logit selection, graph node construction, and edge weight computation via gradient injection
  • Implement graph pruning via node and edge influence thresholding, using the Neumann series on the nilpotent adjacency matrix
  • Build interactive attribution graph visualisations using the same dashboard templates as Anthropic's published work

4️⃣ Exploring circuits & interventions

Now that you've built the attribution graph algorithm from scratch, you'll use the circuit-tracer library to explore real circuits and perform feature interventions. You'll study the Dallas/Austin two-hop factual recall circuit, test its causal structure via zero ablation, swap features between prompts, and generate text with feature interventions.

Learning Objectives
  • Load and inspect pre-computed attribution graphs and their supernodes
  • Perform zero ablation experiments to test causal claims made by the graph
  • Perform cross-prompt feature swapping to demonstrate compositional circuit structure
  • Use open-ended generation with feature interventions

A note on memory usage

In these exercises, we'll be loading some pretty large models into memory (e.g. Gemma 2-2B and its SAEs, as well as a host of other models in later sections of the material). It's useful to have functions which can help profile memory usage for you, so that if you encounter OOM errors you can try and clear out unnecessary models. For example, we've found that with the right memory handling (i.e. deleting models and objects when you're not using them any more) it should be possible to run all the exercises in this material on a Colab Pro notebook, and all the exercises minus the handful involving Gemma on a free Colab notebook.

See this dropdown for some functions which you might find helpful, and how to use them.

First, we can run some code to inspect our current memory usage. Here's an example of running this code on a Colab Pro notebook.

import part42_sae_circuits.utils as utils

# Profile memory usage, and delete gemma models if we've loaded them in
namespace = globals().copy() | locals()
utils.profile_pytorch_memory(namespace=namespace, filter_device="cuda:0")
Allocated = 35.88 GB
Total = 39.56 GB
Free = 3.68 GB
┌──────────────────────┬────────────────────────┬──────────┬─────────────┐
│ Name                 │ Object                 │ Device   │   Size (GB) │
├──────────────────────┼────────────────────────┼──────────┼─────────────┤
│ gemma_2_2b           │ HookedSAETransformer   │ cuda:0   │       11.94 │
│ gpt2                 │ HookedSAETransformer   │ cuda:0   │        0.61 │
│ gemma_2_2b_sae       │ SAE                    │ cuda:0   │        0.28 │
│ sae_resid_dirs       │ Tensor (4, 24576, 768) │ cuda:0   │        0.28 │
│ gpt2_sae             │ SAE                    │ cuda:0   │        0.14 │
│ logits               │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ logits_with_ablation │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ clean_logits         │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ _                    │ Tensor (16, 128, 768)  │ cuda:0   │        0.01 │
│ clean_sae_acts_post  │ Tensor (4, 15, 24576)  │ cuda:0   │        0.01 │
└──────────────────────┴────────────────────────┴──────────┴─────────────┘

From this, we see that we've allocated a lot of memory for the the Gemma model, so let's delete it. We'll also run some code to move any remaining objects on the GPU which are larger than 100MB to the CPU, and print the memory status again.

del gemma_2_2b
del gemma_2_2b_sae

THRESHOLD = 0.1  # GB
for obj in gc.get_objects():
    try:
        if isinstance(obj, t.nn.Module) and utils.get_tensors_size(obj) / 1024**3 > THRESHOLD:
            if hasattr(obj, "cuda"):
                obj.cpu()
            if hasattr(obj, "reset"):
                obj.reset()
    except Exception:
        pass

# Move our gpt2 model & SAEs back to GPU (we'll need them for the exercises we're about to do)
gpt2.to(device)
gpt2_saes = {layer: sae.to(device) for layer, sae in gpt2_saes.items()}

utils.print_memory_status()
Allocated = 14.90 GB
Reserved = 39.56 GB
Free = 24.66

Mission success! We've managed to free up a lot of memory. Note that the code which moves all objects collected by the garbage collector to the CPU is often necessary to free up the memory. We can't just delete the objects directly because PyTorch can still sometimes keep references to them (i.e. their tensors) in memory. In fact, if you add code to the for loop above to print out obj.shape when obj is a tensor, you'll see that a lot of those tensors are actually Gemma model weights, even once you've deleted gemma_2_2b.

Setup (don't read, just run)

import gc
import os
import sys
from collections import Counter, namedtuple
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Callable, TypeAlias

import einops
import numpy as np
import plotly.express as px
import torch as t
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
from IPython.display import IFrame, display
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from sae_lens import SAE, ActivationsStore, HookedSAETransformer
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
from tabulate import tabulate
from torch import Tensor
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

dtype = t.float32  # t.bfloat16
device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")
device = str(device)  # SAELens expects device as string; this is easier!


def _get_hook_layer(sae: SAE) -> int:
    """Extract the layer number from an SAE's hook name (e.g. 'blocks.7.hook_resid_pre' → 7)."""
    return int(sae.cfg.metadata.hook_name.split(".")[1])


# Make sure exercises are in the path
chapter = "chapter1_transformer_interp"
section = "part42_sae_circuits"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
    sys.path.append(str(exercises_dir))

import part42_sae_circuits.tests as tests
import part42_sae_circuits.utils as utils

MAIN = __name__ == "__main__"

Speedrunning some relevant background

This section covers the essential SAELens concepts you'll need for the rest of these exercises. If you've already completed exercise set 1.3.3 Interpretability with SAEs, you can skip this section and go straight to section 1️⃣.

Loading SAEs with SAE.from_pretrained

SAELens is a library designed to help researchers train and analyse sparse autoencoders. You can think of it as the equivalent of TransformerLens for sparse autoencoders (and it also integrates very well with TransformerLens models, which we'll see shortly).

To load an SAE, use SAE.from_pretrained. This function returns an SAE object directly:

gpt2_sae = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    device=str(device),
)

You can view the available SAE releases in SAELens with get_pretrained_saes_directory(). Each release contains multiple SAEs (e.g. trained on different layers of the same base model).

Base models are loaded using the HookedSAETransformer class, which is adapted from the TransformerLens HookedTransformer class:

gpt2 = HookedSAETransformer.from_pretrained("gpt2-small", device=device)

Running SAEs and caching activations

You can add SAEs to a TransformerLens model when doing forward passes in much the same way you add hook functions. There are several different methods for this:

  • model.run_with_saes(tokens, saes=[list_of_saes]) works like model.run_with_hooks, doing a single forward pass with SAEs attached (then resetting them).
  • logits, cache = model.run_with_cache_with_saes(tokens, saes=[sae]) works like model.run_with_cache, caching all intermediate activations including SAE activations.
  • with model.saes(saes=[sae]): is a context manager which temporarily attaches SAEs.
  • model.add_sae(sae) / model.reset_saes() manually adds and removes SAEs.

To access SAE activations from the cache, the hook names are the concatenation of the HookedTransformer hook_name and the SAE hook name, joined by a period. The most important ones are:

# Post-activation latent values (shape [batch, seq, d_sae]) - this is the one you'll use most
cache[f"{sae.cfg.metadata.hook_name}.hook_sae_acts_post"]

# Pre-activation latent values (before the activation function)
cache[f"{sae.cfg.metadata.hook_name}.hook_sae_acts_pre"]

# The SAE's reconstruction of the original activations
cache[f"{sae.cfg.metadata.hook_name}.hook_sae_recons"]

# The final SAE output (either reconstruction, or reconstruction + error term)
cache[f"{sae.cfg.metadata.hook_name}.hook_sae_output"]

Here's a full example, extracting the top activating latents at the final token of a prompt:

_, cache = gpt2.run_with_cache_with_saes(
    prompt,
    saes=[gpt2_sae],
    stop_at_layer=_get_hook_layer(gpt2_sae) + 1,  # no need to compute past the SAE layer
)
sae_acts_post = cache[f"{gpt2_sae.cfg.metadata.hook_name}.hook_sae_acts_post"][0, -1, :]

The use_error_term parameter

The parameter sae.use_error_term determines whether we actually substitute the model's activations with SAE reconstructions during the forward pass:

  • use_error_term=False (default): the SAE's output replaces the transformer's activations. This means downstream computations use the SAE reconstruction rather than the original activations.
  • use_error_term=True: the SAE computes all its internal states (latent activations etc.) in the same way, but the transformer's activations are left intact. This is useful when you want to cache SAE activations without actually intervening on the model.
# Cache SAE activations WITHOUT intervening on the model
gpt2_sae.use_error_term = True
logits, cache = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])
# logits are identical to running the base model without SAEs, but we still get SAE activations in the cache

# Cache SAE activations AND replace model activations with SAE reconstructions
gpt2_sae.use_error_term = False
logits_recon, cache_recon = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])
# logits_recon will differ from the base model's logits due to SAE reconstruction error

This parameter matters a lot for the gradient exercises in this set: we'll typically use use_error_term=True to get the "true" latent activations from a clean forward pass, then use_error_term=False when computing Jacobians (because we want our Jacobian function to actually pass through the SAE decoder and encoder).

Using ActivationsStore

The ActivationsStore class is a convenient alternative to loading a bunch of data yourself. It streams in data from a given dataset; in the case of the from_sae classmethod that dataset will be given by your SAE's config (which is also the same as the SAE's original training dataset):

gpt2_act_store = ActivationsStore.from_sae(
    model=gpt2,
    sae=gpt2_sae,
    streaming=True,
    store_batch_size_prompts=16,
    n_batches_in_buffer=32,
    device=str(device),
)

# Get a batch of tokens
tokens = gpt2_act_store.get_batch_tokens()
assert tokens.shape == (gpt2_act_store.store_batch_size_prompts, gpt2_act_store.context_size)

Neuronpedia & display_dashboard

Neuronpedia is an open platform for interpretability research. It hosts SAE dashboards which help you quickly understand what a particular SAE latent represents, including components like max activating examples, top logits, activation density plots, and LLM-generated explanations.

We can display these dashboards inline using the following helper function, which we'll use in several places throughout these exercises:

def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    latent_idx=0,
    width=800,
    height=600,
) -> None:
    release = get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

    print(url)
    display(IFrame(url, width=width, height=height))