Exercise Status: All exercises complete and verified

3️⃣ Causal interventions

Learning Objectives
  • Understand why classification accuracy alone is insufficient - causal evidence is needed
  • Implement activation patching with probe directions to flip model predictions
  • Compare the causal effects of MM vs. LR probe directions
  • Appreciate that MM probes find more causally implicated directions despite lower classification accuracy

Before diving into the exercises, here's a preview of the key findings from the Geometry of Truth paper that Sections 1–3 are designed to reproduce:

  • Truth representations become more general with scale. Larger models show a more abstract, cross-domain notion of truth that applies across structurally and topically diverse datasets (cities, numerical comparisons, translations). At smaller scales the representations are more dataset-specific.
  • Only a small subset of hidden states are causally implicated. It's not the case that all layers and positions contribute equally - the truth direction is concentrated in early-to-mid layers, and only activations at specific layer/token combinations actually drive the model's truth judgments when you intervene.
  • Difference-of-means probes find the most causally relevant directions. Despite logistic regression probes achieving higher classification accuracy, the simpler difference-of-means direction turns out to be more causally implicated in model outputs. This disconnect between accuracy and causal relevance is a recurring theme in mechanistic interpretability.

You've verified the first two points in Sections 1 and 2. This section establishes the third via activation patching.

So far, we've shown that truth is linearly separable in activation space and trained probes that can classify it. But classification accuracy alone doesn't prove the model uses these representations. A probe might find a direction that's correlated with truth but not causally involved in the model's computation.

To establish causality, we use activation patching: we add or subtract the truth direction from the model's hidden states during inference and measure whether this changes the model's output. If adding the truth direction to false-statement activations makes the model predict TRUE, the direction is causally implicated.

From the Geometry of Truth paper:

"We evaluate the causal role of the truth representations by measuring the natural indirect effect (NIE) of interventions on the model's output probabilities."

Experimental setup

We'll use a few-shot TRUE/FALSE classification task: 1. Give the model 2-4 labeled examples: "Statement. This statement is: TRUE/FALSE" 2. Then present a new statement: "Statement. This statement is:" 3. Measure P(TRUE) - P(FALSE) from the model's next-token probabilities

The intervention: at layers between intervene_layer (8) and probe_layer (14), we add or subtract the scaled truth direction at two token positions per sequence: the final period of the statement and the first token of the appended " This statement is:" suffix.

Exercise - implement few-shot evaluation

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-15 minutes on this exercise. Establishing the baseline that the model can do the task in-context.

First, we need to verify that our model can actually perform TRUE/FALSE classification in-context. Implement a function that: 1. Prepends a few-shot prompt to each statement 2. Appends " This statement is:" to each statement 3. Runs the model and extracts P(TRUE) - P(FALSE) from the last-token logits

We'll use the few-shot prompt from the geometry-of-truth repo, which was calibrated for LLaMA-2-13B on sp_en_trans:

# Few-shot prompt from geometry-of-truth/interventions.py. Note 3 true + 1 false examples:
# the paper calibrated this specific prompt for LLaMA-2-13B on sp_en_trans.
FEW_SHOT_PROMPT = """\
The Spanish word 'jirafa' means 'giraffe'. This statement is: TRUE
The Spanish word 'escribir' means 'to write'. This statement is: TRUE
The Spanish word 'gato' means 'cat'. This statement is: TRUE
The Spanish word 'aire' means 'silver'. This statement is: FALSE
"""

    # Get token IDs for TRUE and FALSE
    TRUE_ID = tokenizer.encode(" TRUE")[-1]
    FALSE_ID = tokenizer.encode(" FALSE")[-1]
def few_shot_evaluate(
    statements: list[str],
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    few_shot_prompt: str,
    true_id: int,
    false_id: int,
    batch_size: int = 32,
) -> Float[Tensor, " n"]:
    """
    Evaluate P(TRUE) - P(FALSE) for each statement using few-shot classification.

    Args:
        statements: List of statements to classify.
        model: Language model.
        tokenizer: Tokenizer.
        few_shot_prompt: The few-shot prefix prompt.
        true_id: Token ID for " TRUE".
        false_id: Token ID for " FALSE".
        batch_size: Batch size.

    Returns:
        Tensor of P(TRUE) - P(FALSE) for each statement.
    """
    raise NotImplementedError()


