Exercise Status: All exercises complete and verified

☆ Bonus

Manual attribution graphs

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵⚪⚪⚪
You should spend up to 30-40 minutes on these exercises. They deepen your understanding of the linearised model but are not required.

In section 3️⃣, we computed attribution edges using the automatic (gradient-based) method: inject reading vectors as gradient seeds and backward-propagate through the frozen model. This is efficient and correct, but it treats the model as a black box.

An alternative is the manual (forward-tracing) method: explicitly take each source node's writing vector and map it through the frozen intermediate layers (attention + skip connections) until it reaches the target node, then take the dot product with the target's reading vector. This gives the same answer as the gradient method (because the frozen model is linear), but it's much more transparent and useful for debugging.

First, let's recreate the cache we need (this was computed inside attribute() but not returned):

freeze_bonus = FreezeHooks(gemma)
cache_bonus = freeze_bonus.cache_frozen_values(tokens)

Exercise - implement map_through_ln, map_through_attn, map_through_mlp

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵⚪⚪⚪
You should spend up to 15-20 minutes on this exercise.

These three helper functions map attribution vectors through the frozen (linearised) model components.

map_through_ln applies frozen RMSNorm. Since the normalization scale is frozen, this is simply x * cached_scale * weight (a linear operation). map_through_attn maps through frozen attention: apply frozen LN, then W_V projection, then multiply by frozen attention patterns, then W_O projection. This is how information flows across sequence positions. map_through_mlp maps through the frozen MLP skip connection: apply frozen LN, then multiply by W_skip. This is the linear approximation of the MLP that gradients flow through.

All three functions support arbitrary batch dimensions (indicated by ... in the einsum patterns), which will be needed when we trace multiple source nodes simultaneously.

def map_through_ln(
    x: Float[Tensor, "... seq d_model"],
    cache: ActivationCache,
    model: HookedSAETransformer,
    layer: int | None,
    is_mlp_ln: bool = False,
) -> Float[Tensor, "... seq d_model"]:
    """Apply frozen RMSNorm: output = x / cached_scale * weight.

    Args:
        x: Input vectors to normalise.
        cache: ActivationCache with frozen scale values.
        model: The transformer model (for LN weight parameters).
        layer: Layer index, or None for the final LayerNorm.
        is_mlp_ln: If True, use ln2 (pre-MLP); otherwise use ln1 (pre-attention).
    """
    raise NotImplementedError()


def map_through_attn(
    resid_pre: Float[Tensor, "... seq d_model"],
    cache: ActivationCache,
    model: HookedSAETransformer,
    layer: int,
) -> Float[Tensor, "... seq d_model"]:
    """Map vectors through frozen attention: LN → W_V → frozen_patterns → W_O → post-norm.

    This is where cross-position information flow happens: a vector at position p
    gets redistributed across all positions according to the frozen attention patterns.
    Includes the post-attention sandwich norm (ln1_post) used by Gemma 3.

    Args:
        resid_pre: Residual stream vectors before this attention layer.
        cache: ActivationCache with frozen attention patterns and LN scales.
        model: The transformer model (for W_V, W_O weight matrices).
        layer: Which attention layer to map through.
    """
    raise NotImplementedError()


def map_through_mlp(
    resid_mid: Float[Tensor, "... seq d_model"],
    cache: ActivationCache,
    model: HookedSAETransformer,
    transcoders: dict[int, "Transcoder"],
    layer: int,
) -> Float[Tensor, "... seq d_model"]:
    """Map vectors through frozen MLP skip connection: LN → W_skip → post-norm.

    Includes the post-MLP sandwich norm (ln2_post) used by Gemma 3.

    Args:
        resid_mid: Residual stream vectors after attention (before MLP) at this layer.
        cache: ActivationCache with frozen LN scales.
        model: The transformer model (for LN weight parameters).
        transcoders: Dict mapping layer -> transcoder (for W_skip).
        layer: Which MLP layer to map through.
    """
    raise NotImplementedError()
