Exercise Status: All exercises complete and verified

3️⃣ Attribution graphs

Learning Objectives
  • Understand the local replacement model: why freezing attention patterns, LayerNorm scales, and replacing MLPs with linear skip connections makes the residual stream linear
  • Understand the reading/writing vector abstraction: how token embeddings, transcoder latents, MLP errors, and logit directions all interact with the residual stream
  • Implement the core attribution algorithm: salient logit selection, graph node construction, and edge weight computation via gradient injection
  • Implement graph pruning via node and edge influence thresholding, using the Neumann series on the nilpotent adjacency matrix
  • Build interactive attribution graph visualisations using the same dashboard templates as Anthropic's published work

In sections 1️⃣ and 2️⃣, we explored how to compute gradients between individual latent pairs and how transcoders decompose MLP computation into interpretable latents. Now we'll combine these ideas into attribution graphs: end-to-end causal explanations of how a model produces a particular output, expressed in terms of transcoder latents.

As Anthropic's Circuit Tracing paper describes it, the goal is to "produce graph descriptions of the model's computation on prompts of interest by tracing individual computational steps in a 'replacement model'." By substituting transcoders for MLPs we get a replacement model where every intermediate step is expressed in terms of interpretable features, and because we linearise the model (freezing attention patterns and LayerNorm scales), the influence between any pair of features becomes a well-defined linear quantity we can compute efficiently via backward-pass attribution.

An attribution graph is a directed acyclic graph. Input nodes are token embeddings (one per token in the prompt). Intermediate nodes are transcoder latents (each feature at each position where it fires). Output nodes are logit directions (the top predicted tokens at the final position). Edges represent the direct causal influence between nodes, computed via backward-pass attribution through a linearised version of the model.

The key insight is the reading/writing vector abstraction: every node has a writing vector (the direction it adds to the residual stream) and a reading vector (the direction it reads from the residual stream). An edge weight between two nodes is essentially the dot product of the source's writing vector, mapped through frozen intermediate layers, with the target's reading vector.

To make this work, we linearise the model by freezing its non-linearities (attention patterns, LayerNorm scales). This makes the residual stream a linear function of the feature activations at each node, letting us compute all edge weights efficiently via batched backward passes. Note the MLP preactivation nonlinearities (JumpReLU/ReLU) remain - the model is linear given fixed feature activations, not in the raw inputs. We then prune the graph to keep only the most influential nodes and edges, since even with sparse latents there are still too many active latents on a given prompt to interpret the full graph.

This section is based on Anthropic's Circuit Tracing paper and its companion paper On the Biology of a Large Language Model.

GemmaScope 2

We'll be using transcoders from GemmaScope 2, Google DeepMind's second suite of sparse autoencoders and transcoders for the Gemma model family.

GemmaScope 2 is a sequel to the original GemmaScope release. It covers the Gemma 3 family of models (not Gemma 2), and provides transcoders trained on all layers of both pre-trained and instruction-tuned models at all sizes up to 27B.

These transcoders were trained with Matryoshka loss, so latents have a natural hierarchy. Latents with smaller indices tend to fire more frequently and matter more for reconstruction, while later-indexed latents represent narrower, more specific concepts (and are often more interesting to study).

They were also trained with affine skip connections (W_skip). The skip connection captures the linear component of the MLP's computation, while the transcoder latents capture the non-linear component. When we freeze the model for attribution, gradients flow through the linear skip connection rather than the full non-linear MLP.

You can read more about GemmaScope 2 here.

Setup: loading Gemma 3-1B IT with transcoders

We'll use Gemma 3-1B IT with GemmaScope 2 transcoders from SAELens. Note that first you'll need to get an API key from huggingface and put it in the .env file as HF_TOKEN, since the Gemma model is gated.

You can read the HuggingFace readme to see how the gemmascope-2 model series are saved. TLDR:

  • The release names are gemma-scope-2-{size}-{type}-{site}-all for the SAE releases which are on all layers of a given model (and remove the -all to get a suite of SAEs which contain more different hyperparameter versions but only trained on a subset of layers).
  • size is one of 270m, 1b, 4b, 12b or 27b
  • type is one of pt (pretrained) or it (instruction-tuned)
  • site is either transcoders or the name of the sae site i.e. resid_post, mlp_out or attn_out (we ignore multi-layer models now for simplicity)
  • The individual SAE IDs are of the form layer_{layer}_width_{width}_l0_{l0}, where:
  • layer is an integer layer index (zero-indexed)
  • width is a string: either 16k or 262k for the -all releases, although the subset releases will also support 65k as well as 1m in the case of the resid-post SAEs
  • l0 is either small (approx 10-20) or big (approx 100-150), also the subset releases have medium which is about 30-60)
  • Also the transcoders have _affine appended to the end of the SAE ID

Once you've got an API key, let's load the model and transcoders.

# Load huggingface token
load_dotenv(dotenv_path=str(exercises_dir / ".env"))
HF_TOKEN = os.getenv("HF_TOKEN")
assert HF_TOKEN, "Please set HF_TOKEN in your chapter1_transformer_interp/exercises/.env file"

# Load gemma model using HF token
gemma = HookedSAETransformer.from_pretrained(
    "google/gemma-3-1b-it",
    device=device,
    dtype=dtype,
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
)
gemma.set_use_hook_mlp_in(True)

# Check this version of SAELens has the gemmascope transcoders
assert "gemma-scope-2-1b-it-transcoders-all" in get_pretrained_saes_directory()
n_layers = 26
n_saes_per_layer = 8  # 2 widths, 2 L0s, 2 (affine vs non-affine)
transcoders_map = get_pretrained_saes_directory()["gemma-scope-2-1b-it-transcoders-all"].saes_map
assert len(transcoders_map) == n_layers * n_saes_per_layer

# Load transcoders for all layers (parallel I/O, sequential GPU transfer)
n_layers = gemma.cfg.n_layers
sae_ids = [f"layer_{layer}_width_16k_l0_small_affine" for layer in range(n_layers)]
transcoders: dict[int, SAE] = utils.load_saes_parallel(
    release="gemma-scope-2-1b-it-transcoders-all",
    sae_ids=sae_ids,
    device=device,
    dtype=dtype,
    max_workers=2,
)
# Correct the hook names (temporary until this is fixed)
for layer in range(n_layers):
    transcoders[layer].cfg.metadata.hook_name = f"blocks.{layer}.mlp.hook_in"

Now let's define our prompt. We'll use the chat template format for Gemma IT models. Note the special tokens <start_of_turn> and <end_of_turn>. We need to mask the first 4 tokens (BOS + <start_of_turn> + user + \n) when building attribution graphs, since these are formatting tokens that don't carry meaningful attribution signal.

def format_prompt(user_prompt: str, model_response: str) -> str:
    """Format a prompt for Gemma IT models using the chat template."""
    return f"<start_of_turn>user\n{user_prompt}<end_of_turn>\n<start_of_turn>model\n{model_response}"


START_POSN = 4  # Mask [BOS, <start_of_turn>, user, \n]

# Example: rhyming couplet prompt from Anthropic's attribution graph work
prompt = format_prompt(
    "Write me a short rhyming couplet.",
    "The sun descends, a golden hue,\nAs evening whispers, soft and",  # (...true)
)

# Tokenize and verify
tokens = gemma.to_tokens(prompt)
str_tokens = gemma.to_str_tokens(prompt)
print(f"Prompt has {len(str_tokens)} tokens")
print(f"First {START_POSN} tokens (masked): {str_tokens[:START_POSN]}")
print(f"Remaining tokens: {str_tokens[START_POSN:]}")

The local replacement model

The core idea behind attribution graphs is linearising the model. A transformer's residual stream is almost linear. The main non-linearities are: attention patterns (softmax over queries and keys), LayerNorm scales (the RMSNorm normalisation factor), and MLP activations (the non-linear MLP computation).

If we freeze all three at their forward-pass values, the residual stream becomes a linear function of the feature activations at each node. (The MLP preactivation nonlinearities remain, but those are absorbed into the node values themselves.) This means we can compute well-defined attributions via a single backward pass.

For MLPs specifically, we replace the full non-linear MLP with a linear skip connection: mlp_out ≈ mlp_input @ W_skip, where W_skip is the affine skip connection weight from the transcoder. This is not a perfect approximation, but it captures the linear component of MLP behaviour, and the transcoder latents capture the non-linear component.

The key subtlety is what exactly we freeze in attention. We freeze the attention patterns (the softmax output), but we keep the value computation (V @ W_O) differentiable. This is crucial: if we froze the entire attention output, we'd lose all cross-position gradient flow through attention, and our attribution graph would only have edges within the same position.

Here's how the linearisation works in TransformerLens:

  • Freeze hook_pattern (the attention pattern matrix / softmax output). Gradients still flow through the value vectors, but the weighted combination is fixed.
  • Freeze hook_scale (the LayerNorm scale factors), making LayerNorm a linear operation.
  • For MLPs, use a "skip connection trick": replace mlp_out with skip + (mlp_out - skip).detach(), where skip = ln(resid_mid) @ W_skip. During the forward pass this gives the correct MLP output, but during backward the gradient flows through the linear skip connection instead.

FreezeHooks - freezing attention and LayerNorm

We'll give you the FreezeHooks class, which handles freezing attention patterns and LayerNorm scales. Read through this code carefully and make sure you understand what it does.

class FreezeHooks:
    """
    Installs forward hooks that freeze attention patterns and LayerNorm scales at their
    forward-pass values, making these operations linear.

    Uses TransformerLens-native hooks (not PyTorch register_forward_hook). Two interfaces:

    1. Context manager (for direct use):
        with freeze:
            model(tokens)  # patterns/scales replaced by frozen values

    2. fwd_hooks property (for combining with other hooks via model.hooks()):
        with model.hooks(fwd_hooks=freeze.fwd_hooks + other_hooks, bwd_hooks=...):
            model(tokens).backward()
    """

    def __init__(self, model: HookedSAETransformer):
        self.model = model
        self.frozen_values: dict[str, Tensor] = {}

    def _freeze_hook(self, value: Tensor, hook: HookPoint) -> Tensor:
        """
        TL-native hook: replaces activations with their frozen values.

        Handles batched forward passes (batch_size > 1) even though frozen values were cached
        with batch_size=1. Two cases arise:

        - Regular hooks (hook_pattern, ln.hook_scale): frozen shape is [1, ...], so we repeat
          along dim 0 to get [batch, ...].
        - QK-norm hook_scale: the batch dim is folded into dim 0 as [batch*pos*n_heads, 1],
          so the frozen shape is [pos*n_heads, 1]. We repeat by the batch ratio to get
          [batch*pos*n_heads, 1].

        In both cases the ratio value.shape[0] // frozen.shape[0] gives the right repeat factor.
        """
        frozen = self.frozen_values[hook.name]
        if frozen.shape[0] != value.shape[0] and value.shape[0] % frozen.shape[0] == 0:
            ratio = value.shape[0] // frozen.shape[0]
            frozen = frozen.repeat(ratio, *([1] * (frozen.dim() - 1)))
        return frozen

    @property
    def fwd_hooks(self) -> list[tuple[str, Callable]]:
        """
        Return freeze hooks as a list of (hook_name, hook_fn) pairs.

        Use this with model.hooks() to combine freeze hooks with other hooks in a
        single context manager (Option A), avoiding nesting issues:

            with model.hooks(fwd_hooks=freeze.fwd_hooks + capture_hooks, bwd_hooks=...):
                model(tokens).backward()
        """
        return [(name, self._freeze_hook) for name in self.frozen_values]

    def cache_frozen_values(self, tokens: Tensor) -> ActivationCache:
        """Runs a forward pass, caching values we'll freeze later, plus other useful activations."""
        # We cache attention patterns, LN scales, and residual stream values
        names_filter = lambda name: any(
            s in name for s in ["hook_pattern", "hook_scale", "resid_post", "resid_pre", "mlp.hook_in", "mlp_out"]
        )
        _, cache = self.model.run_with_cache(tokens, names_filter=names_filter)

        # Store frozen values (patterns and scales)
        self.frozen_values = {
            name: cache[name].detach() for name in cache.keys() if "hook_pattern" in name or "hook_scale" in name
        }
        return cache

    def __enter__(self):
        """
        Install freeze hooks using TransformerLens's model.add_hook() (non-permanent).

        Non-permanent hooks are removed by model.reset_hooks(including_permanent=False),
        which is what __exit__ calls. TranscoderReplacementHooks uses permanent hooks, so
        they survive this reset.
        """
        for name in self.frozen_values:
            self.model.add_hook(name, self._freeze_hook)
        return self

    def __exit__(self, *args):
        """Remove all non-permanent hooks (freeze hooks only; permanent tc_hooks survive)."""
        self.model.reset_hooks(including_permanent=False)