# Load sp_en_trans for evaluation (exclude statements used in the few-shot prompt)
sp_df = datasets["sp_en_trans"]
sp_statements = sp_df["statement"].tolist()
sp_labels = t.tensor(sp_df["label"].values, dtype=t.float32)

# Filter out statements that appear in the few-shot prompt
sp_eval_mask = [s not in FEW_SHOT_PROMPT for s in sp_statements]
sp_eval_stmts = [s for s, m in zip(sp_statements, sp_eval_mask) if m]
sp_eval_labels = sp_labels[t.tensor(sp_eval_mask)]

p_diffs = few_shot_evaluate(sp_eval_stmts, model, tokenizer, FEW_SHOT_PROMPT, TRUE_ID, FALSE_ID)

# Compute accuracy
preds = (p_diffs > 0).float()
acc = (preds == sp_eval_labels).float().mean().item()
assert acc > 0.9, f"Few-shot accuracy too low: {acc:.3f} (expected > 0.9)"
true_mean = p_diffs[sp_eval_labels == 1].mean().item()
false_mean = p_diffs[sp_eval_labels == 0].mean().item()

print(f"Few-shot classification accuracy: {acc:.3f}")
print(f"Mean P(TRUE)-P(FALSE) for true statements:  {true_mean:.4f}")
print(f"Mean P(TRUE)-P(FALSE) for false statements: {false_mean:.4f}")

# Histogram
fig = go.Figure()
fig.add_trace(
    go.Histogram(x=p_diffs[sp_eval_labels == 1].numpy(), name="True", marker_color="blue", opacity=0.6, nbinsx=30)
)
fig.add_trace(
    go.Histogram(x=p_diffs[sp_eval_labels == 0].numpy(), name="False", marker_color="red", opacity=0.6, nbinsx=30)
)
fig.add_vline(x=0, line_dash="dash", line_color="gray")
fig.update_layout(
    title="Few-Shot Classification: P(TRUE) - P(FALSE)",
    xaxis_title="P(TRUE) - P(FALSE)",
    yaxis_title="Count",
    barmode="overlay",
    height=400,
    width=700,
)
fig.show()
Click to see the expected output
Few-shot classification accuracy: 0.997
Mean P(TRUE)-P(FALSE) for true statements:  0.7853
Mean P(TRUE)-P(FALSE) for false statements: -0.9150
Solution
def few_shot_evaluate(
    statements: list[str],
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    few_shot_prompt: str,
    true_id: int,
    false_id: int,
    batch_size: int = 32,
) -> Float[Tensor, " n"]:
    """
    Evaluate P(TRUE) - P(FALSE) for each statement using few-shot classification.

    Args:
        statements: List of statements to classify.
        model: Language model.
        tokenizer: Tokenizer.
        few_shot_prompt: The few-shot prefix prompt.
        true_id: Token ID for " TRUE".
        false_id: Token ID for " FALSE".
        batch_size: Batch size.

    Returns:
        Tensor of P(TRUE) - P(FALSE) for each statement.
    """
    p_diffs = []

    for i in range(0, len(statements), batch_size):
        batch = statements[i : i + batch_size]
        queries = [few_shot_prompt + stmt + " This statement is:" for stmt in batch]

        inputs = tokenizer(queries, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)

        with t.no_grad():
            outputs = model(**inputs)
            # Get logits at the last non-padding position
            last_idx = inputs["attention_mask"].sum(dim=1) - 1
            batch_indices = t.arange(len(batch), device=outputs.logits.device)
            last_logits = outputs.logits[batch_indices, last_idx]  # [batch, vocab]
            probs = last_logits.softmax(dim=-1)
            p_diff = probs[:, true_id] - probs[:, false_id]
            p_diffs.append(p_diff.cpu().float())

    return t.cat(p_diffs)

