[1.3.3] Interpretability with SAEs
Please send any problems / bugs on the #errata channel in the Slack group, and ask any questions on the dedicated channels for this chapter of material.
If you want to change to dark mode, you can do this by clicking the three horizontal lines in the top-right, then navigating to Settings → Theme.
Links to all other chapters: (0) Fundamentals, (1) Transformer Interpretability, (2) RL.
Note - there is a very large amount of content in this set of exercises, easily double that of any other single exercise set in ARENA (and some of those exercise sets are meant to last several days). The purpose of these exercises isn't to go through every single one of them, but rather to jump around to the ones you're most interested in. For instance, if you already have a rough idea of what superposition & SAEs are, you can skip past section 0️⃣ and go straight into the later sections.
Also, rather than using this material as exercises, you can also just use it as a helpful source of reference code, if you ever want to quickly implement some particular SAE technique or type of forward pass / causal intervention.
You can use the interactive map below to get a better sense of the material content, and the dependencies between different sections.
Introduction
In these exercises, we dive deeply into the interpretability research that can be done with sparse autoencoders. We'll start by introducing two important tools: SAELens (essentially the TransformerLens of SAEs, which also integrates very well with TransformerLens) and Neuronpedia, an open platform for interpretability research. We'll then move through a few other exciting domains in SAE interpretability, grouped into several categories (e.g. understanding / classifying latents, or finding circuits in SAEs).
We expect some degree of prerequisite knowledge in these exercises. Specifically, it will be very helpful if you understand:
- What superposition is
- What the sparse autoencoder architecture is, and why it can help us disentangle features from superposition
We've included an abridged version of the exercise set 1.3.1 Superposition & SAEs, which contains all the material we view as highly useful for the rest of these exercises. If you've already gone through this exercise set then you can proceed straight to section 1️⃣, if not then we recommend at least skimming through section 0️⃣ so that you feel comfortable with the core geometric intuitions for superposition and how SAEs work.
One note before starting - we'll be mostly adopting the terminology that features are characteristics of the underlying data distribution that our base models are trained on, and SAE latents (or just "latents") are the directions in the SAE. This is to avoid the overloading of the term "feature", and avoiding the implicit assumption that "SAE features" correspond to real features in the data. We'll relax this terminology when we're looking at SAE latents which very clearly correspond to specific interpretable features in the data.
Reading Material
Most of this is optional, and can be read at your leisure depending on what interests you most & what level of background you have. If we could recommend just one, it would be "Towards Monosemanticity" - particularly the first half of "Problem Setup", and the sections where they take a deep dive on individual latents.
- Toy Models of Superposition outlines the core ideas behind superposition - what it is, why it matters for interepretability, and what we might be able to do about it.
- Towards Monosemanticity: Decomposing Language Models With Dictionary Learning arguably took the first major stride in mechanistic interpretability with SAEs: training them on a 1-layer model, and extracting a large number of interpretable features.
- Scaling Monosemanticity: Extracting Interpretable Features from Claude 3 Sonnet shows how you can scale up the science of SAEs to larger models, specifically the SOTA (at the time) model Claude 3 Sonnet. It provides an interesting insight into where the field might be moving in the near future.
- Improving Dictionary Learning with Gated Sparse Autoencoders is a paper from DeepMind which introduces the Gated SAE architecture, demonstrating how it outperforms the standard architecture and also motivating its use by speculating about underlying feature distributions.
- Gemma Scope announces DeepMind's release of a comprehensive suite of open-sourced SAE models (trained with JumpReLU architecture). We'll be working a lot more with Gemma Scope models in subsequent exercises!
- LessWrong, SAEs tag contains a collection of posts on LessWrong that discuss SAEs, and is a great source of inspiration for further independent research!
Content & Learning Objectives
View this interactive map to see all the material which is complete & still in development, as well as figures showing what you'll create at the end of each section. It contains pretty much all the information below, but in an easier-to-visualize format.
Highlighting a few important things which the map should make clear:
- There's no required order to go through this material in! With the exception of the first few chunks of section 1️⃣ you can basically pick whatever you want to work through, depending on what your Objectives are. The map shows the dependencies between different sections, which you can use to guide your work.
- Some sections of material are still in development, and will continue to be added to over October & November (although development will mostly be paused during the middle of October, while I'm going through an interview process). We're open to suggestions or recommendations for how to improve / add to this material further!
1️⃣ Intro to SAE Interpretability
The idea is for this section is to be an MVP for all basic SAE topics, excluding training & evals (which we'll come back to in section 4). The focus will be on how to understand & interpret SAE latents (in particular all the components of the SAE dashboard). We'll also look at techniques for finding latents (e.g. ablation & attribution methods), as well as taking a deeper dive into attention SAEs and how they work.
Learning Objectives
- Learn how to use the
SAELenslibrary to load in & run SAEs (alongside the TransformerLens models they're attached to)- Understand the basic features of Neuronpedia, and how it can be used for things like steering and searching over features
- Understand SAE dashboards, what each part of them tells you about a particular latent (as well as how to compute them yourself)
- Learn techniques for finding latents, including direct logit attribution, ablation and attribution patching
- Use attention SAEs, understand how they differ from regular SAEs (as well as topics specific to attention SAEs, like direct latent attribution)
- Learn a bit about different SAE architectures or training methods (e.g. gated, end-to-end, meta-saes, transcoders) - some of these will be covered in more detail later
2️⃣ Understanding SAE Latents: A Deeper Dive
This is essentially an a-la-carte batch of several different topics, which aren't specifically related to SAE circuits or training SAEs, but which were too niche or in-depth to cover in the intro section. Much of this represents research being done on the cutting edge of current SAE interpretability, and could be an interesting jumping off point for your own research!
Learning Objectives
- Study feature splitting, and what it means for SAE training
- Use UMAPs and other dimensionality-reduction techniques to better understand SAE latent geometry
- Understand feature absorption, and how meta-SAEs might help us disentangle this problem (not implemented yet)
- Use the logit lens & techniques like token enrichment analysis to better understand & characterize SAE latents (not implemented yet)
- Take a deeper dive into automated interpretability, covering autointerp-based evals & patch scoping (not implemented yet)
3️⃣ Training & Evaluating SAEs
In this section, we first cover some basic material on training SAEs. We'll show you how SAELens supports SAE training and cover a few general pieces of advice, and then go through a few training case studies. Each of these training exercises represents a good jumping off point for further investigation of your trained models (although this is more true for the smaller models e.g. the attention SAE trained on a 2L model, since it's more likely to have found features that match the level of complexity of the base model, given that these tutorials are optimised for brevity and low compute requirements!).
We plan to add a second half to this section which covers evals, but this is currently still in development (and we cover many evals-related things in other parts of the material, e.g. the first half of this section, as well as things like autointerp in section 3).
Learning Objectives
- Learn how to train SAEs using
SAELens- Understand how to interpret different metrics during training, and understand when & why SAE training fails to produce interpretable latents
- Get hands-on experience training SAEs in a variety of context: MLP output of TinyStories-1L, residual stream of Gemma-2-2B, attention output of a 2L model, etc
- Understand how to evaluate SAEs, and why simple metrics can be deceptive (not implemented yet)
A note on memory usage
In these exercises, we'll be loading some pretty large models into memory (e.g. Gemma 2-2B and its SAEs, as well as a host of other models in later sections of the material). It's useful to have functions which can help profile memory usage for you, so that if you encounter OOM errors you can try and clear out unnecessary models. For example, we've found that with the right memory handling (i.e. deleting models and objects when you're not using them any more) it should be possible to run all the exercises in this material on a Colab Pro notebook, and all the exercises minus the handful involving Gemma on a free Colab notebook.
See this dropdown for some functions which you might find helpful, and how to use them.
First, we can run some code to inspect our current memory usage. Here's me running this code during the exercise set on SAE circuits, after having already loaded in the Gemma models from the previous section. This was on a Colab Pro notebook.
import part31_superposition_and_saes.utils as utils
# Profile memory usage, and delete gemma models if we've loaded them in
namespace = globals().copy() | locals()
utils.profile_pytorch_memory(namespace=namespace, filter_device="cuda:0")
Allocated = 35.88 GB Total = 39.56 GB Free = 3.68 GB ┌──────────────────────┬────────────────────────┬──────────┬─────────────┐ │ Name │ Object │ Device │ Size (GB) │ ├──────────────────────┼────────────────────────┼──────────┼─────────────┤ │ gemma_2_2b │ HookedSAETransformer │ cuda:0 │ 11.94 │ │ gpt2 │ HookedSAETransformer │ cuda:0 │ 0.61 │ │ gemma_2_2b_sae │ SAE │ cuda:0 │ 0.28 │ │ sae_resid_dirs │ Tensor (4, 24576, 768) │ cuda:0 │ 0.28 │ │ gpt2_sae │ SAE │ cuda:0 │ 0.14 │ │ logits │ Tensor (4, 15, 50257) │ cuda:0 │ 0.01 │ │ logits_with_ablation │ Tensor (4, 15, 50257) │ cuda:0 │ 0.01 │ │ clean_logits │ Tensor (4, 15, 50257) │ cuda:0 │ 0.01 │ │ _ │ Tensor (16, 128, 768) │ cuda:0 │ 0.01 │ │ clean_sae_acts_post │ Tensor (4, 15, 24576) │ cuda:0 │ 0.01 │ └──────────────────────┴────────────────────────┴──────────┴─────────────┘
From this, we see that we've allocated a lot of memory for the the Gemma model, so let's delete it. We'll also run some code to move any remaining objects on the GPU which are larger than 100MB to the CPU, and print the memory status again.
del gemma_2_2b
del gemma_2_2b_sae
THRESHOLD = 0.1 # GB
for obj in gc.get_objects():
try:
if isinstance(obj, t.nn.Module) and part32_utils.get_tensors_size(obj) / 1024**3 > THRESHOLD:
if hasattr(obj, "cuda"):
obj.cpu()
if hasattr(obj, "reset"):
obj.reset()
except:
pass
# Move our gpt2 model & SAEs back to GPU (we'll need them for the exercises we're about to do)
gpt2.to(device)
gpt2_saes = {layer: sae.to(device) for layer, sae in gpt2_saes.items()}
part32_utils.print_memory_status()
Allocated = 14.90 GB Reserved = 39.56 GB Free = 24.66
Mission success! We've managed to free up a lot of memory. Note that the code which moves all objects collected by the garbage collector to the CPU is often necessary to free up the memory. We can't just delete the objects directly because PyTorch can still sometimes keep references to them (i.e. their tensors) in memory. In fact, if you add code to the for loop above to print out obj.shape when obj is a tensor, you'll see that a lot of those tensors are actually Gemma model weights, even once you've deleted gemma_2_2b.
Setup (don't read, just run)
import gc
import itertools
import os
import random
import sys
from collections import Counter, defaultdict
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias
import circuitsvis as cv
import einops
import numpy as np
import pandas as pd
import plotly.express as px
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from sae_lens import (
SAE,
ActivationsStore,
HookedSAETransformer,
LanguageModelSAERunnerConfig,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy
device = t.device(
"mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)
# Make sure exercises are in the path
chapter = "chapter1_transformer_interp"
section = "part32_interp_with_saes"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
sys.path.append(str(exercises_dir))
# There's a single utils & tests file for both parts 3.1 & 3.2
import part31_superposition_and_saes.tests as tests
import part31_superposition_and_saes.utils as utils
from plotly_utils import imshow, line
MAIN = __name__ == "__main__"