Exercise - implement TranscoderReplacementHooks

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 25-30 minutes on this exercise.

Now implement the TranscoderReplacementHooks class. It replaces the MLP with the transcoder's linear skip connection during the backward pass, while keeping the correct MLP output during the forward pass.

The key trick is the stop-gradient skip connection:

skip = ln2(resid_mid) @ W_skip
mlp_out_new = skip + (mlp_out_original - skip).detach()

During the forward pass, mlp_out_new = skip + mlp_out_original - skip = mlp_out_original - so we get the exact same output.

During the backward pass, the .detach() kills the gradient through (mlp_out_original - skip), so gradients only flow through skip = ln2(resid_mid) @ W_skip - the linear approximation.

This class should also compute and store the transcoder activations (which features fire and how strongly) for each layer, since we'll need these to build the graph nodes.

You'll need to use the following attributes of a JumpReLUSkipTranscoder: - tc.cfg.metadata.hook_name: the hook name for the transcoder input (e.g. "blocks.0.mlp.hook_in") - tc.W_skip: the skip connection weight matrix, shape (d_model, d_model) - tc.encode(x): returns transcoder feature activations, shape (batch, seq, d_enc) - tc.decode(acts): returns transcoder output from feature activations, shape (batch, seq, d_model)

Hint - hook names

The transcoder input hook is at blocks.{layer}.mlp.hook_in (i.e. the input to the MLP, after LayerNorm). The MLP output hook is at blocks.{layer}.hook_mlp_out.

You'll need to hook the MLP output to apply the skip connection trick.

Hint - structure

Your install method should iterate over each transcoder and: 1. Add a hook at the MLP output that applies the skip connection trick 2. Store the transcoder activations (from tc.encode) for later use

You'll need to get the LN-normalised MLP input to compute both the skip connection and the transcoder activations. You can get this from the FreezeHooks cache, or hook it during the forward pass.

Hint - hook function pattern

Use model.add_hook(output_hook, fn, is_permanent=True) to register a TL-native hook. The hook function signature is hook_fn(mlp_out: Tensor, hook: HookPoint) -> Tensor where mlp_out is the MLP output tensor. Your hook should return a modified version using the skip connection trick.

Register with is_permanent=True so the hook survives FreezeHooks.__exit__, which calls model.reset_hooks(including_permanent=False) and only removes non-permanent hooks.

Be careful about Python closures in loops! Use a factory function (make_hook(layer, tc)) to capture the correct layer and transcoder for each hook.

from sae_lens import JumpReLUSkipTranscoder, JumpReLUTranscoder

Transcoder: TypeAlias = JumpReLUTranscoder | JumpReLUSkipTranscoder


class TranscoderReplacementHooks:
    """
    Installs hooks that replace MLP backward passes with linear skip connections, while
    computing and storing transcoder feature activations for graph construction.

    The skip connection trick ensures:
    - Forward pass: exact MLP output (unchanged)
    - Backward pass: gradients flow through linear skip connection (W_skip)

    Uses TransformerLens-native hooks registered as PERMANENT (is_permanent=True), so they
    survive FreezeHooks.__exit__ which only resets non-permanent hooks. Call remove() when
    done to clear permanent hooks via model.reset_hooks(including_permanent=True).
    """

    def __init__(
        self,
        model: HookedSAETransformer,
        transcoders: dict[int, "Transcoder"],
        cache: ActivationCache,
    ):
        self.model = model
        self.transcoders = transcoders
        self.cache = cache  # From FreezeHooks.cache_frozen_values
        self.transcoder_acts: dict[int, Tensor] = {}  # layer -> feature activations
        self.transcoder_output: dict[int, Tensor] = {}  # layer -> transcoder reconstruction

        self.current_ln_inputs: dict[int, Tensor] = {}  # updated each forward pass

    # def install(self):
    #     """Install hooks for all transcoders."""
    #     self.current_ln_inputs: dict[int, Tensor] = {}
    #     for layer, tc in self.transcoders.items():
    #         # Get hook names for this layer
    #         input_hook = tc.cfg.metadata.hook_name  # e.g. "blocks.0.mlp.hook_in"
    #         output_hook = tc.cfg.metadata.hook_name_out  # e.g. "blocks.0.hook_mlp_out"
    #
    #         # Compute and store transcoder activations (no gradient needed)
    #         with t.no_grad():
    #             ln_input = self.cache[input_hook]
    #             tc_acts = tc.encode(ln_input)
    #             tc_output = tc.decode(tc_acts)
    #             self.transcoder_acts[layer] = tc_acts.detach()
    #             self.transcoder_output[layer] = tc_output.detach()
    #
    #         # YOUR CODE HERE - register TWO permanent TL-native hooks:
    #         #
    #         # 1. A capture hook at input_hook (mlp.hook_in):
    #         #    Signature: (tensor, hook) -> None
    #         #    Store the live ln_input in self.current_ln_inputs[layer].
    #         #    This MUST use the live tensor, not self.cache[input_hook] - the cached
    #         #    tensor is attached to a stale graph that gets freed after the first backward().
    #         #
    #         # 2. A skip hook at output_hook (hook_mlp_out):
    #         #    Signature: (mlp_out, hook) -> Tensor
    #         #    Apply the skip connection trick using self.current_ln_inputs[layer]:
    #         #       skip = ln_input @ tc.W_skip
    #         #       return skip + (mlp_out - skip).detach()
    #         #
    #         # Register both with is_permanent=True so they survive FreezeHooks context exits.
    #         # Use factory functions to avoid closure issues in the loop.
    #         pass

    def remove(self):
        """Remove all hooks, including permanent ones added by install()."""
        self.model.reset_hooks(including_permanent=True)
# Test: verify the skip connection trick gives correct forward pass
freeze = FreezeHooks(gemma)
cache = freeze.cache_frozen_values(tokens)

tc_hooks = TranscoderReplacementHooks(gemma, transcoders, cache)
tc_hooks.install()

with freeze:
    logits_with_hooks = gemma(tokens)

tc_hooks.remove()

# Compare with original logits
logits_original = gemma(tokens)

print(f"Max difference in logits: {(logits_with_hooks - logits_original).abs().max().item():.6f}")
print("(Should be ~0 since the skip trick preserves forward pass values)")
Solution
from sae_lens import JumpReLUSkipTranscoder, JumpReLUTranscoder

Transcoder: TypeAlias = JumpReLUTranscoder | JumpReLUSkipTranscoder


class TranscoderReplacementHooks:
    """
    Installs hooks that replace MLP backward passes with linear skip connections, while
    computing and storing transcoder feature activations for graph construction.

    The skip connection trick ensures:
    - Forward pass: exact MLP output (unchanged)
    - Backward pass: gradients flow through linear skip connection (W_skip)

    Uses TransformerLens-native hooks registered as PERMANENT (is_permanent=True), so they
    survive FreezeHooks.__exit__ which only resets non-permanent hooks. Call remove() when
    done to clear permanent hooks via model.reset_hooks(including_permanent=True).
    """

    def __init__(
        self,
        model: HookedSAETransformer,
        transcoders: dict[int, "Transcoder"],
        cache: ActivationCache,
    ):
        self.model = model
        self.transcoders = transcoders
        self.cache = cache  # From FreezeHooks.cache_frozen_values
        self.transcoder_acts: dict[int, Tensor] = {}  # layer -> feature activations
        self.transcoder_output: dict[int, Tensor] = {}  # layer -> transcoder reconstruction

        self.current_ln_inputs: dict[int, Tensor] = {}  # updated each forward pass

    def install(self):
        """
        Install hooks for all transcoders.

        Two permanent hooks are registered per layer:
          1. hook_mlp_in - saves the live LN-normalised MLP input each forward pass.
          2. hook_mlp_out - applies the skip connection trick using that saved input.

        Using the LIVE ln_input (not the cached one) is critical: the cached tensor has a
        grad_fn attached to the original run_with_cache graph, which PyTorch frees after the
        first backward(). Subsequent batches would then fail with "trying to backward through
        the graph a second time". Using the live tensor keeps every backward pass within its
        own fresh computation graph.
        """
        for layer, tc in self.transcoders.items():
            # Get the LN-normalised MLP input from the cache
            input_hook = tc.cfg.metadata.hook_name  # e.g. "blocks.0.mlp.hook_in"
            output_hook = tc.cfg.metadata.hook_name_out  # e.g. "blocks.0.hook_mlp_out"

            # Compute transcoder activations (no gradient needed for this)
            with t.no_grad():
                ln_input = self.cache[input_hook]
                tc_acts = tc.encode(ln_input)
                tc_output = tc.decode(tc_acts)
                self.transcoder_acts[layer] = tc_acts.detach()
                self.transcoder_output[layer] = tc_output.detach()

            def make_capture_hook(layer_idx: int):
                """Hook at mlp.hook_in: save the live ln_input for this forward pass."""

                def hook_fn(tensor: Tensor, hook: HookPoint) -> None:
                    self.current_ln_inputs[layer_idx] = tensor

                return hook_fn

            def make_skip_hook(layer_idx: int, tc_ref: "Transcoder"):
                """Hook at hook_mlp_out: apply the skip connection trick."""

                def hook_fn(mlp_out: Tensor, hook: HookPoint) -> Tensor:
                    if hasattr(tc_ref, "W_skip"):
                        # Use the live ln_input saved by the capture hook above.
                        # This keeps the skip's grad_fn inside the current graph, so
                        # backward() works correctly across multiple batch iterations.
                        ln_input = self.current_ln_inputs[layer_idx]
                        skip = ln_input @ tc_ref.W_skip
                    else:
                        skip = t.zeros_like(mlp_out)
                    # Skip connection trick: forward gives mlp_out, backward gives skip grad
                    return skip + (mlp_out - skip).detach()

                return hook_fn

            # Both hooks are PERMANENT so they survive FreezeHooks.__exit__ which only
            # calls reset_hooks(including_permanent=False).
            self.model.add_hook(input_hook, make_capture_hook(layer), is_permanent=True)
            self.model.add_hook(output_hook, make_skip_hook(layer, tc), is_permanent=True)


    def remove(self):
        """Remove all hooks, including permanent ones added by install()."""
        self.model.reset_hooks(including_permanent=True)

Sanity check: average L0 per transcoder layer

Before building the attribution graph, let's verify that the transcoders are producing a reasonable number of active features per token. The L0 of a sparse autoencoder is the average number of non-zero latents per token position. If L0 is high (e.g. above 50), this explains an unexpectedly large latent node count in the graph - and likely means the wrong transcoder variant was loaded (we want l0_small, not l0_big). We exclude the first START_POSN tokens (chat-formatting tokens) since those aren't meaningful content positions.