The hook used in the intervention experiment needs to handle variable-length sequences in a batch, finding the right token positions dynamically per batch element using attention_mask. The following is a simpler, fixed-position version for reference - the real exercise is implementing the batch-aware version inside intervention_experiment below.

def make_intervention_hook(
    direction: Float[Tensor, " d_model"],
    scale: float,
    positions: list[int],
) -> callable:
    """
    Create a forward hook that adds scale * direction to hidden states at fixed positions.
    This handles both plain-tensor and tuple outputs from transformer layers.
    """

    def hook_fn(module, input, output):
        if isinstance(output, tuple):
            hidden_states = output[0]
        else:
            hidden_states = output

        for pos in positions:
            if 0 <= pos < hidden_states.shape[1]:
                hidden_states[:, pos, :] += scale * direction

        if isinstance(output, tuple):
            return (hidden_states,) + output[1:]
        else:
            return hidden_states

    return hook_fn

Exercise - implement the batch-aware intervention hook

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 15-20 minutes on this exercise. This is the most important exercise in this section - it establishes causality.

We've provided the scaffolding for intervention_experiment below. Your task is to implement make_batch_hook - the function that returns a hook which adds the scaled direction vector to hidden states at the right token positions for each batch element.

The hook needs to work with variable-length sequences, since after padding different sequences in a batch end at different positions. Use attention_mask.sum(dim=1) to find the real sequence length end for each batch element, and len_suffix (the number of tokens in " This statement is:") to find the two target positions: end - len_suffix - 1 (the final period of the statement) and end - len_suffix (the first token of " This statement is:", i.e. the word "This").

Look at the simpler make_intervention_hook above for reference on handling tuple vs. plain tensor outputs.