tests.test_map_through_ln(map_through_ln, gemma, cache_bonus)
tests.test_map_through_attn(map_through_attn, gemma, cache_bonus)
tests.test_map_through_mlp(map_through_mlp, gemma, cache_bonus, transcoders)
Solution
def map_through_ln(
    x: Float[Tensor, "... seq d_model"],
    cache: ActivationCache,
    model: HookedSAETransformer,
    layer: int | None,
    is_mlp_ln: bool = False,
) -> Float[Tensor, "... seq d_model"]:
    """Apply frozen RMSNorm: output = x / cached_scale * weight.

    Args:
        x: Input vectors to normalise.
        cache: ActivationCache with frozen scale values.
        model: The transformer model (for LN weight parameters).
        layer: Layer index, or None for the final LayerNorm.
        is_mlp_ln: If True, use ln2 (pre-MLP); otherwise use ln1 (pre-attention).
    """
    if layer is None:
        scale = cache["ln_final.hook_scale"]
        weight = model.ln_final.w
    elif is_mlp_ln:
        scale = cache[f"blocks.{layer}.ln2.hook_scale"]
        weight = model.blocks[layer].ln2.w
    else:
        scale = cache[f"blocks.{layer}.ln1.hook_scale"]
        weight = model.blocks[layer].ln1.w
    return x / scale * weight


def map_through_attn(
    resid_pre: Float[Tensor, "... seq d_model"],
    cache: ActivationCache,
    model: HookedSAETransformer,
    layer: int,
) -> Float[Tensor, "... seq d_model"]:
    """Map vectors through frozen attention: LN → W_V → frozen_patterns → W_O → post-norm.

    This is where cross-position information flow happens: a vector at position p
    gets redistributed across all positions according to the frozen attention patterns.
    Includes the post-attention sandwich norm (ln1_post) used by Gemma 3.

    Args:
        resid_pre: Residual stream vectors before this attention layer.
        cache: ActivationCache with frozen attention patterns and LN scales.
        model: The transformer model (for W_V, W_O weight matrices).
        layer: Which attention layer to map through.
    """
    x = map_through_ln(resid_pre, cache, model, layer=layer)

    W_V = model.blocks[layer].attn.W_V  # (n_heads, d_model, d_head)
    v = einops.einsum(x, W_V, "... src d_model, n_heads d_model d_head -> ... src n_heads d_head")

    patterns = cache[f"blocks.{layer}.attn.hook_pattern"]  # (1, n_heads, dest, src)
    z = einops.einsum(v, patterns[0], "... src n_heads d_head, n_heads dest src -> ... dest n_heads d_head")

    W_O = model.blocks[layer].attn.W_O  # (n_heads, d_head, d_model)
    result = einops.einsum(z, W_O, "... dest n_heads d_head, n_heads d_head d_model -> ... dest d_model")

    # Post-attention sandwich norm (Gemma 3)
    if hasattr(model.blocks[layer], "ln1_post"):
        post_scale = cache[f"blocks.{layer}.ln1_post.hook_scale"]
        post_w = model.blocks[layer].ln1_post.w
        result = result / post_scale * post_w

    return result


def map_through_mlp(
    resid_mid: Float[Tensor, "... seq d_model"],
    cache: ActivationCache,
    model: HookedSAETransformer,
    transcoders: dict[int, "Transcoder"],
    layer: int,
) -> Float[Tensor, "... seq d_model"]:
    """Map vectors through frozen MLP skip connection: LN → W_skip → post-norm.

    Includes the post-MLP sandwich norm (ln2_post) used by Gemma 3.

    Args:
        resid_mid: Residual stream vectors after attention (before MLP) at this layer.
        cache: ActivationCache with frozen LN scales.
        model: The transformer model (for LN weight parameters).
        transcoders: Dict mapping layer -> transcoder (for W_skip).
        layer: Which MLP layer to map through.
    """
    x = map_through_ln(resid_mid, cache, model, layer=layer, is_mlp_ln=True)
    result = x @ transcoders[layer].W_skip

    # Post-MLP sandwich norm (Gemma 3)
    if hasattr(model.blocks[layer], "ln2_post"):
        post_scale = cache[f"blocks.{layer}.ln2_post.hook_scale"]
        post_w = model.blocks[layer].ln2_post.w
        result = result / post_scale * post_w

    return result

Exercise - implement compute_adjacency_matrix_manual

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵⚪⚪⚪
You should spend up to 20-25 minutes on this exercise.

Now use these helpers to compute the full adjacency matrix by explicit forward-tracing. For each (source layer, target layer) pair, you: (1) place each source node's writing vector at its sequence position, giving a tensor of shape (n_source_nodes, seq_len, d_model), (2) map through intermediate layers by adding the attention output (cross-position flow) and MLP skip output (within-position linear transform) to the residual at each layer, (3) handle the target layer (for latent/error targets: add attention at the target layer to get resid_mid, then apply the pre-MLP LayerNorm ln2; for logit targets: apply the final LayerNorm ln_final), and (4) dot the mapped vectors at each target node's position with the target's reading vector.

