[1.4.2] SAE Circuits

Colab: exercises | solutions

Please send any problems / bugs on the #errata channel in the Slack group, and ask any questions on the dedicated channels for this chapter of material.

If you want to change to dark mode, you can do this by clicking the three horizontal lines in the top-right, then navigating to Settings → Theme.

Links to all other chapters: (0) Fundamentals, (1) Transformer Interpretability, (2) RL.


Introduction

In these exercises, we dive deeply into the interpretability research that can be done with sparse autoencoders. We'll start by introducing two important tools: SAELens (essentially the TransformerLens of SAEs, which also integrates very well with TransformerLens) and Neuronpedia, an open platform for interpretability research. We'll then move through a few other exciting domains in SAE interpretability, grouped into several categories (e.g. understanding / classifying latents, or finding circuits in SAEs).

We expect some degree of prerequisite knowledge in these exercises. Specifically, it will be very helpful if you understand:

  • What superposition is
  • What the sparse autoencoder architecture is, and why it can help us disentangle features from superposition

We've included an abridged version of the exercise set 1.3.1 Superposition & SAEs, which contains all the material we view as highly useful for the rest of these exercises. If you've already gone through this exercise set then you can proceed straight to section 1️⃣, if not then we recommend at least skimming through section 0️⃣ so that you feel comfortable with the core geometric intuitions for superposition and how SAEs work.

One note before starting - we'll be mostly adopting the terminology that features are characteristics of the underlying data distribution that our base models are trained on, and SAE latents (or just "latents") are the directions in the SAE. This is to avoid the overloading of the term "feature", and avoiding the implicit assumption that "SAE features" correspond to real features in the data. We'll relax this terminology when we're looking at SAE latents which very clearly correspond to specific interpretable features in the data.

Reading Material

Most of this is optional, and can be read at your leisure depending on what interests you most & what level of background you have. If we could recommend just one, it would be "Towards Monosemanticity" - particularly the first half of "Problem Setup", and the sections where they take a deep dive on individual latents.

Content & Learning Objectives

1️⃣ SAE Circuits

SAEs are cool and interesting and we can steer on their latents to produce cool and interesting effects - but does this mean that we've truly unlocked the true units of computation used by our models, or have we just found an interesting clustering algorithm? The answer is that we don't really know yet! One strong piece of evidence for the former would be finding circuits with SAEs, in other words sets of latents in different layers of the transformer which communicate with each other, and explain some particular behaviour in an end-to-end way. How to find these kinds of circuits, and what they look like, is what we'll explore in this section.

Learning Objectives
  • Learn how to find connections between SAE latents in different layers of the transformer
  • Discover how to apply knowledge of SAE circuits to remove the bias from a linear classifier, as described in the Sparse Feature Circuits paper (not implemented yet)
  • Study transcoders, and understand how they can improve circuit analysis compared to regular SAEs

A note on memory usage

In these exercises, we'll be loading some pretty large models into memory (e.g. Gemma 2-2B and its SAEs, as well as a host of other models in later sections of the material). It's useful to have functions which can help profile memory usage for you, so that if you encounter OOM errors you can try and clear out unnecessary models. For example, we've found that with the right memory handling (i.e. deleting models and objects when you're not using them any more) it should be possible to run all the exercises in this material on a Colab Pro notebook, and all the exercises minus the handful involving Gemma on a free Colab notebook.

See this dropdown for some functions which you might find helpful, and how to use them.

First, we can run some code to inspect our current memory usage. Here's me running this code during the exercise set on SAE circuits, after having already loaded in the Gemma models from the previous section. This was on a Colab Pro notebook.

import part31_superposition_and_saes.utils as utils
# Profile memory usage, and delete gemma models if we've loaded them in
namespace = globals().copy() | locals()
utils.profile_pytorch_memory(namespace=namespace, filter_device="cuda:0")
Allocated = 35.88 GB
Total = 39.56 GB
Free = 3.68 GB
┌──────────────────────┬────────────────────────┬──────────┬─────────────┐
│ Name                 │ Object                 │ Device   │   Size (GB) │
├──────────────────────┼────────────────────────┼──────────┼─────────────┤
│ gemma_2_2b           │ HookedSAETransformer   │ cuda:0   │       11.94 │
│ gpt2                 │ HookedSAETransformer   │ cuda:0   │        0.61 │
│ gemma_2_2b_sae       │ SAE                    │ cuda:0   │        0.28 │
│ sae_resid_dirs       │ Tensor (4, 24576, 768) │ cuda:0   │        0.28 │
│ gpt2_sae             │ SAE                    │ cuda:0   │        0.14 │
│ logits               │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ logits_with_ablation │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ clean_logits         │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ _                    │ Tensor (16, 128, 768)  │ cuda:0   │        0.01 │
│ clean_sae_acts_post  │ Tensor (4, 15, 24576)  │ cuda:0   │        0.01 │
└──────────────────────┴────────────────────────┴──────────┴─────────────┘

From this, we see that we've allocated a lot of memory for the the Gemma model, so let's delete it. We'll also run some code to move any remaining objects on the GPU which are larger than 100MB to the CPU, and print the memory status again.

del gemma_2_2b
del gemma_2_2b_sae
THRESHOLD = 0.1  # GB
for obj in gc.get_objects():
    try:
        if isinstance(obj, t.nn.Module) and part32_utils.get_tensors_size(obj) / 1024**3 > THRESHOLD:
            if hasattr(obj, "cuda"):
                obj.cpu()
            if hasattr(obj, "reset"):
                obj.reset()
    except:
        pass
# Move our gpt2 model & SAEs back to GPU (we'll need them for the exercises we're about to do)
gpt2.to(device)
gpt2_saes = {layer: sae.to(device) for layer, sae in gpt2_saes.items()}
part32_utils.print_memory_status()
Allocated = 14.90 GB
Reserved = 39.56 GB
Free = 24.66

Mission success! We've managed to free up a lot of memory. Note that the code which moves all objects collected by the garbage collector to the CPU is often necessary to free up the memory. We can't just delete the objects directly because PyTorch can still sometimes keep references to them (i.e. their tensors) in memory. In fact, if you add code to the for loop above to print out obj.shape when obj is a tensor, you'll see that a lot of those tensors are actually Gemma model weights, even once you've deleted gemma_2_2b.

Setup (don't read, just run)

import gc
import itertools
import os
import random
import sys
from collections import Counter, defaultdict
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import circuitsvis as cv
import einops
import numpy as np
import pandas as pd
import plotly.express as px
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)

# Make sure exercises are in the path
chapter = "chapter1_transformer_interp"
section = "part32_interp_with_saes"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
    sys.path.append(str(exercises_dir))

# There's a single utils & tests file for both parts 3.1 & 3.2
import part31_superposition_and_saes.tests as tests
import part31_superposition_and_saes.utils as utils
from plotly_utils import imshow, line

MAIN = __name__ == "__main__"