Read through the full function before you start. The scaffolding: * Constructs queries by prepending the few-shot prompt and appending " This statement is:" to each statement * Registers hooks on each intervention layer (and removes them in a finally block so errors don't leave hooks stuck) * Extracts P(TRUE) - P(FALSE) from the last-token logits after the forward pass

def intervention_experiment(
    statements: list[str],
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    direction: Float[Tensor, " d_model"],
    few_shot_prompt: str,
    true_id: int,
    false_id: int,
    intervene_layers: list[int],
    intervention: str = "none",
    batch_size: int = 32,
) -> Float[Tensor, " n"]:
    """
    Run the intervention experiment.

    Args:
        statements: Statements to evaluate.
        model: Language model.
        tokenizer: Tokenizer.
        direction: The (already scaled) truth direction vector.
        few_shot_prompt: Few-shot prefix.
        true_id: Token ID for " TRUE".
        false_id: Token ID for " FALSE".
        intervene_layers: List of layer indices to intervene at.
        intervention: "none", "add", or "subtract".
        batch_size: Batch size.

    Returns:
        P(TRUE) - P(FALSE) for each statement.
    """
    assert intervention in ["none", "add", "subtract"]

    # Determine how many tokens " This statement is:" adds
    suffix_tokens = tokenizer.encode(" This statement is:")
    len_suffix = len(suffix_tokens)

    p_diffs = []
    for i in range(0, len(statements), batch_size):
        batch = statements[i : i + batch_size]
        queries = [few_shot_prompt + stmt + " This statement is:" for stmt in batch]

        inputs = tokenizer(queries, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)

        # Register hooks for intervention
        hooks = []
        if intervention != "none":
            dir_device = direction.to(model.device)
            scale = 1.0 if intervention == "add" else -1.0

            # Each sequence in the batch can have a different length, so we iterate over batch
            # elements inside the hook, using attention_mask to find real sequence lengths.
            def make_batch_hook(dir_vec, attn_mask, scl):
                def hook_fn(module, input, output):
                    # YOUR CODE HERE - implement the batch-aware hook:
                    # 1. Extract hidden_states from output (handle tuple or plain tensor)
                    # 2. For each batch element b, find end = attn_mask[b].sum()
                    # 3. Patch at positions end - len_suffix and end - len_suffix - 1
                    # 4. Return the modified output (keeping the tuple structure if applicable)
                    raise NotImplementedError()
                return hook_fn

            for layer_idx in intervene_layers:
                hook = model.model.layers[layer_idx].register_forward_hook(
                    make_batch_hook(dir_device, inputs["attention_mask"], scale)
                )
                hooks.append(hook)

        with t.no_grad():
            # Common pattern for hooks, so failed hooks don't get stuck
            try:
                outputs = model(**inputs)
            finally:
                for hook in hooks:
                    hook.remove()

            # Get logits at the last non-padding position, then get probability differences
            last_idx = inputs["attention_mask"].sum(dim=1) - 1
            batch_indices = t.arange(len(batch), device=outputs.logits.device)
            last_logits = outputs.logits[batch_indices, last_idx]
            probs = last_logits.softmax(dim=-1)
            p_diff = probs[:, true_id] - probs[:, false_id]
            p_diffs.append(p_diff.cpu().float())

    return t.cat(p_diffs)


# Train the intervention probe on cities + neg_cities combined. The paper found that
# "training on statements and their opposites improves generalization" - using both
# a statement and its negation gives the probe a cleaner truth direction.
# Load neg_cities for this paired training
neg_cities_df = pd.read_csv(GOT_DATASETS / "neg_cities.csv")
neg_cities_stmts = neg_cities_df["statement"].tolist()
neg_cities_labels = t.tensor(neg_cities_df["label"].values, dtype=t.float32)

neg_cities_acts_dict = extract_activations(neg_cities_stmts, model, tokenizer, [PROBE_LAYER])
neg_cities_acts = neg_cities_acts_dict[PROBE_LAYER]

# Train probe on cities + neg_cities combined
combined_acts = t.cat([activations["cities"], neg_cities_acts])
combined_labels = t.cat([labels_dict["cities"], neg_cities_labels])
combined_probe = MMProbe.from_data(combined_acts, combined_labels)

# Scale the direction
direction = combined_probe.direction
direction_hat = direction / direction.norm()
true_acts = combined_acts[combined_labels == 1]
false_acts = combined_acts[combined_labels == 0]
true_mean = true_acts.mean(0)
false_mean = false_acts.mean(0)
projection_diff = ((true_mean - false_mean) @ direction_hat).item()
scaled_direction = projection_diff * direction_hat

# Intervene at all layers from INTERVENE_LAYER through PROBE_LAYER. This matches
# the paper's "group (b)" hidden states that were found to be causally implicated.
intervene_layer_list = list(range(INTERVENE_LAYER, PROBE_LAYER + 1))

# Run for all 3 conditions × 2 subsets
results_intervention = {}
for intervention_type in ["none", "add", "subtract"]:
    for subset in ["true", "false"]:
        mask = sp_eval_labels == (1 if subset == "true" else 0)
        subset_stmts = [s for s, m in zip(sp_eval_stmts, mask.tolist()) if m]
        p_diffs = intervention_experiment(
            subset_stmts,
            model,
            tokenizer,
            scaled_direction,
            FEW_SHOT_PROMPT,
            TRUE_ID,
            FALSE_ID,
            intervene_layer_list,
            intervention=intervention_type,
        )
        results_intervention[(intervention_type, subset)] = p_diffs.mean().item()

# Print results
intervention_df = pd.DataFrame(
    {
        "Intervention": ["none", "add", "subtract"],
        "True Stmts (mean P_diff)": [
            f"{results_intervention[('none', 'true')]:.4f}",
            f"{results_intervention[('add', 'true')]:.4f}",
            f"{results_intervention[('subtract', 'true')]:.4f}",
        ],
        "False Stmts (mean P_diff)": [
            f"{results_intervention[('none', 'false')]:.4f}",
            f"{results_intervention[('add', 'false')]:.4f}",
            f"{results_intervention[('subtract', 'false')]:.4f}",
        ],
    }
)
print("\nIntervention results (mean P(TRUE) - P(FALSE)):")
display(intervention_df)

# Grouped bar chart
fig = go.Figure()
for subset, color in [("true", "blue"), ("false", "red")]:
    vals = [results_intervention[(interv, subset)] for interv in ["none", "add", "subtract"]]
    fig.add_trace(
        go.Bar(
            name=f"{subset.capitalize()} statements",
            x=["None", "Add", "Subtract"],
            y=vals,
            marker_color=color,
            opacity=0.7,
        )
    )
fig.update_layout(
    title="Causal Intervention: Effect on P(TRUE) - P(FALSE)",
    yaxis_title="Mean P(TRUE) - P(FALSE)",
    barmode="group",
    height=400,
    width=600,
)
fig.add_hline(y=0, line_dash="dash", line_color="gray")
fig.show()
Click to see the expected output
Solution
def intervention_experiment(
    statements: list[str],
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    direction: Float[Tensor, " d_model"],
    few_shot_prompt: str,
    true_id: int,
    false_id: int,
    intervene_layers: list[int],
    intervention: str = "none",
    batch_size: int = 32,
) -> Float[Tensor, " n"]:
    """
    Run the intervention experiment.

    Args:
        statements: Statements to evaluate.
        model: Language model.
        tokenizer: Tokenizer.
        direction: The (already scaled) truth direction vector.
        few_shot_prompt: Few-shot prefix.
        true_id: Token ID for " TRUE".
        false_id: Token ID for " FALSE".
        intervene_layers: List of layer indices to intervene at.
        intervention: "none", "add", or "subtract".
        batch_size: Batch size.

    Returns:
        P(TRUE) - P(FALSE) for each statement.
    """
    assert intervention in ["none", "add", "subtract"]

    # Determine how many tokens " This statement is:" adds
    suffix_tokens = tokenizer.encode(" This statement is:")
    len_suffix = len(suffix_tokens)

    p_diffs = []
    for i in range(0, len(statements), batch_size):
        batch = statements[i : i + batch_size]
        queries = [few_shot_prompt + stmt + " This statement is:" for stmt in batch]

        inputs = tokenizer(queries, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)

        # Register hooks for intervention
        hooks = []
        if intervention != "none":
            dir_device = direction.to(model.device)
            scale = 1.0 if intervention == "add" else -1.0

            # Each sequence in the batch can have a different length, so we iterate over batch
            # elements inside the hook, using attention_mask to find real sequence lengths.
            def make_batch_hook(dir_vec, attn_mask, scl):
                def hook_fn(module, input, output):
                    hidden_states = output[0] if isinstance(output, tuple) else output

                    seq_lens = attn_mask.sum(dim=1)  # [batch]
                    for b in range(hidden_states.shape[0]):
                        end = seq_lens[b].item()
                        for offset in [-len_suffix, -len_suffix - 1]:
                            pos = int(end + offset)
                            if 0 <= pos < hidden_states.shape[1]:
                                hidden_states[b, pos, :] += scl * dir_vec

                    return (hidden_states,) + output[1:] if isinstance(output, tuple) else hidden_states

                return hook_fn

            for layer_idx in intervene_layers:
                hook = model.model.layers[layer_idx].register_forward_hook(
                    make_batch_hook(dir_device, inputs["attention_mask"], scale)
                )
                hooks.append(hook)

        with t.no_grad():
            # Common pattern for hooks, so failed hooks don't get stuck
            try:
                outputs = model(**inputs)
            finally:
                for hook in hooks:
                    hook.remove()

            # Get logits at the last non-padding position, then get probability differences
            last_idx = inputs["attention_mask"].sum(dim=1) - 1
            batch_indices = t.arange(len(batch), device=outputs.logits.device)
            last_logits = outputs.logits[batch_indices, last_idx]
            probs = last_logits.softmax(dim=-1)
            p_diff = probs[:, true_id] - probs[:, false_id]
            p_diffs.append(p_diff.cpu().float())

    return t.cat(p_diffs)

The key result: adding the truth direction to false-statement activations should push P(TRUE) - P(FALSE) upward (making the model more likely to predict TRUE), while subtracting it from true-statement activations should push it downward. This demonstrates that the probe direction is causally implicated in the model's computation, not merely correlated with truth.

Comparing MM vs. LR interventions

Now let's repeat the intervention experiment using the LR probe's direction instead of the MM direction. We'll scale both directions the same way and compare the Natural Indirect Effects (NIEs).

As a reminder, the NIE for "add" on false statements = P_diff(add) - P_diff(none). A higher NIE means the direction is more causally implicated. Run the code below to see how the two probe types compare.

# Train LR probe on same data
lr_combined = LRProbe.from_data(combined_acts, combined_labels)
lr_direction = lr_combined.direction.detach()
lr_direction_hat = lr_direction / lr_direction.norm()
lr_proj_diff = ((true_mean - false_mean) @ lr_direction_hat).item()
lr_scaled_direction = lr_proj_diff * lr_direction_hat

# Run intervention for LR direction
lr_results = {}
for intervention_type in ["none", "add", "subtract"]:
    for subset in ["true", "false"]:
        mask = sp_eval_labels == (1 if subset == "true" else 0)
        subset_stmts = [s for s, m in zip(sp_eval_stmts, mask.tolist()) if m]
        p_diffs = intervention_experiment(
            subset_stmts,
            model,
            tokenizer,
            lr_scaled_direction,
            FEW_SHOT_PROMPT,
            TRUE_ID,
            FALSE_ID,
            intervene_layer_list,
            intervention=intervention_type,
        )
        lr_results[(intervention_type, subset)] = p_diffs.mean().item()

# Compute NIEs
mm_nie_false = results_intervention[("add", "false")] - results_intervention[("none", "false")]
mm_nie_true = results_intervention[("subtract", "true")] - results_intervention[("none", "true")]
lr_nie_false = lr_results[("add", "false")] - lr_results[("none", "false")]
lr_nie_true = lr_results[("subtract", "true")] - lr_results[("none", "true")]

nie_df = pd.DataFrame(
    {
        "Probe": ["MM", "MM", "LR", "LR"],
        "Intervention": ["Add to false", "Subtract from true", "Add to false", "Subtract from true"],
        "NIE": [f"{mm_nie_false:.4f}", f"{mm_nie_true:.4f}", f"{lr_nie_false:.4f}", f"{lr_nie_true:.4f}"],
    }
)
print("Natural Indirect Effects (NIE):")
display(nie_df)

# Side-by-side bar chart
fig = go.Figure()
fig.add_trace(
    go.Bar(
        name="MM Probe",
        x=["Add→False", "Sub→True"],
        y=[mm_nie_false, mm_nie_true],
        marker_color="blue",
        opacity=0.7,
    )
)
fig.add_trace(
    go.Bar(
        name="LR Probe",
        x=["Add→False", "Sub→True"],
        y=[lr_nie_false, lr_nie_true],
        marker_color="orange",
        opacity=0.7,
    )
)
fig.update_layout(
    title="Natural Indirect Effect: MM vs LR Probe Directions",
    yaxis_title="NIE (change in P(TRUE)-P(FALSE))",
    barmode="group",
    height=400,
    width=600,
)
fig.show()
Click to see the expected output
Question - Which probe type produces a more causally implicated direction? Why might this be?

The MM (difference-of-means) probe should produce a direction with higher NIE than the LR (logistic regression) probe, even though LR achieves higher classification accuracy. From the Geometry of Truth paper:

"Mass-mean probe directions are highly causal, with MM outperforming LR and CCS in 7/8 experimental conditions, often substantially."

The explanation: LR optimizes for classification accuracy, which means it can exploit any feature that correlates with truth, even if that feature isn't causally used by the model. The MM direction, by contrast, is the geometric center of the true/false clusters. As the paper notes: "In some cases, however, the direction identified by LR can fail to reflect an intuitive best guess for the feature direction, even in the absence of confounding features."

For a related but slightly different perspective, see the Adversarial Examples Are Not Bugs, They Are Features paper, which uses the "feature robustness framing" to explain how the direction learned by an accuracy-optimizing classifier might not always line up with the mean-difference direction, or what we would view as the "canonical" direction of the feature we're trying to learn. (Incidentally, this is also a motivator for why encoders and decoders in SAEs are untied - we use encoder vectors for detection but decoder vectors for steering.)

This is an important cautionary tale: high probe accuracy does not guarantee causal relevance. Always validate with interventions!