Remember: node_range_dict maps layer keys ("E", 0, 1, ..., "L") to (start_idx, end_idx) ranges into the node list. Embedding nodes ("E") are sources only; logit nodes ("L") are targets only. A source node at layer l writes to resid_post[l] (after the MLP), so intermediate layers start at l+1. Embedding sources write to resid_pre[0], so their intermediate layers start at 0.

Hint - loop structure
for src_key in layer_keys:
    if src_key == "L" or src_key not in graph.node_range_dict:
        continue
    src_start, src_end = graph.node_range_dict[src_key]
    src_writes_after = -1 if src_key == "E" else src_key  # numeric layer

    # Initialise: place writing vectors at source positions
    vecs = t.zeros(n_src, seq_len, d_model, device=device)
    for i in range(n_src):
        node = graph.nodes[src_start + i]
        vecs[i, node.ctx_idx] = graph.writing_vecs[src_start + i]

    for tgt_key in layer_keys:
        if tgt_key == "E" or tgt_key not in graph.node_range_dict:
            continue
        tgt_layer_num = n_layers if tgt_key == "L" else tgt_key
        if src_writes_after >= tgt_layer_num:
            continue

        mapped = vecs.clone()
        # ... map through intermediate layers, handle target, compute edges
def compute_adjacency_matrix_manual(
    model: HookedSAETransformer,
    cache: ActivationCache,
    graph: GraphNodes,
    transcoders: dict[int, "Transcoder"],
) -> Float[Tensor, "n_nodes n_nodes"]:
    """Compute the attribution adjacency matrix using explicit forward-tracing.

    For each (source_layer, target_layer) pair, traces the source writing vectors through
    all intermediate frozen layers and computes dot products with target reading vectors.

    Args:
        model: The transformer model.
        cache: ActivationCache with frozen values (from FreezeHooks.cache_frozen_values).
        graph: GraphNodes containing all node metadata, writing_vecs, reading_vecs.
        transcoders: Dict mapping layer -> transcoder (for W_skip in MLP skip connections).

    Returns:
        Adjacency matrix of shape (n_nodes, n_nodes), where A[target, source] is the edge weight.
    """
    n_nodes = len(graph.nodes)
    seq_len = graph.seq_len
    n_layers = graph.n_layers
    d_model = model.cfg.d_model
    device = graph.writing_vecs.device
    adjacency = t.zeros(n_nodes, n_nodes, device=device)

    layer_keys = ["E"] + list(range(n_layers)) + ["L"]

    # For each source layer:
    #   For each target layer (that comes after the source):
    #     (1) Map source writing vectors through intermediate layers using map_through_attn and map_through_mlp
    #     (2) Handle the target layer (attention + LN for latents, final LN for logits)
    #     (3) Compute dot products with target reading vectors at their positions
    raise NotImplementedError()

    return adjacency
