2️⃣ Transcoders
Learning Objectives
- Understand transcoders, and how they differ from standard SAEs
- Learn techniques for interpreting transcoder latents: pullbacks, de-embeddings, and extended embeddings
- Work through a blind case study, interpreting a transcoder latent using only circuit-level analysis (no activation examples)
Introduction
The MLP-layer SAEs we've looked at attempt to represent activations as a sparse linear combination of latent vectors; importantly, they only operate on activations at a single point in the model. They don't actually learn to perform the MLP layer's computation, rather they learn to reconstruct the results of that computation. It's very hard to do any weights-based analysis on MLP layers in superposition using standard SAEs, since many latents are highly dense in the neuron basis, meaning the neurons are hard to decompose.
In contrast, transcoders take in the activations before the MLP layer (i.e. the possibly-normalized residual stream values) and aim to represent the post-MLP activations of that MLP layer, again as a sparse linear combination of latent vectors. The transcoder terminology is the most common, although these have also been called input-output SAEs (because we take the input to some base model layer, and try to learn the output) and predicting future activations (for obvious reasons). Note that transcoders aren't technically autoencoders, because they're learning a mapping rather than a reconstruction - however a lot of our intuitions from SAEs carry over to transcoders.
Why might transcoders be an improvement over standard SAEs? Mainly, they offer a much clearer insight into the function of a model's layers. From the Transcoders LessWrong post:
One of the strong points of transcoders is that they decompose the function of an MLP layer into sparse, independently-varying, and meaningful units (like neurons were originally intended to be before superposition was discovered). This significantly simplifies circuit analysis.
...
As an analogy, let’s say that we have some complex compiled computer program that we want to understand (a la Chris Olah’s analogy). SAEs are analogous to a debugger that lets us set breakpoints at various locations in the program and read out variables. On the other hand, transcoders are analogous to a tool for replacing specific subroutines in this program with human-interpretable approximations.
Intuitively it might seem like transcoders are solving a different (more complicated) kind of optimization problem - trying to mimic the MLP's computation rather than just reproduce output - and so they would suffer a performance tradeoff relative to standard SAEs. However, evidence suggests that this might not be the case, and transcoders might offer a pareto improvement over standard SAEs.
We'll start by loading in our transcoders. We load in transcoders exactly the same way as SAEs: with SAE.from_pretrained.
The model we'll be working with has been trained to reconstruct the 8th MLP layer of GPT-2. An important note - we're talking about taking the normalized input to the MLP layer and outputting mlp_out (i.e. the values we'll be adding back to the residual stream). So when we talk about pre-MLP and post-MLP values, we mean this, not pre/post activation function!
gpt2 = HookedSAETransformer.from_pretrained("gpt2-small", device=device, dtype=dtype)
hf_repo_id = "callummcdougall/arena-demos-transcoder"
sae_id = "gpt2-small-layer-{layer}-mlp-transcoder-folded-b_dec_out"
gpt2_transcoders = {
layer: SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id.format(layer=layer), device=device, dtype=dtype)
for layer in tqdm(range(9))
}
layer = 8
gpt2_transcoder = gpt2_transcoders[layer]
print("Transcoder hooks (same as regular SAE hooks):", gpt2_transcoder.hook_dict.keys())
# Load the sparsity values, and plot them
log_sparsity_path = hf_hub_download(hf_repo_id, f"{sae_id.format(layer=layer)}/log_sparsity.pt")
log_sparsity = t.load(log_sparsity_path, map_location="cpu", weights_only=True)
fig = px.histogram(
to_numpy(log_sparsity), width=800, template="ggplot2", title="Transcoder latent sparsity"
).update_layout(showlegend=False)
fig.show()
live_latents = np.arange(len(log_sparsity))[to_numpy(log_sparsity > -4)]
# Get the activations store
gpt2_act_store = ActivationsStore.from_sae(
model=gpt2,
sae=gpt2_transcoders[layer],
dataset="NeelNanda/pile-10k",
streaming=True,
store_batch_size_prompts=16,
n_batches_in_buffer=32,
device=device,
)
tokens = gpt2_act_store.get_batch_tokens()
assert tokens.shape == (gpt2_act_store.store_batch_size_prompts, gpt2_act_store.context_size)
Click to see the expected output
Next, we've given you a helper function which wraps model.run_with_cache_with_saes to handle setting use_error_term on transcoders before running. By default use_error_term=True, meaning the model's activations are left intact and we just cache transcoder activations.
def run_with_cache_with_transcoder(
model: HookedSAETransformer,
transcoders: list[SAE],
tokens: Tensor,
use_error_term: bool = True, # by default we don't intervene, just compute activations
) -> ActivationCache:
"""
Runs MLP transcoder(s) on a batch of tokens, using native SAELens v6 transcoder support.
If use_error_term=True (default), the model's activations are left intact and we just cache
transcoder activations. If False, the model's MLP outputs are replaced with transcoder outputs.
"""
prev_use_error_terms = [tc.use_error_term for tc in transcoders]
for tc in transcoders:
tc.use_error_term = use_error_term
try:
_, cache = model.run_with_cache_with_saes(tokens, saes=transcoders)
finally:
for tc, prev in zip(transcoders, prev_use_error_terms):
tc.use_error_term = prev
return cache
Lastly, we've given you the functions which you should already have encountered in the earlier exercise sets, when we were replicating SAE dashboards (if you've not done these exercises yet, we strongly recommend them!). The only difference is that we use run_with_cache_with_transcoder (a thin wrapper around model.run_with_cache_with_saes) to handle setting use_error_term on transcoders.
def get_k_largest_indices(
x: Float[Tensor, "batch seq"], k: int, buffer: int = 0, no_overlap: bool = True
) -> Int[Tensor, "k 2"]:
if buffer > 0:
x = x[:, buffer:-buffer]
indices = x.flatten().argsort(-1, descending=True)
rows = indices // x.size(1)
cols = indices % x.size(1) + buffer
if no_overlap:
unique_indices = t.empty((0, 2), device=x.device).long()
while len(unique_indices) < k:
unique_indices = t.cat((unique_indices, t.tensor([[rows[0], cols[0]]], device=x.device)))
is_overlapping_mask = (rows == rows[0]) & ((cols - cols[0]).abs() <= buffer)
rows = rows[~is_overlapping_mask]
cols = cols[~is_overlapping_mask]
return unique_indices
return t.stack((rows, cols), dim=1)[:k]
def index_with_buffer(
x: Float[Tensor, "batch seq"], indices: Int[Tensor, "k 2"], buffer: int | None = None
) -> Float[Tensor, " k *buffer_x2_plus1"]:
rows, cols = indices.unbind(dim=-1)
if buffer is not None:
rows = einops.repeat(rows, "k -> k buffer", buffer=buffer * 2 + 1)
cols[cols < buffer] = buffer
cols[cols > x.size(1) - buffer - 1] = x.size(1) - buffer - 1
cols = einops.repeat(cols, "k -> k buffer", buffer=buffer * 2 + 1) + t.arange(
-buffer, buffer + 1, device=x.device
)
return x[rows, cols]
def display_top_seqs(data: list[tuple[float, list[str], int]]):
table = Table("Act", "Sequence", title="Max Activating Examples", show_lines=True)
for act, str_toks, seq_pos in data:
formatted_seq = (
"".join([f"[b u green]{str_tok}[/]" if i == seq_pos else str_tok for i, str_tok in enumerate(str_toks)])
.replace("�", "")
.replace("\n", "↵")
)
table.add_row(f"{act:.3f}", repr(formatted_seq))
rprint(table)
def fetch_max_activating_examples(
model: HookedSAETransformer,
transcoder: SAE,
act_store: ActivationsStore,
latent_idx: int,
total_batches: int = 100,
k: int = 10,
buffer: int = 10,
display: bool = False,
) -> list[tuple[float, list[str], int]]:
data = []
for _ in tqdm(range(total_batches)):
tokens = act_store.get_batch_tokens()
cache = run_with_cache_with_transcoder(model, [transcoder], tokens)
acts = cache[f"{transcoder.cfg.metadata.hook_name}.hook_sae_acts_post"][..., latent_idx]
k_largest_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
tokens_with_buffer = index_with_buffer(tokens, k_largest_indices, buffer=buffer)
str_toks = [model.to_str_tokens(toks) for toks in tokens_with_buffer]
top_acts = index_with_buffer(acts, k_largest_indices).tolist()
data.extend(list(zip(top_acts, str_toks, [buffer] * len(str_toks))))
data = sorted(data, key=lambda x: x[0], reverse=True)[:k]
if display:
display_top_seqs(data)
return data
Let's pick latent 1, and compare our results to the neuronpedia dashboard (note that we do have neuronpedia dashboards for this model, even though it's not in SAELens yet).
latent_idx = 1
neuronpedia_id = "gpt2-small/8-tres-dc"
url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
display(IFrame(url, width=800, height=600))
fetch_max_activating_examples(
gpt2, gpt2_transcoder, gpt2_act_store, latent_idx=latent_idx, total_batches=200, display=True
)
Max Activating Examples ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Act ┃ Sequence ┃ ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ 13.403 │ ' perception about not only how critical but also how dominant goalies can be in the NHL. ' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 12.093 │ ' of the best. Hundreds of goalies have laced up the skates, put on' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 11.882 │ " 9↵↵December 4↵↵Messi's goal breakdown↵↵Barcelona↵↵Copa del" │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 11.549 │ 't identify who the six Canadian legend NHL goalies are. We know 5 of them are hall' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 11.493 │ ' Steven Gerrard giving the Reds the lead. Second half goals from Rafael and a Robin van Persie │ │ │ penalty won' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 11.280 │ "↵↵Messi's month-by-month goal tally↵↵January 7↵↵February 10↵" │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 10.874 │ '.↵↵Going from most recent to oldest Canadian goalies to win the Calder. 08-09 saw' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 10.811 │ ' the NHL. Each of these goalies left his stamp on the game. Now six' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 10.734 │ '↵"We just need to make sure that one goal conceded does not create a ripple effect, which then' │ ├────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 10.501 │ ' the win with two sweetly taken second-half goals.↵↵"There were games last year where' │ └────────┴────────────────────────────────────────────────────────────────────────────────────────────────────────┘
Pullback & de-embeddings
In the exercises on latent-latent gradients at the start of this section, we saw that it could be quite difficult to compute how any 2 latents in different layers interact with each other. In fact, we could only compute gradients between latents which were both alive in the same forward pass. One way we might have liked to deal with this is by just taking the dot product of the "writing vector" of one latent with the "reading vector" of another. For example, suppose our SAEs were trained on post-ReLU MLP activations, then we could compute: W_dec[:, f1] @ W_out[layer1] @ W_in[layer2] @ W_enc[f2, :] (where f1 and f2 are our earlier and later latent indices, layer1 and layer2 are our SAE layers, and W_in, W_out are the MLP input & output weight matrices for all layers). To make sense of this formula: the term W_dec[:, f1] @ W_out[layer1] is the "writing vector" being added to the residual stream by the first (earlier) latent, and we would take the dot product of this with W_in[layer2] @ W_enc[f2, :] to compute the activation of the second (later) latent. However, one slight frustration is that we're ignoring the later MLP layer's ReLU function (remember that the SAEs are reconstructing the post-ReLU activations, not pre-ReLU). This might seem like a minor point, but it actually gets to a core part of the limitation of standard SAEs when trained on layers which perform computation - the SAEs are reconstructing a snapshot in the model, but they're not helping us get insight into the layer's actual computation process.
How do transcoders help us here? Well, since transcoders sit around the entire MLP layer (nonlinearity and all), we can literally compute the dot product between the "writing vector" and a downstream "reading vector" to figure out whether any given latent causes another one to be activated (ignoring layernorm). To make a few definitions:
- The pullback of some later latent is $p = (W_{dec})^T f_{later}$, i.e. the dot product of the later latent vector (reading weight) with all the decoder weights (writing weights) of earlier latents.
- The de-embedding is a special case: $d = W_E f_{later}$, i.e. instead of asking "which earlier transcoder latents activate some later latent?" we ask "which tokens maximally activate some later latent?".
Note that we can in principle compute both of these quantities for regular MLP SAEs. But they wouldn't be as accurate to the model's actual computation, and so you couldn't draw as many strong conclusions from them.
To complete our circuit picture of (embeddings -> transcoders -> unembeddings), it's worth noting that we can compute the logit lens for a transcoder latent in exactly the same way as regular SAEs: just take the dot product of the transcoder decoder vector with the unembedding matrix. Since this has basically the exact same justification & interpretation as for regular SAEs, we don't need to invent a new term for it, so we'll just keep calling it the logit lens!
Exercise - compute de-embedding
In the cell below, you should compute the de-embedding for this latent (i.e. which tokens cause this latent to fire most strongly). You can use the logit lens function as a guide (which we've provided, from where it was used in the earlier exercises).
def show_top_logits(
model: HookedSAETransformer,
sae: SAE,
latent_idx: int,
k: int = 10,
) -> None:
"""Displays the top & bottom logits for a particular latent."""
logits = sae.W_dec[latent_idx] @ model.W_U
pos_logits, pos_token_ids = logits.topk(k)
pos_tokens = model.to_str_tokens(pos_token_ids)
neg_logits, neg_token_ids = logits.topk(k, largest=False)
neg_tokens = model.to_str_tokens(neg_token_ids)
print(
tabulate(
zip(map(repr, neg_tokens), neg_logits, map(repr, pos_tokens), pos_logits),
headers=["Bottom tokens", "Value", "Top tokens", "Value"],
tablefmt="simple_outline",
stralign="right",
numalign="left",
floatfmt="+.3f",
)
)
print(f"Top logits for transcoder latent {latent_idx}:")
show_top_logits(gpt2, gpt2_transcoder, latent_idx=latent_idx)
def show_top_deembeddings(model: HookedSAETransformer, sae: SAE, latent_idx: int, k: int = 10) -> None:
"""Displays the top & bottom de-embeddings for a particular latent."""
raise NotImplementedError()
print(f"\nTop de-embeddings for transcoder latent {latent_idx}:")
show_top_deembeddings(gpt2, gpt2_transcoder, latent_idx=latent_idx)
tests.test_show_top_deembeddings(show_top_deembeddings, gpt2, gpt2_transcoder)
Top logits for transcoder latent 1: ┌─────────────────┬─────────┬──────────────┬─────────┐ │ Bottom tokens │ Value │ Top tokens │ Value │ ├─────────────────┼─────────┼──────────────┼─────────┤ │ ' Goal' │ -0.812 │ 'keeping' │ +0.693 │ │ 'nox' │ -0.638 │ 'bred' │ +0.690 │ │ 'ussions' │ -0.633 │ 'urious' │ +0.663 │ │ ' Vision' │ -0.630 │ 'cake' │ +0.660 │ │ 'heses' │ -0.623 │ 'swick' │ +0.651 │ │ 'iasco' │ -0.619 │ 'hedral' │ +0.647 │ │ ' dream' │ -0.605 │ 'sy' │ +0.622 │ │ ' Grenade' │ -0.594 │ 'ascus' │ +0.612 │ │ 'rament' │ -0.586 │ 'ebted' │ +0.611 │ │ ' imagin' │ -0.575 │ 'ZE' │ +0.610 │ └─────────────────┴─────────┴──────────────┴─────────┘ Top de-embeddings for transcoder latent 1: ┌─────────────────┬─────────┬──────────────┬─────────┐ │ Bottom tokens │ Value │ Top tokens │ Value │ ├─────────────────┼─────────┼──────────────┼─────────┤ │ 'attr' │ -0.775 │ 'liga' │ +1.720 │ │ ' reciproc' │ -0.752 │ 'GAME' │ +1.695 │ │ 'oros' │ -0.712 │ 'jee' │ +1.676 │ │ ' resists' │ -0.704 │ ' scorer' │ +1.649 │ │ ' Advent' │ -0.666 │ 'ickets' │ +1.622 │ │ 'gling' │ -0.646 │ ' scored' │ +1.584 │ │ ' Barron' │ -0.630 │ 'artifacts' │ +1.580 │ │ ' coh' │ -0.593 │ 'scoring' │ +1.578 │ │ ' repr' │ -0.592 │ 'itory' │ +1.520 │ │ ' reprint' │ -0.587 │ ' scoring' │ +1.520 │ └─────────────────┴─────────┴──────────────┴─────────┘
Solution
def show_top_logits(
model: HookedSAETransformer,
sae: SAE,
latent_idx: int,
k: int = 10,
) -> None:
"""Displays the top & bottom logits for a particular latent."""
logits = sae.W_dec[latent_idx] @ model.W_U
pos_logits, pos_token_ids = logits.topk(k)
pos_tokens = model.to_str_tokens(pos_token_ids)
neg_logits, neg_token_ids = logits.topk(k, largest=False)
neg_tokens = model.to_str_tokens(neg_token_ids)
print(
tabulate(
zip(map(repr, neg_tokens), neg_logits, map(repr, pos_tokens), pos_logits),
headers=["Bottom tokens", "Value", "Top tokens", "Value"],
tablefmt="simple_outline",
stralign="right",
numalign="left",
floatfmt="+.3f",
)
)
def show_top_deembeddings(model: HookedSAETransformer, sae: SAE, latent_idx: int, k: int = 10) -> None:
"""Displays the top & bottom de-embeddings for a particular latent."""
de_embeddings = model.W_E @ sae.W_enc[:, latent_idx]
pos_logits, pos_token_ids = de_embeddings.topk(k)
pos_tokens = model.to_str_tokens(pos_token_ids)
neg_logits, neg_token_ids = de_embeddings.topk(k, largest=False)
neg_tokens = model.to_str_tokens(neg_token_ids)
print(
tabulate(
zip(map(repr, neg_tokens), neg_logits, map(repr, pos_tokens), pos_logits),
headers=["Bottom tokens", "Value", "Top tokens", "Value"],
tablefmt="simple_outline",
stralign="right",
numalign="left",
floatfmt="+.3f",
)
)
This is ... pretty underwhelming! It seems very obvious that the top activating token should be " goal" from looking at the dashboard - why are we getting weird words like "liga" and "jee"? Obviously some words make sense like " scored" or " scoring", but overall this isn't what we would expect.
Can you guess what's happening here? (Try and think about it before reading on, since reading the description of the next exercise will give away the answer!)
Hint
If you've done the IOI ARENA exercises (or read the IOI paper), you'll have come across this idea. It has to do with the architecture of GPT2-Small.
Answer
GPT2-Small has tied embeddings, i.e. its embedding matrix is the transpose of its unembedding matrix. This means the direct path is unable to represent bigram frequencies (e.g. it couldn't have higher logits for the bigram Barack Obama than for Obama Barack), so the MLP layers have to step in and break the symmetry. In particular MLP0 seems to do this, which is why we call it the extended embedding (or the effective embedding).
The result of this is that the indexed rows of the embedding matrix aren't really a good representation of the thing that the model has actually learned to treat as the embedding of a given token.
Exercise - correct the de-embedding function
You should fill in the function below to compute the extended embedding, which will allow us to correct the mistake in the function discussed in the dropdowns above.
There are many different ways to compute the extended embedding (e.g. sometimes we include the attention layer and assume it always self-attends, sometimes we only use MLP0's output and sometimes we add it to the raw embeddings, sometimes we use a BOS token to make it more accurate). Most of these methods will get similar quality of results (it's far more important that you include MLP0 than the exact details of how you include it). For the sake of testing though, you should use the following method:
- Take the embedding matrix,
- Apply layernorm to it (i.e. each token's embedding vector is scaled to have unit std dev),
- Apply MLP0 to it (i.e. to each token's normalized embedding vector separately),
- Add the result back to the original embedding matrix.
Tip - rather than writing out the individual operations for layernorm & MLPs, you can just use the forward methods of model.blocks[layer].ln2 or .mlp respectively.
def create_extended_embedding(model: HookedTransformer) -> Float[Tensor, "d_vocab d_model"]:
"""
Creates the extended embedding matrix using the model's layer-0 MLP, and the method described
in the exercise above.
You should also divide the output by its standard deviation across the `d_model` dimension
(this is because that's how it'll be used later e.g. when fed into the MLP layer / transcoder).
"""
raise NotImplementedError()
tests.test_create_extended_embedding(create_extended_embedding, gpt2)
Solution
def create_extended_embedding(model: HookedTransformer) -> Float[Tensor, "d_vocab d_model"]:
"""
Creates the extended embedding matrix using the model's layer-0 MLP, and the method described
in the exercise above.
You should also divide the output by its standard deviation across the `d_model` dimension
(this is because that's how it'll be used later e.g. when fed into the MLP layer / transcoder).
"""
W_E = model.W_E.clone()[:, None, :] # shape [batch=d_vocab, seq_len=1, d_model]
mlp_output = model.blocks[0].mlp(model.blocks[0].ln2(W_E)) # shape [batch=d_vocab, seq_len=1, d_model]
W_E_ext = (W_E + mlp_output).squeeze()
return (W_E_ext - W_E_ext.mean(dim=-1, keepdim=True)) / W_E_ext.std(dim=-1, keepdim=True)
Once you've passed those tests, try rewriting show_top_deembeddings to use the extended embedding. Do the results look better? (Hint - they should!)
Note - don't worry if the magnitude of the results seems surprisingly large. Remember that a normalization step is applied pre-MLP, so the actual activations will be smaller than is suggested by the values in the table you'll generate.
def show_top_deembeddings_extended(model: HookedSAETransformer, sae: SAE, latent_idx: int, k: int = 10) -> None:
"""Displays the top & bottom de-embeddings for a particular latent."""
raise NotImplementedError()
tests.test_show_top_deembeddings_extended(show_top_deembeddings_extended, gpt2, gpt2_transcoder)
print(f"Top de-embeddings (extended) for transcoder latent {latent_idx}:")
show_top_deembeddings_extended(gpt2, gpt2_transcoder, latent_idx=latent_idx)
Click to see the expected output
Top de-embeddings (extended) for transcoder latent 1: ┌─────────────────┬─────────┬──────────────┬─────────┐ │ Bottom tokens │ Value │ Top tokens │ Value │ ├─────────────────┼─────────┼──────────────┼─────────┤ │ ' coupled' │ -8.747 │ 'goal' │ +14.161 │ │ 'inski' │ -7.633 │ ' Goal' │ +13.004 │ │ ' bent' │ -7.601 │ ' goal' │ +12.510 │ │ ' Line' │ -7.357 │ 'Goal' │ +11.724 │ │ ' Layer' │ -7.235 │ ' Goals' │ +11.538 │ │ ' layered' │ -7.225 │ ' goals' │ +11.447 │ │ ' lined' │ -7.206 │ ' goalt' │ +10.378 │ │ 'avy' │ -7.110 │ 'score' │ +10.364 │ │ ' Cassidy' │ -7.032 │ ' Soccer' │ +10.162 │ │ 'Contin' │ -7.006 │ ' puck' │ +10.122 │ └─────────────────┴─────────┴──────────────┴─────────┘
Solution
def show_top_deembeddings_extended(model: HookedSAETransformer, sae: SAE, latent_idx: int, k: int = 10) -> None:
"""Displays the top & bottom de-embeddings for a particular latent."""
de_embeddings = create_extended_embedding(model) @ sae.W_enc[:, latent_idx]
pos_logits, pos_token_ids = de_embeddings.topk(k)
pos_tokens = model.to_str_tokens(pos_token_ids)
neg_logits, neg_token_ids = de_embeddings.topk(k, largest=False)
neg_tokens = model.to_str_tokens(neg_token_ids)
print(
tabulate(
zip(map(repr, neg_tokens), neg_logits, map(repr, pos_tokens), pos_logits),
headers=["Bottom tokens", "Value", "Top tokens", "Value"],
tablefmt="simple_outline",
stralign="right",
numalign="left",
floatfmt="+.3f",
)
)
Blind case study
This is an open-ended exploration designed to put everything you've learned in this section into practice. It's challenging and there's no single right answer - treat it as a research exercise rather than a problem with a solution to converge on.
The authors of the post introducing transcoders present the idea of a blind case study. To quote from their post:
...we have some latent in some transcoder, and we want to interpret this transcoder latent without looking at the examples that cause it to activate. Our goal is to instead come to a hypothesis for when the latent activates by solely using the input-independent and input-dependent circuit analysis methods described above.
By input-independent circuit analysis, they mean things like pullbacks and de-embeddings (i.e. things which are a function of just the model & transcoder's weights). By input-dependent, they specifically mean the input-dependent influence, which they define to be the elementwise product of the pullback to some earlier transcoder and the post-ReLU activations of that earlier transcoder. In other words, it tells you not just which earlier latents would affect some later latent when those earlier latents fire, but which ones do affect the later latent on some particular input (i.e. taking into account which ones actually fired).
What's the motivation for this? Well, eventually we want to be able to understand latents when they appear in complex circuits, not just as individual units which respond to specific latents in the data. And part of that should involve being able to build up hypotheses about what a given latent is doing based on only its connection to other latents (or to specific tokens in the input). Just looking directly at the top activating examples can definitely be helpful, but not only is it sometimes misleading, it also can only tell you what a latent is doing, without giving much insight into why.
To be clear on the rules:
- You can't look at activations of a latent on specific tokens in specific example prompts.
- You can use input-dependent analysis e.g. the influence of some earlier latents on your target latent on some particular input (however you have to keep the input in terms of token IDs not tokens, because it's cheating to look at the actual content of prompts which activate any of your latents).
- You can use input-independent analysis e.g. a latent's de-embeddings or logit lens.
We're making this a very open-ended exercise - we've written some functions for you above, but others you might have to write yourself, depending on what seems most useful for your analysis (e.g. we've not given you a function to compute pullback yet). If you want an easier exercise then you can use a latent which the post successfully reverse-engineered (e.g. latent 355, the 300th live latent in the transcoder), but for a challenge you can also try latent 479 (the 400th live latent in the transcoder,which the authors weren't able to reverse-engineer in their initial post).
If you want a slightly easier version of the game, you can try a rule relaxation where you're allowed to pass your own sequences into the model to test hypotheses (you just can't do something like find the top activating sequences over a large dataset and decode them). This allows you to test your hypotheses in ways that still impose some restrictions on your action space.
blind_study_latent = 479
layer = 8
gpt2_transcoder = gpt2_transcoders[layer]
# YOUR CODE HERE!
You can click on the dropdown below to see my attempt at this exercise, or read this notebook which shows the authors' walkthrough blind case study interpretation of this latent. Don't visit the notebook until you've given the exercise a good try though, since the title will give part of the problem away!
My attempt
My approach was to work through four stages of analysis, building up a hypothesis and refining it at each stage.
Stage 1: De-embeddings & logit lens
Start with input-independent analysis: what tokens does this latent read from (de-embeddings) and write to (logit lens)?
# (1) look at de-embedding
print("De-embeddings:")
show_top_deembeddings_extended(gpt2, gpt2_transcoder, latent_idx=blind_study_latent)
print("Logit lens:")
show_top_logits(gpt2, gpt2_transcoder, latent_idx=blind_study_latent)
# Results?
# - de-embedding has quite a few words related to finance or accumulation, e.g. " deficits", " output", " amounts", " amassed" (also "imately" could be the second half of "approximately")
# - but definitely not as strong evidence as we got for "goal" earlier
# - logit lens shows us this latent firing will boost words like ' costing' and ' estimated'
# - possible theory: it fires on phrases like "...fines <<costing>>..." or "...amassed <<upwards>> of..."
# - prediction based on theory: we should see earlier latents firing on money-related words, and being attended to
# - e.g. "the bank had <<amassed>> upwards of $100m$": maybe "amassed" attends to "bank"
Stage 2: Influence from earlier latents (direct path)
Next, look at the direct (non-attention-mediated) influence from earlier transcoder latents. Gather the top activating sequences, compute pullback-weighted activations, and inspect the de-embeddings of the most consistently influential latents.
# (2) look at influence from earlier latents
# Gather 20 top activating sequences for the target latent
total_batches = 500
k = 20
buffer = 10
data = [] # list of (seq_pos: int, tokens: list[int], top_act: float)
for _ in tqdm(range(total_batches)):
tokens = gpt2_act_store.get_batch_tokens()
cache = run_with_cache_with_transcoder(gpt2, [gpt2_transcoder], tokens, use_error_term=True)
acts = cache[f"{gpt2_transcoder.cfg.metadata.hook_name}.hook_sae_acts_post"][..., blind_study_latent]
k_largest_indices = get_k_largest_indices(acts, k=k, buffer=buffer) # [k, 2]
tokens_in_top_sequences = tokens[k_largest_indices[:, 0]] # [k, seq_len]
top_acts = index_with_buffer(acts, k_largest_indices) # [k,]
data.extend(list(zip(k_largest_indices[:, 1].tolist(), tokens_in_top_sequences.tolist(), top_acts.tolist())))
data = sorted(data, key=lambda x: x[2], reverse=True)[:k]
tokens = t.tensor([x[1] for x in data]) # each row is a full sequence, containing one of the max activating tokens
top_seqpos = [x[0] for x in data] # list of sequence positions of the max activating tokens
acts = [x[2] for x in data] # list of max activating values
# Compute pullback from earlier latents to target latent, then compute influence for these top activating sequences
cache = run_with_cache_with_transcoder(gpt2, list(gpt2_transcoders.values()), tokens, use_error_term=True)
t.cuda.empty_cache()
all_influences = []
for _layer in range(layer):
acts = cache[f"{gpt2_transcoders[_layer].cfg.metadata.hook_name}.hook_sae_acts_post"] # shape [k=20, seq_len=128, d_sae=24k]
acts_at_top_posn = acts[range(k), top_seqpos] # shape [k=20, d_sae=24k]
pullback = gpt2_transcoders[_layer].W_dec @ gpt2_transcoder.W_enc[:, blind_study_latent] # shape [d_sae]
influence = acts_at_top_posn * pullback # shape [k=20, d_sae=24k]
all_influences.append(influence)
# Find the earlier latents which are consistently in the top 10 for influence on target latent, and inspect their de-embeddings
all_influences = t.cat(all_influences, dim=-1) # shape [k, n_layers*d_sae]
top_latents = all_influences.topk(k=10, dim=-1).indices.flatten() # shape [k*10]
top_latents_as_tuples = [(i // gpt2_transcoder.cfg.d_sae, i % gpt2_transcoder.cfg.d_sae) for i in top_latents.tolist()]
top5_latents_as_tuples = sorted(Counter(top_latents_as_tuples).items(), key=lambda x: x[1], reverse=True)[:5]
print(
tabulate(
top5_latents_as_tuples,
headers=["Latent", "Count"],
tablefmt="simple_outline",
)
)
for (_layer, _idx), count in top5_latents_as_tuples:
print(f"Latent {_layer}.{_idx} was in the top 5 for {count}/{k} of the top-activating seqs. Top de-embeddings:")
show_top_deembeddings_extended(gpt2, gpt2_transcoders[_layer], latent_idx=_idx)
# Results?
# - 7.13166 is very interesting: it's in the top way more than any other latent (17/20 vs 10/20 for the second best), and it boosts quantifiers like " approximately", " exceeding", " EQ", " ≥"
# - Since this is the direct path, possibly we'll find our target latent fires on these kinds of words too? Would make sense given its logit lens results
# - Also more generally, the words we're getting as top de-embeddings in these latents all appear in similar contexts, but they're not similar (i.e. substitutable) words, which makes this less likely to be a token-level latent
Stage 3: Influence via attention heads
Now look at attention-mediated influence: which earlier transcoder latents affect the target latent via attention heads? For each head, map the target latent's reading vector backwards through the OV circuit, then dot it against the attention-weighted writing vectors of earlier transcoders.
# (3) look at influence coming from attention heads (i.e. embedding -> earlier transcoders -> attention -> target transcoder latent)
# The method here is a bit complicated. We do the following, for each head:
# - (A) Map the target latent's "reading vector" backwards through the attention head, to get a "source token reading vector" (i.e. the vector we'd dot product with the residual stream at the source token to get the latent activation for our target latent at the destination token)
# - (B) For all earlier transcoders, compute their "weighted source token writing vector" (i.e. the vector which they write to the residual stream at each source token, weighted by attention from target position to source position)
# - (C) Take the dot product of these, and find the top early latents for this particular head
top_latents_as_tuples = []
for attn_layer in range(layer + 1): # we want to include target layer, because attn comes before MLP
for attn_head in range(gpt2.cfg.n_heads):
for early_transcoder_layer in range(attn_layer): # we don't include target layer, because attn comes before MLP
# Get names
pattern_name = utils.get_act_name("pattern", attn_layer)
transcoder_acts_name = f"{gpt2_transcoders[early_transcoder_layer].cfg.metadata.hook_name}.hook_sae_acts_post"
# (A)
reading_vector = gpt2_transcoder.W_enc[:, blind_study_latent] # shape [d_model]
reading_vector_src = einops.einsum(
reading_vector,
gpt2.W_O[attn_layer, attn_head],
gpt2.W_V[attn_layer, attn_head],
"d_model_out, d_head d_model_out, d_model_in d_head -> d_model_in",
)
# (B)
writing_vectors = gpt2_transcoders[early_transcoder_layer].W_dec # shape [d_sae, d_model]
patterns = cache[pattern_name][range(k), attn_head, top_seqpos] # shape [k, seq_K]
early_transcoder_acts = cache[transcoder_acts_name] # shape [k, seq_K, d_sae]
pattern_weighted_acts = einops.einsum(patterns, early_transcoder_acts, "k seq_K, k seq_K d_sae -> d_sae")
# pattern_weighted_acts = (patterns[..., None] * early_transcoder_acts).mean(0).mean(0) # shape [k, d_sae]
weighted_src_token_writing_vectors = einops.einsum(
pattern_weighted_acts, writing_vectors, "d_sae, d_sae d_model -> d_sae d_model"
)
# (C)
influences = weighted_src_token_writing_vectors @ reading_vector_src # shape [d_sae]
top_latents_as_tuples.extend(
[
{
"early_latent": repr(f"{early_transcoder_layer}.{idx.item():05d}"),
"attn_head": (attn_layer, attn_head),
"influence": value.item(),
}
# (early_transcoder_layer, attn_layer, attn_head, idx.item(), value.item())
for value, idx in zip(*influences.topk(k=10, dim=-1))
]
)
top20_latents_as_tuples = sorted(top_latents_as_tuples, key=lambda x: x["influence"], reverse=True)[:20]
print(
tabulate(
[v.values() for v in top20_latents_as_tuples],
headers=["Early latent", "Attention head", "Influence"],
tablefmt="simple_outline",
)
)
# Results?
# - Attribution from layer 7 transcoder:
# - 2 latents fire in layer 7, and boost our target latent via head L8H5
# - I'll inspect both of these (prediction = as described above, these latents' de-embeddings will be financial words)
# - Attribution from earlier transcoders:
# - There are a few transcoder latents in layers 0, 1, 2 which have influence mediated through L7 attention heads (mostly L7H3 and L7H4)
# - I'll check out both of them, but I'll also check out the de-embedding mapped directly through these heads (ignoring earlier transcoders), because I suspect these early transcoder latents might just be the extended embedding in disguise
def show_top_deembeddings_extended_via_attention_head(
model: HookedSAETransformer,
sae: SAE,
latent_idx: int,
attn_head: tuple[int, int] | None = None,
k: int = 10,
use_extended: bool = True,
) -> None:
"""
Displays the top k de-embeddings for a particular latent, optionally after that token's embedding is mapped through
some attention head.
"""
t.cuda.empty_cache()
W_E_ext = create_extended_embedding(model) if use_extended else (model.W_E / model.W_E.std(dim=-1, keepdim=True))
if attn_head is not None:
W_V = model.W_V[*attn_head]
W_O = model.W_O[*attn_head]
W_E_ext = (W_E_ext @ W_V) @ W_O
W_E_ext = (W_E_ext - W_E_ext.mean(dim=-1, keepdim=True)) / W_E_ext.std(dim=-1, keepdim=True)
de_embeddings = W_E_ext @ sae.W_enc[:, latent_idx]
pos_logits, pos_token_ids = de_embeddings.topk(k)
pos_tokens = model.to_str_tokens(pos_token_ids)
print(
tabulate(
zip(map(repr, pos_tokens), pos_logits),
headers=["Top tokens", "Value"],
tablefmt="simple_outline",
stralign="right",
numalign="left",
floatfmt="+.3f",
)
)
print("Layer 7 transcoder latents (these influence the target latent via L8H5):")
for _layer, _idx in [(7, 3373), (7, 14110), (7, 10719), (7, 8696)]:
print(f"{_layer}.{_idx} de-embeddings:")
show_top_deembeddings_extended_via_attention_head(gpt2, gpt2_transcoders[_layer], latent_idx=_idx)
print("\n" * 3 + "Layer 1-2 transcoder latents (these influence the target latent via L7H3 and L7H4):")
for _layer, _idx in [(2, 21691), (1, 14997)]:
print(f"{_layer}.{_idx} de-embeddings:")
show_top_deembeddings_extended_via_attention_head(gpt2, gpt2_transcoders[_layer], latent_idx=_idx)
print("\n" * 3 + "De-embeddings of target latent via L7H3 and L7H4:")
for attn_layer, attn_head in [(7, 3), (7, 4)]:
print(f"L{attn_layer}H{attn_head} de-embeddings:")
show_top_deembeddings_extended_via_attention_head(
gpt2,
gpt2_transcoder,
latent_idx=blind_study_latent,
attn_head=(attn_layer, attn_head),
)
# Results?
# - Layer 7 transcoder latents:
# - 14110 & 8696 both seem to fire on financial words, e.g. " revenues" is top word for both and they also both include " GDP" in their top 10
# - They also both fire on words like "deaths" and "fatalities", which also makes sense given my hypothesis (e.g. this could be sentences like "the number fatalities* is approximately** totalling***" (where * = src token where the layer 7 latent fires, ** = word predicted by target latent)
# - 10719 very specifically fires on the word "estimated" (or variants), which also makes sense: these kinds of sentences can often have the word "estimated" in them (e.g. "the estimated number of fatalities is 1000")
# - 3373 fires on "effectively", "constitutes" and "amounted", which are also likely to appear in sentences like this one (recall we've not looked at where attn is coming from - this could be self-attention!)
# - Earlier transcoder latents:
# - Disappointingly, these don't seem very interpretable (nor when I just look at direct contributions from the attention heads which are meant to be mediating their influence)
Stage 4: Component-level attribution
Finally, tally up the mean attribution from each component (embeddings, attention heads, MLPs) across the top activating sequences. This gives a qualitative picture of which paths matter most.
# (4) Final experiment: component-level attribution
# For all these top examples, I want to tally up the contributions from each component (past MLP layers, attention heads, and direct path) and compare them
# This gives me a qualitative sense of which ones matter more
latent_dir = gpt2_transcoder.W_enc[:, blind_study_latent] # shape [d_model,]
embedding_attribution = cache["embed"][range(k), top_seqpos].mean(0) @ latent_dir
attn_attribution = (
t.stack(
[
einops.einsum(
cache["z", _layer][range(k), top_seqpos].mean(0),
gpt2.W_O[_layer],
"head d_head, head d_head d_model -> head d_model",
)
for _layer in range(layer + 1)
]
)
@ latent_dir
) # shape [layer+1, n_heads]
mlp_attribution = (
t.stack([cache["mlp_out", _layer][range(k), top_seqpos].mean(0) for _layer in range(layer)]) @ latent_dir
)
all_attributions = t.zeros((layer + 2, gpt2.cfg.n_heads + 1))
all_attributions[0, 0] = embedding_attribution
all_attributions[1:, :-1] = attn_attribution
all_attributions[1:-1, -1] = mlp_attribution
df = pd.DataFrame(utils.to_numpy(all_attributions))
text = [["W_E", *["" for _ in range(gpt2.cfg.n_heads)]]]
for _layer in range(layer + 1):
text.append(
[f"L{_layer}H{_head}" for _head in range(gpt2.cfg.n_heads)] + [f"MLP{_layer}" if _layer < layer else ""]
)
fig = px.imshow(
df,
color_continuous_scale="RdBu",
color_continuous_midpoint=0.0,
width=700,
height=600,
title="Attribution from different components",
)
fig.data[0].update(text=text, texttemplate="%{text}", textfont={"size": 12})
fig.show()
# Results?
# - Way less impact from W_E than I expected, and even MLP0 (extended embedding) had a pretty small impact, this is evidence away from it being a token-level latent
# - Biggest attributions are from L8H5 and MLP7
# - L8H5 is the one that attends back to (A) tokens with financial/fatalities context, (B) the word "estimated" and its variants, and (C) other related quantifiers like "effectively" or "amounted"
# - MLP7 was seen to contain many latents that fired on words which would appear in sentences related to financial estimations (see (2), where we looked at the top 5 contributing latents - they were all in layer 7)
# - Also, the not-very-interpretable results from attention heads 7.3 and 7.4 matter less now, because we can see from this that they aren't very important (although I don't know why they turned up so high before)
Final theory & verification
Based on all the evidence gathered above, formulate a concrete hypothesis and test it by viewing the Neuronpedia dashboard.
# Based on all evidence, this is my final theory:
# - The latent activates primarily on sentences involving estimates of financial quantities (or casualties)
# - For example I expect top activating seqs like:
# - "The number of fatalities is **approximately** totalling..."
# - "The bank had **amassed** upwards of $100m..."
# - "The GDP of the UK **exceeds** $300bn..."
# - "This tech company is estimated to be **roughly** worth..."
# where I've highlighted what I guess to be the top activating token, but the surrounding cluster should also be activating
# - Concretely, what causes it to fire? Most important things (in order) are:
# - (1) Attention head 8.5, which attends back to the output of layer 7 transcoder latents that fire on words which imply we're in sentences discussing financial quantities or fatality estimates (e.g. "fatalities", "bank", "GDP" and "company" in the examples above). Also this head strongly attends back to a layer 7 latent which detects the word "estimated" and its variants, so I expect very strong activations to start after this word appears in a sentence
# - (2) Layer-7 transcoder latents (directly), for example latent 7.13166 fires on the token "≤" and causes our target latent to fire
# - (3) Direct path: the latent should fire strongest on words like **approximately** which rank highly in its de-embedding
# Let's display the latent dashboard for both the target latent and the other latents involved in this theory, and see if the theory is correct:
neuronpedia_id = "gpt2-small/8-tres-dc"
url = f"https://neuronpedia.org/{neuronpedia_id}/{blind_study_latent}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
display(IFrame(url, width=800, height=600))
# Conclusions?
# - Mostly correct:
# - The top activating sequences are mostly financial estimates
# - Activations are very large after the word "estimated" (most of the top examples are sentences containing this word)
# - The latent doesn't seem to be token-level; it fires on a cluster of adjacent words
# - Some areas where the hypothesis was incorrect, or lacking:
# - I didn't give a hypothesis for when the activations would stop - it seems they stop exactly at the estimated value, and I don't think I would have been able to predict that based on the experiments I ran
# - Relatedly, I wouldn't have predicted activations staying high even on small connecting words before the estimated value (e.g. "of" in "monthly rent of...", or "as" in "as much as...")
# - I overestimated the importance of the current word in the sentence (or more generally, I had too rigid a hypothesis for what pattern of sentences would this latent activate on & where it would activate)
# - I thought there would be more casualty estimates in the top activating sequences, but there weren't. Subsequent testing (see code below) shows that it does indeed fire strongly on non-financial estimates with the right sentence structure, and fatalities fires stronger than the other 2 non-financial example sentences, but the difference is small, so I think this was still an overestimation in my hypothesis)
prompts = {
"fatalities": """Body counts are a crude measure of the war's impact and more reliable estimates will take time to compile. Since war broke out in the Gaza Strip almost a year ago, the official number of Palestinians killed is estimated to exceed 41,000.""",
"emissions": """Environmental measurements are an imperfect gauge of climate change impact and more comprehensive studies will take time to finalize. Since the implementation of new global emissions policies almost a year ago, the reduction in global carbon dioxide emissions is estimated to exceed million metric tons.""",
"visitors": """Visitor counts are a simplistic measure of a national park's popularity and more nuanced analyses will take time to develop. Since the implementation of the new trail system almost a year ago, the number of unique bird species spotted in Yellowstone National Park is estimated to have increased by 47.""",
}
acts_dict = {}
for name, prompt in prompts.items():
str_tokens = [f"{tok} ({i})" for i, tok in enumerate(gpt2.to_str_tokens(prompt))]
cache = run_with_cache_with_transcoder(gpt2, [gpt2_transcoder], prompt)
acts = cache[f"{gpt2_transcoder.cfg.metadata.hook_name}.hook_sae_acts_post"][0, :, blind_study_latent]
acts_dict[name] = utils.to_numpy(acts).tolist()
min_length = min([len(x) for x in acts_dict.values()])
acts_dict = {k: v[-min_length:] for k, v in acts_dict.items()}
df = pd.DataFrame(acts_dict)
px.line(df, y=prompts.keys(), height=500, width=800).show()
In sections 1️⃣ and 2️⃣, we worked with GPT-2 Small to study latent-level gradients and transcoders. We saw that gradients between pairs of latents can tell us which upstream latents most influence a given downstream latent, and that transcoders give us interpretable decompositions of MLP computation (with the added benefit of explicit reading and writing vectors for each latent).
In the next section, we'll combine these two ideas into attribution graphs, which give us a complete causal picture of how a model produces a particular output. We'll also switch from GPT-2 to Gemma 3-1B IT with GemmaScope 2 transcoders, which gives us a more capable model to study.