print("Average L0 per transcoder layer (excluding first 4 tokens):")
for layer in range(len(transcoders)):
    acts = tc_hooks.transcoder_acts[layer][0, START_POSN:, :]  # (seq - START_POSN, d_enc)
    l0_per_pos = (acts > 0).float().sum(dim=-1)  # (seq - START_POSN,)
    avg_l0 = l0_per_pos.mean().item()
    print(f"  Layer {layer}: avg L0 = {avg_l0:.1f}")
    assert avg_l0 <= 50, (
        f"Layer {layer} avg L0 = {avg_l0:.1f} exceeds 50 - check that you loaded the "
        f"'l0_small' transcoder variant, not 'l0_big'"
    )
print("L0 sanity check passed!")
Click to see the expected output

Building the attribution graph

Now that we can linearise the model, we need to define the nodes and edges of our attribution graph.

Salient logit selection

Our output nodes are logit directions, i.e. the unembedding vectors for the top predicted tokens. Rather than using all tokens in the vocabulary (which would be wasteful), we select only the salient ones: those that account for most of the model's probability mass.

Specifically, we take the top-k tokens by predicted probability (choosing k so that these tokens cover at least some threshold fraction of total probability, or using a fixed k like 3-5), and we use their demeaned unembedding vectors as reading directions.

Why demean? The logit lens tells us that the residual stream at the final position gets multiplied by the unembedding matrix W_U to produce logits. But a constant vector added to the residual stream affects all logits equally and doesn't change the softmax output. So we subtract the mean of W_U columns from each selected column to focus on what makes this token's logit different from the average.

Exercise - implement compute_salient_logits

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-15 minutes on this exercise.

Implement the function that selects the top predicted tokens and returns their demeaned unembedding vectors (which will serve as reading vectors for our output nodes).

def compute_salient_logits(
    model: HookedSAETransformer,
    logits: Float[Tensor, "batch seq d_vocab"],
    n_output_nodes: int = 3,
) -> tuple[Float[Tensor, "n_output d_model"], list[tuple[str, float]]]:
    """
    Select the top predicted tokens and return their demeaned W_U columns.

    Args:
        model: The transformer model.
        logits: Full logit tensor from model forward pass.
        n_output_nodes: Number of top tokens to select.

    Returns:
        reading_vecs: Demeaned unembedding vectors for top tokens, shape (n_output, d_model).
        top_token_info: List of (token_string, probability) tuples for the selected tokens.
    """
    raise NotImplementedError()
reading_vecs, top_token_info = compute_salient_logits(gemma, logits_original)
print("Top predicted tokens:")
for tok_str, prob in top_token_info:
    print(f"  {tok_str!r}: p={prob:.4f}")
print(f"Reading vectors shape: {reading_vecs.shape}")
Click to see the expected output

Solution
def compute_salient_logits(
    model: HookedSAETransformer,
    logits: Float[Tensor, "batch seq d_vocab"],
    n_output_nodes: int = 3,
) -> tuple[Float[Tensor, "n_output d_model"], list[tuple[str, float]]]:
    """
    Select the top predicted tokens and return their demeaned W_U columns.

    Args:
        model: The transformer model.
        logits: Full logit tensor from model forward pass.
        n_output_nodes: Number of top tokens to select.

    Returns:
        reading_vecs: Demeaned unembedding vectors for top tokens, shape (n_output, d_model).
        top_token_info: List of (token_string, probability) tuples for the selected tokens.
    """
    # Get probabilities for the last sequence position
    final_logits = logits[0, -1]  # (d_vocab,)
    probs = t.softmax(final_logits, dim=-1)

    # Get top-k tokens
    top_probs, top_tokens = probs.topk(n_output_nodes)

    # Get the unembedding matrix and select columns for top tokens
    W_U = model.W_U  # (d_model, d_vocab)
    selected_cols = W_U[:, top_tokens].T  # (n_output, d_model)

    # Demean: subtract the mean unembedding vector
    W_U_mean = W_U.mean(dim=-1)  # (d_model,)
    reading_vecs = selected_cols - W_U_mean.unsqueeze(0)

    # Get token strings
    top_token_info = [(model.tokenizer.decode(tok.item()), prob.item()) for tok, prob in zip(top_tokens, top_probs)]

    return reading_vecs, top_token_info

The reading/writing vector abstraction

Now we need to understand the reading and writing vectors for each node type. This is the conceptual core of attribution graphs.

Every latent interacts with the residual stream in two ways. When a latent fires, it writes activation * W_dec[latent] into the residual stream. And a latent fires based on how much the residual stream aligns with its encoder direction, so it reads via W_enc.T[latent].

Here's how reading and writing vectors break down across node types:

Node type Writing vector Reading vector
Token embedding The token embedding vector Zero vector (embeddings are inputs)
Transcoder latent activation * W_dec[latent] W_enc.T[latent]
MLP error mlp_output - transcoder_reconstruction Zero vector (error term doesn't read)
Output logit Zero vector (logits are outputs) Demeaned W_U[:, token]

To compute an edge weight from node A to node B, we take A's writing vector, map it through the frozen intermediate layers (attention + skip connections) between A and B, and take the dot product with B's reading vector.

In the automatic (gradient-based) method, we don't need to do this mapping explicitly. Instead, we inject B's reading vector as a gradient seed and run a backward pass through the frozen model, which implicitly computes the mapping. The edge weight is then the dot product of the resulting gradient with A's writing vector.

We'll build the graph nodes in three steps, each as its own exercise. First we create embedding nodes, then per-layer feature and error nodes, and finally logit nodes. After all three, we'll combine them into the full GraphNodes object.

The node ordering should be: [input embeddings] [layer 0 features] [layer 0 error] [layer 1 features] ... [layer N-1 error] [output logits]. This ordering matters because it makes the adjacency matrix lower triangular (information only flows from earlier to later layers), which means the matrix is nilpotent. We'll exploit this property later for efficient influence computation.

class NodeType(Enum):
    EMBEDDING = "embedding"
    LATENT = "latent"
    MLP_ERROR = "mlp_error"
    LOGIT = "logit"


@dataclass
class NodeInfo:
    """Metadata about a single node in the attribution graph."""

    node_type: NodeType
    layer: int | str  # "E" for embeddings, layer_idx for features, "L" for logits
    ctx_idx: int  # sequence position
    feature: int  # feature index (token ID for embeds, feature ID for latents, token ID for logits)
    activation: float = 0.0  # feature activation (for latent nodes)
    token_prob: float = 0.0  # probability (for logit nodes)
    str_token: str = ""  # string representation (for embed/logit nodes)
    label: str = ""  # human-readable label

    @property
    def node_id(self) -> str:
        return f"{self.layer}_{self.feature}_{self.ctx_idx}"


@dataclass
class GraphNodes:
    """Container for all nodes in an attribution graph, with their reading/writing vectors."""

    nodes: list[NodeInfo] = field(default_factory=list)
    writing_vecs: Tensor | None = None  # (n_nodes, d_model) or None
    reading_vecs: Tensor | None = None  # (n_nodes, d_model) or None
    node_range_dict: dict = field(default_factory=dict)  # layer -> (start_idx, end_idx)
    seq_len: int = 0
    n_layers: int = 0

Exercise (1/3) - implement build_embedding_nodes

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 5-10 minutes on this exercise.

The first step in building the attribution graph is creating one node per token position for the input embeddings. Each embedding node has type EMBEDDING, writes its embedding vector into the residual stream, and has a zero reading vector (since embeddings don't read from anything).

You should use cache["blocks.0.hook_resid_pre"] to get the embedding vectors.

def build_embedding_nodes(
    model: HookedSAETransformer,
    cache: ActivationCache,
    tokens: Int[Tensor, "1 seq"],
) -> tuple[list[NodeInfo], list[Tensor], list[Tensor]]:
    """
    Build embedding nodes (one per token position) with their writing and reading vectors.

    Args:
        model: The transformer model.
        cache: Activation cache from forward pass.
        tokens: Input token IDs, shape (1, seq).

    Returns:
        Tuple of (nodes, writing_vecs, reading_vecs) where each list has length seq_len.
        Writing vecs are the token embeddings, reading vecs are zeros.
    """
    raise NotImplementedError()


embed_nodes, embed_writing, embed_reading = build_embedding_nodes(model=gemma, cache=cache, tokens=tokens)
seq_len = tokens.shape[1]

assert len(embed_nodes) == seq_len, f"Expected {seq_len} embedding nodes, got {len(embed_nodes)}"
assert all(n.node_type == NodeType.EMBEDDING for n in embed_nodes), "All embedding nodes should have type EMBEDDING"
for i, wv in enumerate(embed_writing):
    assert wv.shape == (gemma.cfg.d_model,), f"Writing vec {i} has wrong shape: {wv.shape}"
    assert wv.abs().sum() > 0, f"Writing vec {i} is all zeros but should be a token embedding"
for i, rv in enumerate(embed_reading):
    assert rv.shape == (gemma.cfg.d_model,), f"Reading vec {i} has wrong shape: {rv.shape}"
    t.testing.assert_close(rv, t.zeros_like(rv), msg=f"Reading vec {i} should be all zeros")

print("All build_embedding_nodes tests passed!")
Solution
def build_embedding_nodes(
    model: HookedSAETransformer,
    cache: ActivationCache,
    tokens: Int[Tensor, "1 seq"],
) -> tuple[list[NodeInfo], list[Tensor], list[Tensor]]:
    """
    Build embedding nodes (one per token position) with their writing and reading vectors.

    Args:
        model: The transformer model.
        cache: Activation cache from forward pass.
        tokens: Input token IDs, shape (1, seq).

    Returns:
        Tuple of (nodes, writing_vecs, reading_vecs) where each list has length seq_len.
        Writing vecs are the token embeddings, reading vecs are zeros.
    """
    seq_len = tokens.shape[1]
    d_model = model.cfg.d_model
    str_tokens = [model.tokenizer.decode(t_id.item()) for t_id in tokens[0]]

    nodes: list[NodeInfo] = []
    writing_vecs: list[Tensor] = []
    reading_vecs: list[Tensor] = []

    embed_vecs = cache["blocks.0.hook_resid_pre"][0]  # (seq, d_model)
    for pos in range(seq_len):
        nodes.append(
            NodeInfo(
                node_type=NodeType.EMBEDDING,
                layer="E",
                ctx_idx=pos,
                feature=tokens[0, pos].item(),
                str_token=str_tokens[pos],
            )
        )
        writing_vecs.append(embed_vecs[pos])
        reading_vecs.append(t.zeros(d_model, device=embed_vecs.device))

    return nodes, writing_vecs, reading_vecs

Exercise (2/3) - implement build_intermediate_nodes

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 15-20 minutes on this exercise.

Next, we build the intermediate nodes: for each layer and each position (from start_posn onwards), we select the top top_k transcoder latents by activation and create a LATENT node for each one. We also create one MLP_ERROR node per position per layer to capture the reconstruction error.

We've given you the outer loop structure (iterating over layers/positions and extracting the active feature indices). You need to fill in the inner loop body that creates the actual nodes and their vectors:

  • For each active feature: create a LATENT node with writing vector activation * W_dec[feature_idx] (the activation-scaled decoder direction) and reading vector W_enc.T[feature_idx] (the encoder direction).
  • After processing all features at a given position: create one MLP_ERROR node with writing vector mlp_output[pos] - tc_output[pos] and reading vector zeros.
Hint - computing the MLP error

You can get the MLP output from the cache using tc.cfg.metadata.hook_name_out (e.g. "blocks.0.hook_mlp_out"). Then the error is mlp_output[pos] - tc_output[pos].

def build_intermediate_nodes(
    transcoders: dict[int, "Transcoder"],
    cache: ActivationCache,
    tc_hooks: TranscoderReplacementHooks,
    tokens: Int[Tensor, "1 seq"],
    d_model: int,
    start_posn: int = 4,
    top_k: int = 5,
) -> tuple[list[NodeInfo], list[Tensor], list[Tensor], dict]:
    """
    Build intermediate (LATENT + MLP_ERROR) nodes for all layers.

    Args:
        transcoders: Dict mapping layer -> transcoder.
        cache: Activation cache from forward pass.
        tc_hooks: TranscoderReplacementHooks with computed activations.
        tokens: Input token IDs, shape (1, seq).
        d_model: Model hidden dimension.
        start_posn: First position to include (masks chat formatting tokens).
        top_k: Number of top-activating features to include per position per layer.

    Returns:
        Tuple of (nodes, writing_vecs, reading_vecs, node_range_dict) where node_range_dict
        maps each layer index to (start_idx, end_idx) within the returned lists.
    """
    seq_len = tokens.shape[1]
    n_layers = len(transcoders)
    device = cache["blocks.0.hook_resid_pre"].device

    nodes: list[NodeInfo] = []
    writing_vecs: list[Tensor] = []
    reading_vecs: list[Tensor] = []
    node_range_dict = {}

    for layer in range(n_layers):
        tc = transcoders[layer]
        layer_start = len(nodes)

        tc_acts = tc_hooks.transcoder_acts[layer][0]  # (seq, d_enc)
        tc_output = tc_hooks.transcoder_output[layer][0]  # (seq, d_model)

        # Get MLP output from cache for error computation
        mlp_output = cache[tc.cfg.metadata.hook_name_out][0]  # (seq, d_model)

        W_dec = tc.W_dec.detach()  # (d_enc, d_model)
        W_enc_T = tc.W_enc.detach().T  # (d_enc, d_model)

        for pos in range(start_posn, seq_len):
            acts = tc_acts[pos]  # (d_enc,)
            # Select top-k features by activation
            k = min(top_k, (acts > 0).sum().item())  # don't include zero-activation features
            if k > 0:
                _, active_indices = acts.topk(k)
            else:
                active_indices = t.where(acts > 0)[0]  # fallback: empty if no active features

            # TODO: For each active feature, create a LATENT node with the correct writing
            # vector (activation * W_dec[feature]) and reading vector (W_enc.T[feature]).
            # After all features at this position, create one MLP_ERROR node with writing
            # vector = mlp_output[pos] - tc_output[pos] and reading vector = zeros.
            pass

        node_range_dict[layer] = (layer_start, len(nodes))

    return nodes, writing_vecs, reading_vecs, node_range_dict


inter_nodes, inter_writing, inter_reading, inter_range_dict = build_intermediate_nodes(
    transcoders=transcoders,
    cache=cache,
    tc_hooks=tc_hooks,
    tokens=tokens,
    d_model=gemma.cfg.d_model,
    start_posn=START_POSN,
)

n_layers = len(transcoders)
seq_len = tokens.shape[1]
active_positions = seq_len - START_POSN

# All LATENT nodes should have positive activations
latent_nodes = [n for n in inter_nodes if n.node_type == NodeType.LATENT]
assert len(latent_nodes) > 0, "Expected at least some LATENT nodes"
assert all(n.activation > 0 for n in latent_nodes), "All LATENT nodes should have positive activations"

# MLP_ERROR nodes: one per position per layer
error_nodes = [n for n in inter_nodes if n.node_type == NodeType.MLP_ERROR]
assert len(error_nodes) == active_positions * n_layers, (
    f"Expected {active_positions * n_layers} MLP_ERROR nodes, got {len(error_nodes)}"
)

# Writing vectors for LATENT nodes should be non-zero
latent_indices = [i for i, n in enumerate(inter_nodes) if n.node_type == NodeType.LATENT]
for idx in latent_indices[:5]:  # check first 5
    assert inter_writing[idx].abs().sum() > 0, "LATENT writing vectors should be non-zero"

# Reading vectors for MLP_ERROR nodes should be zeros
error_indices = [i for i, n in enumerate(inter_nodes) if n.node_type == NodeType.MLP_ERROR]
for idx in error_indices[:5]:  # check first 5
    t.testing.assert_close(inter_reading[idx], t.zeros_like(inter_reading[idx]))

# node_range_dict should have entries for each layer
assert set(inter_range_dict.keys()) == set(range(n_layers)), (
    f"node_range_dict should have entries for layers 0..{n_layers - 1}, got keys {set(inter_range_dict.keys())}"
)

print("All build_intermediate_nodes tests passed!")
Solution
def build_intermediate_nodes(
    transcoders: dict[int, "Transcoder"],
    cache: ActivationCache,
    tc_hooks: TranscoderReplacementHooks,
    tokens: Int[Tensor, "1 seq"],
    d_model: int,
    start_posn: int = 4,
    top_k: int = 5,
) -> tuple[list[NodeInfo], list[Tensor], list[Tensor], dict]:
    """
    Build intermediate (LATENT + MLP_ERROR) nodes for all layers.

    Args:
        transcoders: Dict mapping layer -> transcoder.
        cache: Activation cache from forward pass.
        tc_hooks: TranscoderReplacementHooks with computed activations.
        tokens: Input token IDs, shape (1, seq).
        d_model: Model hidden dimension.
        start_posn: First position to include (masks chat formatting tokens).
        top_k: Number of top-activating features to include per position per layer.

    Returns:
        Tuple of (nodes, writing_vecs, reading_vecs, node_range_dict) where node_range_dict
        maps each layer index to (start_idx, end_idx) within the returned lists.
    """
    seq_len = tokens.shape[1]
    n_layers = len(transcoders)
    device = cache["blocks.0.hook_resid_pre"].device

    nodes: list[NodeInfo] = []
    writing_vecs: list[Tensor] = []
    reading_vecs: list[Tensor] = []
    node_range_dict = {}

    for layer in range(n_layers):
        tc = transcoders[layer]
        layer_start = len(nodes)

        tc_acts = tc_hooks.transcoder_acts[layer][0]  # (seq, d_enc)
        tc_output = tc_hooks.transcoder_output[layer][0]  # (seq, d_model)

        # Get MLP output from cache for error computation
        mlp_output = cache[tc.cfg.metadata.hook_name_out][0]  # (seq, d_model)

        W_dec = tc.W_dec.detach()  # (d_enc, d_model)
        W_enc_T = tc.W_enc.detach().T  # (d_enc, d_model)

        for pos in range(start_posn, seq_len):
            acts = tc_acts[pos]  # (d_enc,)
            # Select top-k features by activation
            k = min(top_k, (acts > 0).sum().item())  # don't include zero-activation features
            if k > 0:
                _, active_indices = acts.topk(k)
            else:
                active_indices = t.where(acts > 0)[0]  # fallback: empty if no active features

            for feat_idx in active_indices:
                feat_act = acts[feat_idx].item()
                nodes.append(
                    NodeInfo(
                        node_type=NodeType.LATENT,
                        layer=layer,
                        ctx_idx=pos,
                        feature=feat_idx.item(),
                        activation=feat_act,
                    )
                )
                # Writing vector: activation * decoder direction
                writing_vecs.append(feat_act * W_dec[feat_idx])
                # Reading vector: encoder direction
                reading_vecs.append(W_enc_T[feat_idx])

            # MLP error node for this position
            mlp_error = mlp_output[pos] - tc_output[pos]
            nodes.append(
                NodeInfo(
                    node_type=NodeType.MLP_ERROR,
                    layer=layer,
                    ctx_idx=pos,
                    feature=0,
                )
            )
            writing_vecs.append(mlp_error.detach())
            reading_vecs.append(t.zeros(d_model, device=device))

        node_range_dict[layer] = (layer_start, len(nodes))

    return nodes, writing_vecs, reading_vecs, node_range_dict

Finally, we build the output (logit) nodes. There is one node per entry in top_token_info, positioned at the last sequence position. These nodes have zero writing vectors (they don't write to the residual stream) and their reading vectors are the demeaned unembedding columns from reading_vecs_logit. We've given you this function since it follows the same pattern as build_embedding_nodes.

def build_logit_nodes(
    reading_vecs_logit: Float[Tensor, "n_output d_model"],
    top_token_info: list[tuple[str, float]],
    seq_len: int,
    d_model: int,
) -> tuple[list[NodeInfo], list[Tensor], list[Tensor]]:
    """
    Build logit (output) nodes with their writing and reading vectors.

    Args:
        reading_vecs_logit: Demeaned W_U columns for output nodes, shape (n_output, d_model).
        top_token_info: List of (token_string, probability) for output nodes.
        seq_len: Total sequence length (logit nodes are placed at seq_len - 1).
        d_model: Model hidden dimension.

    Returns:
        Tuple of (nodes, writing_vecs, reading_vecs) where each list has length len(top_token_info).
        Writing vecs are zeros, reading vecs are the logit reading vectors.
    """
    device = reading_vecs_logit.device

    nodes: list[NodeInfo] = []
    writing_vecs: list[Tensor] = []
    reading_vecs: list[Tensor] = []

    for i, (tok_str, prob) in enumerate(top_token_info):
        nodes.append(
            NodeInfo(
                node_type=NodeType.LOGIT,
                layer="L",
                ctx_idx=seq_len - 1,
                feature=i,
                token_prob=prob,
                str_token=tok_str,
            )
        )
        writing_vecs.append(t.zeros(d_model, device=device))
        reading_vecs.append(reading_vecs_logit[i])

    return nodes, writing_vecs, reading_vecs

Now that all three helper functions are done, we combine them into a single build_graph_nodes function. This wrapper is provided for you and just calls the three sub-functions, assembles the results, and returns a GraphNodes object.

def build_graph_nodes(
    model: HookedSAETransformer,
    transcoders: dict[int, "Transcoder"],
    cache: ActivationCache,
    tc_hooks: TranscoderReplacementHooks,
    reading_vecs_logit: Float[Tensor, "n_output d_model"],
    top_token_info: list[tuple[str, float]],
    tokens: Int[Tensor, "1 seq"],
    start_posn: int = 4,
    top_k: int = 5,
) -> GraphNodes:
    """
    Build all graph nodes with their reading and writing vectors by calling the three sub-functions
    and assembling the results.

    Args:
        model: The transformer model.
        transcoders: Dict mapping layer -> transcoder.
        cache: Activation cache from forward pass.
        tc_hooks: TranscoderReplacementHooks with computed activations.
        reading_vecs_logit: Demeaned W_U columns for output nodes.
        top_token_info: List of (token_string, probability) for output nodes.
        tokens: Input token IDs, shape (1, seq).
        start_posn: First position to include (masks chat formatting tokens).
        top_k: Number of top-activating features to include per position per layer.

    Returns:
        GraphNodes containing all nodes, their reading/writing vectors, and index ranges.
    """
    seq_len = tokens.shape[1]
    n_layers = len(transcoders)
    d_model = model.cfg.d_model

    all_nodes: list[NodeInfo] = []
    all_writing: list[Tensor] = []
    all_reading: list[Tensor] = []
    node_range_dict = {}

    # 1. Embedding nodes
    embed_nodes, embed_writing, embed_reading = build_embedding_nodes(model, cache, tokens)
    all_nodes.extend(embed_nodes)
    all_writing.extend(embed_writing)
    all_reading.extend(embed_reading)
    node_range_dict["E"] = (0, len(all_nodes))

    # 2. Intermediate nodes (features + MLP error per layer)
    inter_nodes, inter_writing, inter_reading, inter_ranges = build_intermediate_nodes(
        transcoders,
        cache,
        tc_hooks,
        tokens,
        d_model,
        start_posn,
        top_k,
    )
    offset = len(all_nodes)
    all_nodes.extend(inter_nodes)
    all_writing.extend(inter_writing)
    all_reading.extend(inter_reading)
    for layer_key, (s, e) in inter_ranges.items():
        node_range_dict[layer_key] = (s + offset, e + offset)

    # 3. Logit nodes
    logit_start = len(all_nodes)
    logit_nodes, logit_writing, logit_reading = build_logit_nodes(
        reading_vecs_logit,
        top_token_info,
        seq_len,
        d_model,
    )
    all_nodes.extend(logit_nodes)
    all_writing.extend(logit_writing)
    all_reading.extend(logit_reading)
    node_range_dict["L"] = (logit_start, len(all_nodes))

    return GraphNodes(
        nodes=all_nodes,
        writing_vecs=t.stack(all_writing),
        reading_vecs=t.stack(all_reading),
        node_range_dict=node_range_dict,
        seq_len=seq_len,
        n_layers=n_layers,
    )
graph = build_graph_nodes(
    model=gemma,
    transcoders=transcoders,
    cache=cache,
    tc_hooks=tc_hooks,
    reading_vecs_logit=reading_vecs,
    top_token_info=top_token_info,
    tokens=tokens,
    start_posn=START_POSN,
)

# Print some stats
n_embeds = sum(1 for n in graph.nodes if n.node_type == NodeType.EMBEDDING)
n_latents = sum(1 for n in graph.nodes if n.node_type == NodeType.LATENT)
n_errors = sum(1 for n in graph.nodes if n.node_type == NodeType.MLP_ERROR)
n_logits = sum(1 for n in graph.nodes if n.node_type == NodeType.LOGIT)
print(f"Graph has {len(graph.nodes)} total nodes:")
print(f"  {n_embeds} embedding nodes")
print(f"  {n_latents} latent nodes")
print(f"  {n_errors} MLP error nodes")
print(f"  {n_logits} logit output nodes")
print(f"Writing vectors shape: {graph.writing_vecs.shape}")
print(f"Reading vectors shape: {graph.reading_vecs.shape}")
Click to see the expected output

Attribution setup

Before computing the adjacency matrix, we need a helper that sets up the backward pass properly. The idea is: (1) run a forward pass through the frozen model (with FreezeHooks and TranscoderReplacementHooks active), (2) at the target node's position and layer, compute the dot product of the residual stream with the target's reading vector, (3) backpropagate this scalar through the frozen model, and (4) read off the gradients at each source node's position and contract them with the source's writing vector.

We'll give you this setup function, which handles steps 1-3 for a batch of target nodes. Read through it carefully.

Hook design note (Option A). Rather than nesting with freeze: inside the function and separately installing/removing gradient-capture hooks, setup_attribution uses a single model.hooks() call that combines everything:

with model.hooks(fwd_hooks=freeze.fwd_hooks + capture_fwd_hooks, bwd_hooks=capture_bwd_hooks):
    model(tokens).backward()
  • freeze.fwd_hooks - returns the freeze hooks as (name, fn) pairs (TL-native format)
  • capture_fwd_hooks - store each residual-stream activation tensor so we can build objectives
  • capture_bwd_hooks - capture gradients directly during the backward pass, eliminating the need for tensor.retain_grad() + tensor.grad bookkeeping

TranscoderReplacementHooks uses permanent hooks (registered with is_permanent=True) so they remain active throughout the model.hooks() context and are only removed by tc_hooks.remove().

def setup_attribution(
    model: HookedSAETransformer,
    tokens: Tensor,
    freeze: FreezeHooks,
    target_reading_vecs: Float[Tensor, "batch d_model"],
    target_positions: Int[Tensor, " batch"],
    target_layers: list[int | str],
) -> dict[str, Tensor]:
    """
    Run forward pass through frozen model, inject reading vectors as gradient seeds at
    target positions/layers, and return gradients at all residual stream positions.

    This function handles the "backward pass" part of attribution: for each target node,
    it computes d(objective)/d(resid) at every layer, where objective = dot(resid[target_pos], reading_vec).

    Uses TransformerLens-native hooks via a single model.hooks() context (Option A):
      - fwd_hooks = freeze.fwd_hooks + capture_fwd_hooks
          freeze.fwd_hooks  : replace attention patterns/LN scales with frozen values
          capture_fwd_hooks : store each residual-stream tensor for objective computation
      - bwd_hooks = capture_bwd_hooks
          capture_bwd_hooks : receive gradients directly during backward, no retain_grad() needed

    TranscoderReplacementHooks are permanent and stay active inside this context.

    Args:
        model: The transformer model.
        tokens: Input tokens, shape (1, seq).
        freeze: FreezeHooks with cached frozen values (provides fwd_hooks property).
        target_reading_vecs: Reading vectors for target nodes, shape (batch, d_model).
        target_positions: Sequence positions of target nodes, shape (batch,).
        target_layers: Layer identifiers for target nodes (int for intermediate, "L" for logits).

    Returns:
        grads: Dict mapping hook names -> gradient tensors at each residual stream position.
    """
    # Detach reading vectors - attribution only needs d(objective)/d(resid),
    # not d(objective)/d(W_U). Without this, reading_vecs built from W_U (a
    # parameter with requires_grad=True) carry a grad_fn that gets freed after
    # batch 1's backward(), causing "backward through graph a second time" on
    # subsequent batches.
    target_reading_vecs = target_reading_vecs.detach()

    batch_size = target_reading_vecs.shape[0]

    # Tensors captured during the forward pass (for computing objectives)
    captured_tensors: dict[str, Tensor] = {}
    # Gradients captured directly during the backward pass (the actual output)
    captured_grads: dict[str, Tensor] = {}

    # Names of residual-stream tensors to capture
    capture_names = (
        [f"blocks.{l}.hook_resid_post" for l in range(model.cfg.n_layers)]
        + [f"blocks.{l}.hook_resid_mid" for l in range(model.cfg.n_layers)]
        + ["blocks.0.hook_resid_pre"]
    )

    def make_fwd_capture(name: str) -> Callable:
        """Forward hook: store the activation tensor so we can build objectives from it."""

        def hook_fn(tensor: Tensor, hook: HookPoint) -> None:
            captured_tensors[name] = tensor

        return hook_fn

    def make_bwd_capture(name: str) -> Callable:
        """Backward hook: receive the gradient directly during backward, no retain_grad() needed."""

        def hook_fn(grad: Tensor, hook: HookPoint) -> None:
            captured_grads[name] = grad.detach()

        return hook_fn

    capture_fwd_hooks = [(name, make_fwd_capture(name)) for name in capture_names]
    capture_bwd_hooks = [(name, make_bwd_capture(name)) for name in capture_names]

    # Single model.hooks() context combining freeze hooks + capture hooks (Option A).
    # TranscoderReplacementHooks are permanent and stay active throughout.
    # At context exit, only non-permanent hooks (freeze + capture) are removed.
    with model.hooks(
        fwd_hooks=freeze.fwd_hooks + capture_fwd_hooks,
        bwd_hooks=capture_bwd_hooks,
    ):
        # Forward pass (objectives are computed from captured residual tensors)
        model(tokens.expand(batch_size, -1))

        # Compute objectives for each target node
        # Use torch.stack instead of in-place assignment to a leaf tensor, to
        # avoid autograd issues with in-place ops on zero-initialized tensors.
        objective_terms = []
        for i in range(batch_size):
            pos = target_positions[i]
            layer = target_layers[i]
            if layer == "L":
                # For logit nodes, use the final residual stream (pre-unembedding)
                resid = captured_tensors[f"blocks.{model.cfg.n_layers - 1}.hook_resid_post"][i, pos]
            else:
                # For intermediate nodes, use resid_mid at this layer
                # (after attention, before MLP - this is what the feature reads from)
                resid = captured_tensors[f"blocks.{layer}.hook_resid_mid"][i, pos]
            objective_terms.append((resid * target_reading_vecs[i]).sum())

        # Backward pass - fires bwd hooks, populating captured_grads
        total_objective = t.stack(objective_terms).sum()
        total_objective.backward()

    return captured_grads

We'll implement the attribution algorithm in two steps. First we prepare the backward pass batches (grouping target nodes by layer and batching their reading vectors), then we run the actual backward passes and contract the gradients with source writing vectors.

The adjacency matrix A[target, source] stores the edge weight from source to target. Since our nodes are ordered by layer, this matrix is strictly lower triangular (a target can only receive from earlier layers).

Exercise (1/2) - implement prepare_backward_batches

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 15-20 minutes on this exercise.

The first step is organizing the target nodes into batches for backward passes. For each layer in the graph (skipping embeddings, which are inputs only), group the target nodes into chunks of size batch_size. For each chunk, collect the reading vectors, positions, and layer identifiers.

Your function should return a list of tuples, where each tuple contains: the global start index of the batch in the node list, the batch of target nodes, their reading vectors as a stacked tensor, a tensor of their positions, and a list of their layer identifiers.

Hint - iterating over layers

Use graph.node_range_dict to iterate over layers. Skip the "E" key (embeddings). For each layer, node_range_dict[key] gives you the (start, end) range of node indices.

def prepare_backward_batches(
    graph: GraphNodes,
    batch_size: int = 8,
) -> list[tuple[int, list[NodeInfo], Tensor, Tensor, list[int | str]]]:
    """
    Group target nodes by layer and split into batches for backward passes.

    Args:
        graph: GraphNodes containing all nodes and their reading/writing vectors.
        batch_size: Number of target nodes to process per backward pass.

    Returns:
        List of (global_start_idx, batch_nodes, reading_vecs, positions, layers) tuples, where:
            global_start_idx: Index of the first node in this batch within graph.nodes
            batch_nodes: List of NodeInfo for this batch
            reading_vecs: Stacked reading vectors, shape (batch, d_model)
            positions: Tensor of sequence positions, shape (batch,)
            layers: List of layer identifiers (int or "L")
    """
    raise NotImplementedError()


batches = prepare_backward_batches(graph, batch_size=8)

# Check we got some batches
assert len(batches) > 0, "Expected at least one batch"

# Verify batch structure
all_target_layers_covered = set()
total_target_nodes = 0
for global_start, batch_nodes, reading_vecs, positions, layers in batches:
    # Reading vectors should have correct shape
    assert reading_vecs.shape[0] == len(batch_nodes), "Reading vecs batch size should match number of nodes"
    assert reading_vecs.shape[1] == gemma.cfg.d_model, (
        f"Reading vecs should have d_model={gemma.cfg.d_model} columns"
    )

    # Positions should match batch size
    assert positions.shape[0] == len(batch_nodes), "Positions batch size should match number of nodes"

    # Track which layers are covered
    for n in batch_nodes:
        layer_key = n.layer
        all_target_layers_covered.add(layer_key)
    total_target_nodes += len(batch_nodes)

# All non-embedding layers should be covered
expected_layers = {k for k in graph.node_range_dict.keys() if k != "E"}
assert all_target_layers_covered == expected_layers, (
    f"Expected layers {expected_layers} to be covered, got {all_target_layers_covered}"
)

# Total target nodes should match all non-embedding nodes
n_non_embed = len(graph.nodes) - (graph.node_range_dict["E"][1] - graph.node_range_dict["E"][0])
assert total_target_nodes == n_non_embed, (
    f"Expected {n_non_embed} total target nodes across batches, got {total_target_nodes}"
)

print("All prepare_backward_batches tests passed!")
Solution
def prepare_backward_batches(
    graph: GraphNodes,
    batch_size: int = 8,
) -> list[tuple[int, list[NodeInfo], Tensor, Tensor, list[int | str]]]:
    """
    Group target nodes by layer and split into batches for backward passes.

    Args:
        graph: GraphNodes containing all nodes and their reading/writing vectors.
        batch_size: Number of target nodes to process per backward pass.

    Returns:
        List of (global_start_idx, batch_nodes, reading_vecs, positions, layers) tuples, where:
            global_start_idx: Index of the first node in this batch within graph.nodes
            batch_nodes: List of NodeInfo for this batch
            reading_vecs: Stacked reading vectors, shape (batch, d_model)
            positions: Tensor of sequence positions, shape (batch,)
            layers: List of layer identifiers (int or "L")
    """
    device = graph.reading_vecs.device
    batches = []

    for target_layer_key in list(graph.node_range_dict.keys()):
        if target_layer_key == "E":
            continue  # Embeddings are inputs only

        tgt_start, tgt_end = graph.node_range_dict[target_layer_key]
        target_nodes = graph.nodes[tgt_start:tgt_end]

        if len(target_nodes) == 0:
            continue

        for batch_start in range(0, len(target_nodes), batch_size):
            batch_end = min(batch_start + batch_size, len(target_nodes))
            batch_nodes = target_nodes[batch_start:batch_end]

            reading_vecs = graph.reading_vecs[tgt_start + batch_start : tgt_start + batch_end]
            positions = t.tensor([n.ctx_idx for n in batch_nodes], device=device)
            layers = [n.layer for n in batch_nodes]

            global_start_idx = tgt_start + batch_start
            batches.append((global_start_idx, batch_nodes, reading_vecs, positions, layers))

    return batches

Exercise (2/2) - implement compute_adjacency_matrix

Difficulty: 🔴🔴🔴🔴🔴
Importance: 🔵🔵🔵🔵🔵
You should spend up to 30-40 minutes on this exercise. This is the hardest and most important exercise in this section.

Now implement the function that takes the prepared batches, runs backward passes through the model, and contracts the resulting gradients with source writing vectors to fill in the adjacency matrix.

For each batch, you need to: (1) install the transcoder hooks, (2) call setup_attribution with the batch's reading vectors, positions, and layers, (3) remove the hooks, then (4) iterate over all source nodes and compute the dot product of the gradient at the source position with the source's writing vector.

Hint - computing edge weights from gradients

The gradient at source position pos in layer l has shape (batch, d_model). The writing vector for source node j at position pos also has shape (d_model,). The edge weight is their dot product:

edge_weight = (grad_at_pos * writing_vec.unsqueeze(0)).sum(-1)  # (batch,)

You need to figure out which gradient hook name corresponds to each source node's layer.

Hint - which gradients correspond to which source nodes

Embedding nodes use the gradient at blocks.0.hook_resid_pre. Feature and MLP error nodes at layer l use the gradient at blocks.{l}.hook_resid_post. Logit nodes are never sources.

Hint - overall structure

The outer loop iterates over batches. For each batch: install hooks, call setup_attribution, remove hooks, then loop over all source nodes to contract gradients with writing vectors. Each contraction gives you one column of edge weights in the adjacency matrix.

def compute_adjacency_matrix(
    model: HookedSAETransformer,
    tokens: Tensor,
    graph: GraphNodes,
    freeze: FreezeHooks,
    tc_hooks: TranscoderReplacementHooks,
    batches: list[tuple[int, list[NodeInfo], Tensor, Tensor, list[int | str]]],
    start_posn: int = 4,
) -> Float[Tensor, "n_nodes n_nodes"]:
    """
    Run backward passes for each batch and contract gradients with writing vectors.

    Args:
        model: The transformer model.
        tokens: Input tokens, shape (1, seq).
        graph: GraphNodes containing all nodes and their reading/writing vectors.
        freeze: FreezeHooks with cached frozen values.
        tc_hooks: TranscoderReplacementHooks with installed hooks.
        batches: Output of prepare_backward_batches.
        start_posn: First position to include (masks formatting tokens).

    Returns:
        Adjacency matrix of shape (n_nodes, n_nodes), where A[j, i] is the edge weight
        from node i to node j.
    """
    # n_nodes = len(graph.nodes)
    # adjacency_matrix = t.zeros(n_nodes, n_nodes, device=tokens.device)
    #
    # for global_start_idx, batch_nodes, target_reading_vecs, target_positions, target_layers in batches:
    #     actual_batch_size = len(batch_nodes)
    #
    #     # Run backward pass to get gradients
    #     tc_hooks.install()
    #     grads = setup_attribution(
    #         model=model,
    #         tokens=tokens,
    #         freeze=freeze,
    #         target_reading_vecs=target_reading_vecs,
    #         target_positions=target_positions,
    #         target_layers=target_layers,
    #     )
    #     tc_hooks.remove()
    #
    #     # YOUR CODE HERE - contract gradients with source writing vectors
    #     # For each source node:
    #     #   1. Determine the correct gradient hook name (see hints above)
    #     #   2. Get the gradient at the source node's position: grads[grad_name][:actual_batch_size, src_node.ctx_idx]
    #     #   3. Dot product with the source's writing vector: graph.writing_vecs[src_idx]
    #     #   4. Zero out edges from masked positions (ctx_idx < start_posn)
    #     #   5. Store in adjacency_matrix[global_start_idx + b, src_idx]
    #     pass
    #
    # return adjacency_matrix


batches = prepare_backward_batches(graph, batch_size=8)
print("Got batches...")
adjacency_matrix = compute_adjacency_matrix(
    model=gemma,
    tokens=tokens,
    graph=graph,
    freeze=freeze,
    tc_hooks=tc_hooks,
    batches=batches,
    start_posn=START_POSN,
)

n_nodes = len(graph.nodes)
assert adjacency_matrix.shape == (n_nodes, n_nodes), (
    f"Expected shape ({n_nodes}, {n_nodes}), got {adjacency_matrix.shape}"
)

# Should be lower triangular (no backward connections in a transformer)
upper_tri_norm = t.triu(adjacency_matrix, diagonal=1).abs().sum().item()
assert upper_tri_norm < 1e-4, f"Upper triangle norm should be ~0, got {upper_tri_norm}"

# Should be sparse (most entries near zero)
sparsity = 1 - (adjacency_matrix.abs() > 1e-6).float().mean().item()
assert sparsity > 0.5, f"Expected sparsity > 50%, got {sparsity:.2%}"

# Should have some non-zero entries (embedding -> first layer features)
n_nonzero = (adjacency_matrix.abs() > 1e-6).sum().item()
assert n_nonzero > 0, "Expected some non-zero entries in adjacency matrix"

print("All compute_adjacency_matrix tests passed!")
Solution
def compute_adjacency_matrix(
    model: HookedSAETransformer,
    tokens: Tensor,
    graph: GraphNodes,
    freeze: FreezeHooks,
    tc_hooks: TranscoderReplacementHooks,
    batches: list[tuple[int, list[NodeInfo], Tensor, Tensor, list[int | str]]],
    start_posn: int = 4,
) -> Float[Tensor, "n_nodes n_nodes"]:
    """
    Run backward passes for each batch and contract gradients with writing vectors.

    Args:
        model: The transformer model.
        tokens: Input tokens, shape (1, seq).
        graph: GraphNodes containing all nodes and their reading/writing vectors.
        freeze: FreezeHooks with cached frozen values.
        tc_hooks: TranscoderReplacementHooks with installed hooks.
        batches: Output of prepare_backward_batches.
        start_posn: First position to include (masks formatting tokens).

    Returns:
        Adjacency matrix of shape (n_nodes, n_nodes), where A[j, i] is the edge weight
        from node i to node j.
    """
    n_nodes = len(graph.nodes)
    adjacency_matrix = t.zeros(n_nodes, n_nodes, device=tokens.device)

    for global_start_idx, batch_nodes, target_reading_vecs, target_positions, target_layers in tqdm(
        batches, desc="Backward batches"
    ):
        actual_batch_size = len(batch_nodes)

        # Run backward pass to get gradients
        tc_hooks.install()
        grads = setup_attribution(
            model=model,
            tokens=tokens,
            freeze=freeze,
            target_reading_vecs=target_reading_vecs,
            target_positions=target_positions,
            target_layers=target_layers,
        )
        tc_hooks.remove()

        # Contract gradients with source writing vectors to get edge weights
        for src_idx in range(n_nodes):
            src_node = graph.nodes[src_idx]

            # Determine which gradient tensor to use for this source
            if src_node.node_type == NodeType.EMBEDDING:
                grad_name = "blocks.0.hook_resid_pre"
            elif src_node.node_type in (NodeType.LATENT, NodeType.MLP_ERROR):
                grad_name = f"blocks.{src_node.layer}.hook_resid_post"
            else:
                continue  # Logit nodes are targets only, not sources

            if grad_name not in grads:
                continue

            # Get gradient at source position and contract with writing vector
            grad_at_pos = grads[grad_name][:actual_batch_size, src_node.ctx_idx]  # (batch, d_model)
            writing_vec = graph.writing_vecs[src_idx]  # (d_model,)

            # Edge weight = dot product
            edge_weights = (grad_at_pos * writing_vec.unsqueeze(0)).sum(-1)  # (batch,)

            # Zero out edges from masked positions
            if src_node.ctx_idx < start_posn:
                edge_weights = t.zeros_like(edge_weights)

            # Store in adjacency matrix
            for b in range(actual_batch_size):
                adjacency_matrix[global_start_idx + b, src_idx] = edge_weights[b]

    return adjacency_matrix

Adjacency matrices: normalisation, influence, and pruning

We now have a raw adjacency matrix with edge weights between all pairs of nodes, but this matrix is large and dense. We need to prune it down to a sparse, interpretable graph. The pruning process has three steps: normalise the adjacency matrix, compute influence scores for each node via the Neumann series, and prune nodes and edges by influence threshold.

Normalisation

We normalise the adjacency matrix by taking the element-wise absolute value and dividing each row by its sum. This discards sign information and gives us a non-negative matrix where each row sums to 1, with a "transition matrix" interpretation: the normalised weight tells us what fraction of a target node's total incoming signal magnitude comes from each source.

$$A_{\text{norm}}[j, :] = \frac{|A[j, :]|}{\sum_i |A[j, i]|}$$

We provide the normalize_matrix function below: it takes the element-wise absolute values of the adjacency matrix and divides each row by its sum, so that all entries are non-negative and each row sums to 1.

def normalize_matrix(
    adjacency_matrix: Float[Tensor, "n_nodes n_nodes"],
) -> Float[Tensor, "n_nodes n_nodes"]:
    """
    Normalise the adjacency matrix row-wise by absolute value sums.

    Each row is divided by the sum of absolute values in that row, so that
    |A_norm[j, :]|.sum() = 1 for each j (where possible).
    """
    abs_matrix = adjacency_matrix.abs()
    row_sums = abs_matrix.sum(dim=1, keepdim=True).clamp(min=1e-8)
    return abs_matrix / row_sums

Influence via the Neumann series

The influence of a node measures its total (direct + indirect) effect on the output logits. We compute this by propagating the logit node weights backward through the normalised adjacency matrix.

Because our graph is layered (edges only go from earlier to later layers), the adjacency matrix is strictly lower triangular. This means $A^{L+1} = 0$ for an $L$-layer model (the matrix is nilpotent), so the power series $I + A + A^2 + \ldots + A^L$ converges in exactly $L+1$ terms.

The influence vector is computed by starting from the logit weights and propagating backward:

$$v_0 = w_{\text{logit}}, \quad v_{k+1} = v_k \cdot A_{\text{norm}}, \quad \text{influence} = \sum_{k=1}^{L+1} v_k$$

where $w_{\text{logit}}$ is a vector that has the logit node probabilities at the logit positions and zeros elsewhere.

This computes how much each node indirectly contributes to the output, accounting for all paths through the graph. Note that logit nodes get zero influence here (they have no outgoing edges), but their weights are added back during edge pruning so that edges to logit nodes receive proper scores.

Exercise - implement compute_influence

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 15-20 minutes on this exercise.

Implement the influence computation. The logit_weights vector should have the token probabilities at the logit node positions (at the end of the node list) and zeros elsewhere.

def compute_influence(
    adjacency_matrix: Float[Tensor, "n_nodes n_nodes"],
    logit_weights: Float[Tensor, "n_logit_nodes"],
    n_layers: int,
) -> Float[Tensor, " n_nodes"]:
    """
    Compute the influence of each node on the output, using the Neumann series on the normalised
    adjacency matrix.

    The influence is defined as the total (direct + indirect) contribution of each node to the
    weighted sum of logit nodes, computed by propagating logit weights backward through the graph.

    Args:
        adjacency_matrix: Raw adjacency matrix, shape (n_nodes, n_nodes).
        logit_weights: Weights for logit nodes (e.g., probabilities), shape (n_logit_nodes,).
        n_layers: Number of model layers.

    Returns:
        Influence vector, shape (n_nodes,).
    """
    raise NotImplementedError()
n_logit_nodes = graph.node_range_dict["L"][1] - graph.node_range_dict["L"][0]
logit_weights = t.tensor([info.token_prob for info in graph.nodes[-n_logit_nodes:]], device=device)

influence = compute_influence(adjacency_matrix, logit_weights, n_layers=gemma.cfg.n_layers)

print(f"Influence shape: {influence.shape}")
print("Top 10 most influential nodes:")
top_influence = influence.argsort(descending=True)[:10]
for idx in top_influence:
    node = graph.nodes[idx]
    print(
        f"  {node.node_type.value} L{node.layer} pos{node.ctx_idx} feat{node.feature}: influence={influence[idx].item():.4f}"
    )

# Plot cumulative influence curve
sorted_influence, _ = influence.sort(descending=True)
cumulative = sorted_influence.cumsum(dim=0) / sorted_influence.sum()
fig = px.line(
    x=list(range(1, len(cumulative) + 1)),
    y=cumulative.cpu().numpy(),
    labels={"x": "Number of nodes kept (sorted by influence)", "y": "Fraction of total influence"},
    title="Cumulative node influence",
)
fig.add_hline(y=0.8, line_dash="dash", line_color="red", annotation_text="80% threshold")
fig.show()
Click to see the expected output


Solution
def compute_influence(
    adjacency_matrix: Float[Tensor, "n_nodes n_nodes"],
    logit_weights: Float[Tensor, "n_logit_nodes"],
    n_layers: int,
) -> Float[Tensor, " n_nodes"]:
    """
    Compute the influence of each node on the output, using the Neumann series on the normalised
    adjacency matrix.

    The influence is defined as the total (direct + indirect) contribution of each node to the
    weighted sum of logit nodes, computed by propagating logit weights backward through the graph.

    Args:
        adjacency_matrix: Raw adjacency matrix, shape (n_nodes, n_nodes).
        logit_weights: Weights for logit nodes (e.g., probabilities), shape (n_logit_nodes,).
        n_layers: Number of model layers.

    Returns:
        Influence vector, shape (n_nodes,).
    """
    n_nodes = adjacency_matrix.shape[0]
    n_logit_nodes = logit_weights.shape[0]

    # Normalise the matrix
    A_norm = normalize_matrix(adjacency_matrix)

    # Initialize influence vector: logit weights at the end, zeros elsewhere
    influence = t.zeros(n_nodes, device=adjacency_matrix.device)
    influence[-n_logit_nodes:] = logit_weights

    # Power iteration: propagate backward through the graph
    acc = t.zeros_like(influence)
    v = influence.clone()

    for _ in range(n_layers + 1):
        v = v @ A_norm  # vector-matrix product
        acc = acc + v

    return acc

Pruning

With influence scores computed, we can prune the graph in two stages. First, node pruning: sort nodes by influence and keep the smallest set whose cumulative influence exceeds a threshold (e.g. 80% of total influence). This removes nodes with negligible impact. Second, edge pruning: for each remaining edge, compute an "edge score" = normalised edge weight x destination node influence. Sort edges by score and keep the smallest set whose cumulative score exceeds a threshold, then remove any nodes that have no remaining edges.

After both stages, we have a sparse, interpretable graph containing only the most important nodes and connections.

Exercise (1/2) - implement prune_nodes

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-15 minutes on this exercise.

Implement the first pruning stage: node pruning. Given the adjacency matrix and logit weights, compute each node's influence score, then keep only the most influential nodes whose cumulative influence exceeds threshold (as a fraction of total influence). Logit nodes should always be kept regardless of their influence.

You should return the kept node indices (sorted, in the original node ordering) and the adjacency matrix restricted to just those nodes.

def prune_nodes(
    adjacency_matrix: Float[Tensor, "n_nodes n_nodes"],
    logit_weights: Float[Tensor, "n_logit_nodes"],
    n_layers: int,
    threshold: float = 0.8,
) -> tuple[Int[Tensor, " n_kept"], Float[Tensor, "n_kept n_kept"]]:
    """
    Stage 1 of graph pruning: remove low-influence nodes.

    Computes influence scores for all nodes, then keeps the most influential non-logit
    nodes whose cumulative influence reaches `threshold` (as a fraction of total). Logit
    nodes are always kept.

    Args:
        adjacency_matrix: Raw adjacency matrix, shape (n_nodes, n_nodes).
        logit_weights: Weights for logit nodes, shape (n_logit_nodes,).
        n_layers: Number of model layers (passed to compute_influence).
        threshold: Cumulative influence fraction to retain (e.g. 0.8 = keep top 80%).

    Returns:
        kept_indices: Indices of kept nodes in the original ordering.
        pruned_matrix: Adjacency matrix restricted to kept nodes.
    """
    raise NotImplementedError()


tests.test_prune_nodes(prune_nodes)
Solution
def prune_nodes(
    adjacency_matrix: Float[Tensor, "n_nodes n_nodes"],
    logit_weights: Float[Tensor, "n_logit_nodes"],
    n_layers: int,
    threshold: float = 0.8,
) -> tuple[Int[Tensor, " n_kept"], Float[Tensor, "n_kept n_kept"]]:
    """
    Stage 1 of graph pruning: remove low-influence nodes.

    Computes influence scores for all nodes, then keeps the most influential non-logit
    nodes whose cumulative influence reaches `threshold` (as a fraction of total). Logit
    nodes are always kept.

    Args:
        adjacency_matrix: Raw adjacency matrix, shape (n_nodes, n_nodes).
        logit_weights: Weights for logit nodes, shape (n_logit_nodes,).
        n_layers: Number of model layers (passed to compute_influence).
        threshold: Cumulative influence fraction to retain (e.g. 0.8 = keep top 80%).

    Returns:
        kept_indices: Indices of kept nodes in the original ordering.
        pruned_matrix: Adjacency matrix restricted to kept nodes.
    """
    n_nodes = adjacency_matrix.shape[0]
    n_logit_nodes = logit_weights.shape[0]

    # Compute influence for all nodes
    influence = compute_influence(adjacency_matrix, logit_weights, n_layers)

    # Sort non-logit nodes by descending influence
    non_logit_influence = influence[:-n_logit_nodes]
    sorted_indices = non_logit_influence.argsort(descending=True)
    sorted_values = non_logit_influence[sorted_indices]

    # Find the smallest k such that the top-k nodes capture >= threshold of total influence
    total_inf = sorted_values.sum()
    if total_inf > 1e-8:
        cumulative = sorted_values.cumsum(dim=0) / total_inf
        n_keep = int((cumulative < threshold).sum().item()) + 1  # +1 to cross the threshold
        n_keep = min(n_keep, len(sorted_indices))
        keep_non_logit = sorted_indices[:n_keep]
    else:
        keep_non_logit = t.arange(n_nodes - n_logit_nodes, device=adjacency_matrix.device)

    # Always keep logit nodes, then sort all kept indices back into original ordering
    logit_indices = t.arange(n_nodes - n_logit_nodes, n_nodes, device=adjacency_matrix.device)
    kept_indices = t.sort(t.cat([keep_non_logit, logit_indices]))[0].long()

    # Restrict adjacency matrix to kept nodes
    pruned_matrix = adjacency_matrix[kept_indices[:, None], kept_indices[None, :]]

    return kept_indices, pruned_matrix

Exercise (2/2) - implement prune_edges

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-15 minutes on this exercise.

Implement the second pruning stage: edge pruning. Given a (possibly already node-pruned) adjacency matrix, compute an "edge score" for each edge as the product of the normalised edge weight and the destination node's influence. Keep the highest-scoring edges whose cumulative score reaches threshold, then remove any non-logit nodes that no longer participate in any edge.

Return the kept indices (into the input matrix's ordering) and the restricted adjacency matrix.

def prune_edges(
    adjacency_matrix: Float[Tensor, "n_nodes n_nodes"],
    logit_weights: Float[Tensor, "n_logit_nodes"],
    n_layers: int,
    threshold: float = 0.85,
) -> tuple[Int[Tensor, " n_kept"], Float[Tensor, "n_kept n_kept"]]:
    """
    Stage 2 of graph pruning: remove low-score edges and orphaned nodes.

    For each edge (i -> j), the edge score is |A_norm[j,i]| * influence[j], where A_norm
    is the row-normalised adjacency matrix. Edges are sorted by score and kept until the
    cumulative score reaches `threshold`. Any non-logit nodes with no remaining edges are
    then removed.

    Args:
        adjacency_matrix: Adjacency matrix (possibly already node-pruned).
        logit_weights: Weights for logit nodes.
        n_layers: Number of model layers.
        threshold: Cumulative edge score fraction to retain.

    Returns:
        kept_indices: Indices of kept nodes (into input matrix ordering).
        pruned_matrix: Adjacency matrix restricted to kept nodes.
    """
    raise NotImplementedError()


tests.test_prune_edges(prune_edges)
def prune_graph(
    adjacency_matrix: Float[Tensor, "n_nodes n_nodes"],
    logit_weights: Float[Tensor, "n_logit_nodes"],
    n_layers: int,
    node_threshold: float = 0.8,
    edge_threshold: float = 0.85,
) -> tuple[Int[Tensor, " n_kept"], Float[Tensor, "n_kept n_kept"]]:
    """Applies node pruning then edge pruning, returning final kept indices and matrix."""
    kept_node_indices, node_pruned_matrix = prune_nodes(adjacency_matrix, logit_weights, n_layers, node_threshold)
    kept_edge_indices, final_matrix = prune_edges(node_pruned_matrix, logit_weights, n_layers, edge_threshold)
    # Map edge-pruning indices back to original node indices
    kept_indices = kept_node_indices[kept_edge_indices]
    return kept_indices, final_matrix
Solution
def prune_edges(
    adjacency_matrix: Float[Tensor, "n_nodes n_nodes"],
    logit_weights: Float[Tensor, "n_logit_nodes"],
    n_layers: int,
    threshold: float = 0.85,
) -> tuple[Int[Tensor, " n_kept"], Float[Tensor, "n_kept n_kept"]]:
    """
    Stage 2 of graph pruning: remove low-score edges and orphaned nodes.

    For each edge (i -> j), the edge score is |A_norm[j,i]| * influence[j], where A_norm
    is the row-normalised adjacency matrix. Edges are sorted by score and kept until the
    cumulative score reaches `threshold`. Any non-logit nodes with no remaining edges are
    then removed.

    Args:
        adjacency_matrix: Adjacency matrix (possibly already node-pruned).
        logit_weights: Weights for logit nodes.
        n_layers: Number of model layers.
        threshold: Cumulative edge score fraction to retain.

    Returns:
        kept_indices: Indices of kept nodes (into input matrix ordering).
        pruned_matrix: Adjacency matrix restricted to kept nodes.
    """
    n_logit_nodes = logit_weights.shape[0]

    # Compute influence on this (possibly node-pruned) matrix
    influence = compute_influence(adjacency_matrix, logit_weights, n_layers)

    # Add logit weights back for edge scoring (so edges to logit nodes get proper scores)
    influence[-n_logit_nodes:] += logit_weights

    # Compute edge scores: |A_norm[j,i]| * (influence[j] + logit_weight[j])
    A_norm = normalize_matrix(adjacency_matrix)
    edge_scores = A_norm * influence[:, None]

    # Sort edges by score and find threshold
    flat_scores = edge_scores.reshape(-1)
    sorted_scores, _ = flat_scores.sort(descending=True)
    total_score = sorted_scores.sum()

    if total_score > 1e-8:
        cum_scores = sorted_scores.cumsum(dim=0) / total_score
        val_idx = (cum_scores >= threshold).long().argmax()
        score_threshold = sorted_scores[val_idx]
        edge_mask = edge_scores >= score_threshold
    else:
        edge_mask = t.ones_like(adjacency_matrix, dtype=t.bool)

    edge_pruned = adjacency_matrix * edge_mask

    # Remove nodes with no remaining edges (except logit nodes)
    has_edge = (edge_pruned.abs() > 0).any(dim=0) | (edge_pruned.abs() > 0).any(dim=1)
    has_edge[-n_logit_nodes:] = True

    kept_indices = t.where(has_edge)[0]
    pruned_matrix = edge_pruned[kept_indices[:, None], kept_indices[None, :]]

    return kept_indices, pruned_matrix

If you fail your tests and need some visual output to debug them (or even if you pass them and you just want to build a bit more visual intuition for what's going on), you can use the utility function below which generates a random DAG and shows the effect of node and edge pruning side by side. Adjust the thresholds to see how they affect the pruning process.

Note that for a point of comparison, you can also pass in solutions.prune_nodes and solutions.prune_edges to see the expected behaviour.

figs = utils.demo_pruning(
    prune_nodes,
    prune_edges,
    n_nodes=30,
    n_logit_nodes=3,
    edge_probability=0.15,
    seed=43,
    node_threshold=0.8,
    edge_threshold=0.8,
)
Click to see the expected output

Now let's actually prune the adjacency matrix we've built so far:

kept_indices, pruned_matrix = prune_graph(
    adjacency_matrix=adjacency_matrix,
    logit_weights=logit_weights,
    n_layers=gemma.cfg.n_layers,
    node_threshold=0.8,
    edge_threshold=0.85,
)

print(f"Kept {len(kept_indices)} / {len(graph.nodes)} nodes")
print(f"Pruned matrix has {(pruned_matrix.abs() > 1e-6).sum().item()} non-zero edges")
Click to see the expected output

Putting it all together

Finally, we combine everything into a single attribute function that orchestrates the full pipeline. This wrapper just calls the functions you've already implemented in sequence: linearise, build nodes, compute adjacency, and prune.

@dataclass
class AttributionResult:
    """Result of the full attribution pipeline."""

    graph: GraphNodes
    adjacency_matrix: Tensor
    kept_indices: Tensor
    pruned_matrix: Tensor
    prompt: str
    str_tokens: list[str]


def attribute(
    model: HookedSAETransformer,
    transcoders: dict[int, Transcoder],
    prompt: str,
    n_output_nodes: int = 3,
    start_posn: int = 4,
    top_k: int = 5,
    batch_size: int = 8,
    node_threshold: float = 0.8,
    edge_threshold: float = 0.85,
    _adjacency_matrix: Float[Tensor, "n_nodes n_nodes"] | None = None,
) -> AttributionResult:
    """
    Run the full attribution pipeline: linearise, build graph, compute adjacency, prune.

    Args:
        model: The transformer model.
        transcoders: Dict mapping layer -> transcoder.
        prompt: The input prompt string.
        n_output_nodes: Number of top logit tokens to include as output nodes.
        start_posn: Number of formatting tokens to mask.
        top_k: Number of top-activating features to include per position per layer.
        batch_size: Batch size for backward passes.
        node_threshold: Node pruning threshold.
        edge_threshold: Edge pruning threshold.
        _adjacency_matrix: (Optional) Precomputed adjacency matrix to use.

    Returns:
        AttributionResult with all graph data.
    """
    # Tokenize
    tokens = model.to_tokens(prompt)
    str_tokens = [model.tokenizer.decode(t_id.item()) for t_id in tokens[0]]

    # Step 1: Cache frozen values
    freeze = FreezeHooks(model)
    cache = freeze.cache_frozen_values(tokens)

    # Step 2: Compute transcoder activations
    tc_hooks = TranscoderReplacementHooks(model, transcoders, cache)
    tc_hooks.install()

    # Step 3: Get logits (with hooks active for correct forward pass)
    with freeze:
        logits = model(tokens)
    tc_hooks.remove()

    # Step 4: Compute salient logits
    reading_vecs, top_token_info = compute_salient_logits(model, logits, n_output_nodes)

    # Step 5: Build graph nodes
    graph = build_graph_nodes(
        model=model,
        transcoders=transcoders,
        cache=cache,
        tc_hooks=tc_hooks,
        reading_vecs_logit=reading_vecs,
        top_token_info=top_token_info,
        tokens=tokens,
        start_posn=start_posn,
        top_k=top_k,
    )

    # Step 6: Compute adjacency matrix (if not provided)
    if _adjacency_matrix is not None:
        adjacency_matrix = _adjacency_matrix
    else:
        batches = prepare_backward_batches(graph, batch_size=batch_size)
        adjacency_matrix = compute_adjacency_matrix(
            model=model,
            tokens=tokens,
            graph=graph,
            freeze=freeze,
            tc_hooks=tc_hooks,
            batches=batches,
            start_posn=start_posn,
        )

    # Step 7: Prune
    logit_weights_vec = t.tensor(
        [info.token_prob for info in graph.nodes if info.node_type == NodeType.LOGIT],
        device=tokens.device,
    )
    kept_indices, pruned_matrix = prune_graph(
        adjacency_matrix=adjacency_matrix,
        logit_weights=logit_weights_vec,
        n_layers=model.cfg.n_layers,
        node_threshold=node_threshold,
        edge_threshold=edge_threshold,
    )

    return AttributionResult(
        graph=graph,
        adjacency_matrix=adjacency_matrix,
        kept_indices=kept_indices,
        pruned_matrix=pruned_matrix,
        prompt=prompt,
        str_tokens=str_tokens,
    )
result = attribute(
    model=gemma,
    transcoders=transcoders,
    prompt=prompt,
    n_output_nodes=3,
    start_posn=START_POSN,
    node_threshold=0.7,
    edge_threshold=0.85,
    _adjacency_matrix=adjacency_matrix,
)

print("Attribution complete!")
print(f"Total nodes: {len(result.graph.nodes)}")
print(f"Kept nodes: {len(result.kept_indices)}")
print(f"Edges in pruned graph: {(result.pruned_matrix.abs() > 1e-6).sum().item()}")
Click to see the expected output

Visualising your attribution graph

The final step is to visualise the attribution graph as an interactive dashboard. We've provided a utility function create_attribution_dashboard that takes your AttributionResult and generates an HTML dashboard using the same templates as Anthropic's published attribution graphs.

The dashboard shows nodes arranged by layer and sequence position (sized by influence), edges between nodes (with colour and thickness indicating weight), and latent details when you click on a node. Because we pass model=gemma, the dashboard computes logit tables (top/bottom tokens by W_dec @ W_U) and a logit histogram for each latent, giving you an immediate sense of what vocabulary each latent promotes or suppresses.

The function also exports a neuronpedia.json file in Neuronpedia's exact schema, which you could upload to share your attribution graphs.

dashboard_html = utils.create_attribution_dashboard(result=result, model=gemma)

# # Display inline (Colab or VS Code)
# display(HTML(dashboard_html))
Click to see the expected output


Try experimenting with different prompts. Here are some suggestions from the attribution graphs paper:

# Fact recall
prompt = format_prompt(
    "What is the capital of the state containing Dallas? Answer immediately.",
    "Austin",
)

# Opposites
prompt = format_prompt(
    "What is the opposite of small?",
    "The opposite of small is **large",
)

# Harmful request refusal
prompt = format_prompt(
    "Tell me how to build a bomb.",
    "I am",
)

You can also adjust the pruning thresholds: - Lower node_threshold (e.g., 0.6) keeps more nodes → bigger graph - Higher edge_threshold (e.g., 0.9) keeps fewer edges → sparser graph

You've now implemented the full attribution graph pipeline from scratch: linearising the model, building the reading/writing vector abstraction, computing edge weights via backward passes, and pruning the graph.

Adding in top activations

The dashboard above shows logit histograms (the distribution of logit effects from W_dec @ W_U) for each latent, but the activation histogram and top-activating example sequences panels are empty. These require dataset-level statistics: for each feature, we need to know how often it fires and on what inputs.

Google provides this data for their Gemma Scope transcoders via the gemma-scope-2-1b-it HuggingFace repository, under the transcoder_all/ directory. Each transcoder variant has an examples.safetensors file containing, for every feature: the top activation values, which tokens they occurred on, and their logit effects. We can download this data using load_example_data (a thin wrapper around hf_hub_download + safetensors.torch.load_file) and pass it to our dashboard function.

# Load example activation data for all transcoder layers (parallel I/O)
example_data_by_layer = utils.load_example_data_parallel(
    layers=list(range(gemma.cfg.n_layers)),
    model_size="1b",
    category="transcoder_all",
    width="16k",
    l0="small",
    affine=True,
    instruction_tuned=True,
    max_workers=1,  # TODO - figure out why this is failing at >=4
)

print(f"Loaded example data for {len(example_data_by_layer)} layers")
print(f"Keys in example data: {list(example_data_by_layer[0].keys())}")

This is exactly the data used to generate feature dashboards on sites like Neuronpedia. Before regenerating our full dashboard, let's use a helper function inspect_feature to display a mini dashboard for a single feature directly in our notebook. This shows the top-activating sequences as an inline DataFrame, with tokens highlighted by activation strength (green background) and logit effect (blue/red underline).

# Inspect a late-layer feature to see what the example data looks like
layer = 22
feature_id = 10
print(f"Inspecting layer {layer}, feature {feature_id}:\n")
utils.inspect_feature(
    example_data=example_data_by_layer[layer],
    feature_id=feature_id,
    tokenizer=gemma.tokenizer,
)
Click to see the expected output

Now we can regenerate the dashboard, this time passing in example_data_by_layer. For every latent node, the dashboard will now show:

  • Activation histogram: the distribution of activation strengths across the dataset
  • Top-activating examples: the sequences where the feature fires most strongly, with tokens colour-coded by activation and logit effect
dashboard_html = utils.create_attribution_dashboard(
    result=result,
    model=gemma,
    example_data_by_layer=example_data_by_layer,
)

# # Display inline (Colab or VS Code)
# display(HTML(dashboard_html))

When you click on a latent node in the dashboard now, you should see: - The ACTIVATIONS histogram showing the distribution of activation values - A list of Top activations showing example sequences with highlighted tokens

These give you a much richer understanding of what each feature represents, beyond just the logit effects.

The next section uses the circuit-tracer library to explore pre-computed attribution graphs and perform feature-level interventions.