manual_adj = compute_adjacency_matrix_manual(gemma, cache_bonus, result.graph, transcoders)
tests.test_compute_adjacency_matrix_manual(manual_adj, result.adjacency_matrix)
Solution
def compute_adjacency_matrix_manual(
    model: HookedSAETransformer,
    cache: ActivationCache,
    graph: GraphNodes,
    transcoders: dict[int, "Transcoder"],
) -> Float[Tensor, "n_nodes n_nodes"]:
    """Compute the attribution adjacency matrix using explicit forward-tracing.

    For each (source_layer, target_layer) pair, traces the source writing vectors through
    all intermediate frozen layers and computes dot products with target reading vectors.

    Args:
        model: The transformer model.
        cache: ActivationCache with frozen values (from FreezeHooks.cache_frozen_values).
        graph: GraphNodes containing all node metadata, writing_vecs, reading_vecs.
        transcoders: Dict mapping layer -> transcoder (for W_skip in MLP skip connections).

    Returns:
        Adjacency matrix of shape (n_nodes, n_nodes), where A[target, source] is the edge weight.
    """
    n_nodes = len(graph.nodes)
    seq_len = graph.seq_len
    n_layers = graph.n_layers
    d_model = model.cfg.d_model
    device = graph.writing_vecs.device
    adjacency = t.zeros(n_nodes, n_nodes, device=device)

    layer_keys = ["E"] + list(range(n_layers)) + ["L"]

    for src_key in layer_keys:
        if src_key == "L" or src_key not in graph.node_range_dict:
            continue
        src_start, src_end = graph.node_range_dict[src_key]
        n_src = src_end - src_start
        if n_src == 0:
            continue

        # Numeric layer after which source writes (-1 for embeddings = before layer 0)
        src_writes_after = -1 if src_key == "E" else src_key

        # Initialise: place each source node's writing vector at its sequence position
        vecs = t.zeros(n_src, seq_len, d_model, device=device)
        for i in range(n_src):
            node = graph.nodes[src_start + i]
            vecs[i, node.ctx_idx] = graph.writing_vecs[src_start + i]

        for tgt_key in layer_keys:
            if tgt_key == "E" or tgt_key not in graph.node_range_dict:
                continue
            tgt_start, tgt_end = graph.node_range_dict[tgt_key]
            if tgt_end == tgt_start:
                continue

            tgt_layer_num = n_layers if tgt_key == "L" else tgt_key
            if src_writes_after >= tgt_layer_num:
                continue  # Source must come strictly before target

            # Map through intermediate layers (full attn + MLP skip at each)
            mapped = vecs.clone()
            end_layer = n_layers if tgt_key == "L" else tgt_layer_num
            for layer in range(src_writes_after + 1, end_layer):
                mapped = mapped + map_through_attn(mapped, cache, model, layer)
                mapped = mapped + map_through_mlp(mapped, cache, model, transcoders, layer)

            # Handle target layer
            if tgt_key == "L":
                # Logit targets: apply final LayerNorm
                mapped = map_through_ln(mapped, cache, model, layer=None)
            else:
                # Latent/error targets: attention at target layer, then pre-MLP LN
                mapped = mapped + map_through_attn(mapped, cache, model, tgt_layer_num)
                mapped = map_through_ln(mapped, cache, model, tgt_layer_num, is_mlp_ln=True)

            # Compute edge weights via dot products at target positions
            for j in range(tgt_end - tgt_start):
                tgt_node = graph.nodes[tgt_start + j]
                tgt_reading = graph.reading_vecs[tgt_start + j]
                edges = einops.einsum(
                    mapped[:, tgt_node.ctx_idx],
                    tgt_reading,
                    "n_src d_model, d_model -> n_src",
                )
                adjacency[tgt_start + j, src_start:src_end] = edges

    return adjacency

AutoInterp

Attribution graphs give us a picture of which features matter and how they connect, but they don't tell us what those features represent. For a graph to be genuinely interpretable, we need human-readable labels for each node. This is where automated interpretability (autointerp) comes in.

If you completed the exercises in section 1.3.3 (Automated Interpretability), you already have a full autointerp implementation that you can apply here. That section covers the full pipeline: collecting max-activating examples for a latent, passing them to an LLM to generate natural-language explanations, and scoring those explanations. All of those tools transfer directly to labelling attribution graph nodes.

One key advantage of combining autointerp with attribution graphs is cost: a pruned attribution graph typically contains only 10-50 active latent nodes, compared to the thousands or millions of latents in a full SAE/transcoder. This means you can afford to run high-quality explanation generation (e.g. using a large model with many examples per latent) on every node in the graph, which would be prohibitively expensive across the full set of latents.

Here's one concrete approach:

  1. Identify the top-k latent nodes in your pruned graph (those with highest influence scores). The result.graph object stores which latents were kept and their activation values. You can iterate over nodes with:
for node in result.graph.nodes:
    if node.node_type == NodeType.LATENT:
        layer, ctx_idx, feature = node.layer, node.ctx_idx, node.feature
        # ... fetch or generate explanations for this latent
  1. Fetch or generate explanations for each latent. If Neuronpedia has explanations for your SAE release, fetch them with get_autointerp_df from 1.3.3. Otherwise, collect max-activating examples with fetch_max_activating_examples and run your LLM-based explanation pipeline. Since there are so few nodes, you can afford to be thorough.

  2. Annotate your attribution graph by adding the explanation text to each node's label in the visualization.

This is left as open-ended exploration rather than a structured exercise, since the implementation depends on how you structured your autointerp code in 1.3.3 and whether Neuronpedia hosts explanations for the transcoders you're using. A good starting point is to pick one of the circuits from section 4 (e.g. the Dallas circuit) and generate explanations for its top 5 most influential latent nodes.