Exercise Status: Mostly complete (white box section needs proof-reading, 'thought branches' bonus not done).

3️⃣ White-box Methods

Learning Objectives
  • Understand receiver heads: attention heads that aggregate information from reasoning steps into the final answer
  • Compute vertical attention scores: quantifying how much each token position attends to specific reasoning steps
  • Implement RoPE (Rotary Position Embeddings) to understand positional encoding in attention
  • Build attention suppression interventions to causally validate which attention patterns matter
  • Compare white-box attention metrics with black-box importance scores to validate mechanistic understanding

Connecting black-box to white-box: In Section 2, we used black-box methods (resampling, counterfactuals) to identify which sentences are important for the final answer. These methods told us what matters, but not why or how the model implements this importance mechanistically.

Now we'll open the black box and look inside the model. The key question: if certain sentences are behaviorally important (high counterfactual importance), can we find attention heads that attend strongly to those same sentences? If so, we've found a mechanistic explanation for the black-box patterns.

The hypothesis: Attention is the mechanism by which information from earlier sentences flows to later positions in the sequence. If a sentence is important for the final answer, we should expect to see attention heads at later positions (especially at the final answer token) attending heavily to that sentence. These are called receiver heads - heads that aggregate important information from reasoning steps.

What we'll do: 1. Compute vertical attention scores - for each sentence, measure how much future tokens attend to it 2. Identify receiver heads - heads with spiky attention to specific sentences (high kurtosis) 3. Validate mechanistically - compare attention patterns to our black-box importance scores from Section 2 4. Causal intervention - use attention suppression to prove these patterns matter causally

By the end, we'll have shown that the black-box patterns (which sentences matter) correspond to white-box mechanisms (which sentences get attended to), giving us a complete mechanistic understanding of thought anchors.

The black-box methods we've used so far treat the model as a black box - we only observe inputs and outputs. But we have access to the model's internals! In this section, we'll look inside the model to understand how it implements the sentence importance patterns we observed.

The key insight from the paper is that certain attention heads ("receiver heads") have attention patterns that correlate strongly with our black-box importance metrics. These heads tend to attend heavily to thought anchors - the sentences that shape the model's reasoning trajectory.

Firstly, let's define a helper function that we'll use for getting whitebox data from our loaded problem_data. We're defining this because we'll only ever need to use the full CoT & chunks plus the counterfactual importance scores later on (when we're comparing them to our whitebox analysis).

Receiver Head Analysis

Key concepts: - Vertical attention score: For each sentence, how much do all future sentences attend to it? - Receiver heads: Attention heads with high kurtosis in their vertical attention scores (i.e., they attend spikily to specific sentences rather than uniformly)

For these exercises, we'll use a few helper functions:

  • utils.get_whitebox_example_data(): Extracts a tuple of the following from problem_data, since this is all we'll be using in this section:
    • Full CoT reasoning trace (string)
    • List of chunks (list of strings)
    • Counterfactual importance scores (numpy array) for us to compare with the whitebox results
  • utils.get_sentence_token_boundaries(): Maps sentences to token positions (handles tokenization complexities)
  • utils.average_attention_by_sentences(): Aggregates token-level attention to sentence-level

Exercise - extract attention matrices using hooks

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵🔵⚪
> You should spend up to 20-25 minutes on this exercise.

The first step in whitebox analysis is to extract attention patterns from the model's forward pass. You can use output_attentions=True in your model forward pass to get the attention weights - they're accessible as output.attentions[layer][batch_idx, head_idx] after you do this (note that this was only possible because we set the model's config to output attentions when we loaded it earlier in this notebook).

The function extract_attention_matrix should return the attention weights for a particular layer & head, and also return the list of string tokens for convenience.

def extract_attention_matrix(
    text: str,
    model,
    tokenizer,
    layer: int,
    head: int,
) -> tuple[Float[np.ndarray, "seq seq"], list[str]]:
    """
    Extract attention matrix from a specific layer and head using hooks.

    Args:
        text: Input text to analyze
        model: HuggingFace model (already loaded)
        tokenizer: Corresponding tokenizer
        layer: Which layer to extract from (0-indexed)
        head: Which attention head to extract from (0-indexed)

    Returns:
        attention_matrix: Shape (seq_len, seq_len) attention weights for the specified head
        tokens: List of token strings for visualization
    """
    raise NotImplementedError("Implement attention extraction using hooks")


# Load model and tokenizer (with attention output enabled - important!)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_1B)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_1B, device_map="auto")  # torch_dtype=torch.bfloat16
model.config._attn_implementation = "eager"
model.config.output_attentions = True
model.eval()

# Test with a short example
test_text = "The cat sat on the mat. It was sleeping."

# Extract attention from middle layer
n_layers = len(model.model.layers)
layer_to_check = n_layers // 2
head_to_check = 8  # head 9 is better for the 1.5b model
print(f"\nExtracting attention from head L{layer_to_check}H{head_to_check}")
attention_matrix, tokens = extract_attention_matrix(
    test_text, model, tokenizer, layer=layer_to_check, head=head_to_check
)
assert attention_matrix.shape[0] == len(tokens), "Shape mismatch"
assert attention_matrix.shape[0] == attention_matrix.shape[1], "Not square"
assert np.allclose(attention_matrix.sum(axis=1), 1.0, atol=1e-5), "Rows don't sum to 1"

# Check causal structure (upper triangle should be mostly zeros)
upper_triangle_sum = np.triu(attention_matrix, k=1).sum()
assert upper_triangle_sum < 1e-5, "Upper triangle has non-zero values, not causal attention"

# Visualize
display(
    cv.attention.attention_heads(
        attention=attention_matrix[None],
        tokens=tokens,
    )
)
Click to see the expected output
Hint #1

The forward method of Qwen2Attention returns a tuple (attn_output, attn_weights, past_key_value). The hook should extract attn_weights from the output tuple.

Solution
def extract_attention_matrix(
    text: str,
    model,
    tokenizer,
    layer: int,
    head: int,
) -> tuple[Float[np.ndarray, "seq seq"], list[str]]:
    """
    Extract attention matrix from a specific layer and head using hooks.
    Args:
        text: Input text to analyze
        model: HuggingFace model (already loaded)
        tokenizer: Corresponding tokenizer
        layer: Which layer to extract from (0-indexed)
        head: Which attention head to extract from (0-indexed)
    Returns:
        attention_matrix: Shape (seq_len, seq_len) attention weights for the specified head
        tokens: List of token strings for visualization
    """
# Tokenize input
    inputs = tokenizer(text, return_tensors="pt")
    if model.device.type == "cuda":
        inputs = {k: v.cuda() for k, v in inputs.items()}
# Forward pass - hook will capture attention
    with torch.no_grad():
        output = model(**inputs, use_cache=False, output_attentions=True)
# captured_attention[0] has shape (batch_size, num_heads, seq_len, seq_len)
    attn_tensor = output.attentions[layer][0, head]  # Select batch 0, specific head
# Convert to numpy
    attention_matrix = attn_tensor.cpu().numpy().astype(np.float32)
# Get token strings
    token_ids = inputs["input_ids"][0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
return attention_matrix, tokens
# Load model and tokenizer (with attention output enabled - important!)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_1B)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_1B, device_map="auto")  # torch_dtype=torch.bfloat16
model.config._attn_implementation = "eager"
model.config.output_attentions = True
model.eval()
# Test with a short example
test_text = "The cat sat on the mat. It was sleeping."
# Extract attention from middle layer
n_layers = len(model.model.layers)
layer_to_check = n_layers // 2
head_to_check = 8  # head 9 is better for the 1.5b model
print(f"\nExtracting attention from head L{layer_to_check}H{head_to_check}")
attention_matrix, tokens = extract_attention_matrix(
    test_text, model, tokenizer, layer=layer_to_check, head=head_to_check
)
assert attention_matrix.shape[0] == len(tokens), "Shape mismatch"
assert attention_matrix.shape[0] == attention_matrix.shape[1], "Not square"
assert np.allclose(attention_matrix.sum(axis=1), 1.0, atol=1e-5), "Rows don't sum to 1"
# Check causal structure (upper triangle should be mostly zeros)
upper_triangle_sum = np.triu(attention_matrix, k=1).sum()
assert upper_triangle_sum < 1e-5, "Upper triangle has non-zero values, not causal attention"
# Visualize
display(
    cv.attention.attention_heads(
        attention=attention_matrix[None],
        tokens=tokens,
    )
)

Exercise - compute vertical attention scores (Part 1: Simple version)

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵🔵🔵
> You should spend up to 10-15 minutes on this exercise.

Vertical attention scores measure how much future sentences attend back to each sentence. This is the core metric for identifying "thought anchors" - sentences that shape downstream reasoning.

In this first part, implement a simple version that computes the basic vertical attention scores without any normalization.

def get_vertical_scores_simple(
    avg_attention_matrix: Float[np.ndarray, "n_sentences n_sentences"],
    proximity_ignore: int = 4,
) -> Float[np.ndarray, " n_sentences"]:
    """
    Compute basic vertical attention scores for each sentence (no normalization).

    Args:
        avg_attention_matrix: Shape (n_sentences, n_sentences), where entry (i, j)
            is the average attention from sentence i to sentence j
        proximity_ignore: Ignore this many nearby sentences (to avoid trivial patterns)

    Returns:
        Array of shape (n_sentences,) with vertical scores

    The vertical score for sentence j is the mean attention it receives from
    all sentences i where i > j + proximity_ignore.

    Adapted from thought-anchors: attn_funcs.py:get_vertical_scores
    """
    raise NotImplementedError("Implement basic vertical attention scores")
Solution
def get_vertical_scores_simple(
    avg_attention_matrix: Float[np.ndarray, "n_sentences n_sentences"],
    proximity_ignore: int = 4,
) -> Float[np.ndarray, " n_sentences"]:
    """
    Compute basic vertical attention scores for each sentence (no normalization).
    Args:
        avg_attention_matrix: Shape (n_sentences, n_sentences), where entry (i, j)
            is the average attention from sentence i to sentence j
        proximity_ignore: Ignore this many nearby sentences (to avoid trivial patterns)
    Returns:
        Array of shape (n_sentences,) with vertical scores
    The vertical score for sentence j is the mean attention it receives from
    all sentences i where i > j + proximity_ignore.
    Adapted from thought-anchors: attn_funcs.py:get_vertical_scores
    """
    n = avg_attention_matrix.shape[0]
    mat = avg_attention_matrix.copy()
# Step 1: Clean matrix - set upper triangle to NaN (can't attend to future)
    mat[np.triu_indices_from(mat, k=1)] = np.nan
# Step 2: Remove proximity - set near-diagonal to NaN (ignore nearby sentences)
    mat[np.triu_indices_from(mat, k=-proximity_ignore + 1)] = np.nan
# Step 3: Compute vertical scores: mean attention received from future sentences
    vert_scores = []
    for j in range(n):
        # Extract "vertical line" - attention from all future sentences to sentence j
        future_attention = mat[j + proximity_ignore :, j]
if len(future_attention) == 0 or np.all(np.isnan(future_attention)):
            vert_scores.append(np.nan)
        else:
            vert_scores.append(np.nanmean(future_attention))
return np.array(vert_scores)

Let's test the simple version on real CoT attention patterns.

# Get example problem data
text, sentences, _ = utils.get_whitebox_example_data(problem_data)

# Extract attention from a middle layer
layer, head = 15, 8
print(f"\nExtracting attention from head L{layer}H{head}...")
token_attention, tokens = extract_attention_matrix(text, model, tokenizer, layer, head)

# Convert to sentence-level attention
boundaries = utils.get_sentence_token_boundaries(text, sentences, tokenizer)
sentence_attention = utils.average_attention_by_sentences(token_attention, boundaries)

# Compute vertical scores (simple version)
vert_scores_simple = get_vertical_scores_simple(sentence_attention, proximity_ignore=4)

# Visualize
hover_texts = [s[:80] + "..." if len(s) > 80 else s for s in sentences]

fig = go.Figure()
fig.add_trace(
    go.Bar(
        x=list(range(len(vert_scores_simple))),
        y=vert_scores_simple,
        marker_color="steelblue",
        hovertemplate="<b>Sentence %{x}</b><br>Score: %{y:.4f}<br>%{customdata}<extra></extra>",
        customdata=hover_texts,
    )
)
fig.add_hline(y=0, line_color="black", line_width=0.5)
fig.update_layout(
    title="Simple Vertical Scores (No Normalization)",
    xaxis_title="Sentence Index",
    yaxis_title="Vertical Score",
    width=1000,
    height=450,
)
fig.show()
Click to see the expected output

Exercise - compute vertical attention scores (Part 2: Full version with depth control)

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵
> You should spend up to 15-20 minutes on this exercise.

Challenge: Later sentences have more tokens to attend to than earlier sentences, which can bias scores. The paper addresses this with depth-control normalization using rank-based normalization.

Now implement the full algorithm including depth control and optional z-score normalization.

Hint: Understanding the depth-control normalization

The depth bias problem: sentence 20 has ~20 previous sentences to attend to, while sentence 5 has only ~5. So raw attention values aren't comparable across rows.

The fix uses rank normalization per row: 1. For each row $i$, count the number of valid (non-NaN) positions: n_valid = sum(~isnan(row)) 2. Replace each attention value with its rank among valid values in that row 3. Divide by n_valid to normalize to [0, 1]

This way, every row has the same scale regardless of how many positions are available. You can use scipy.stats.rankdata with nan_policy="omit" to rank values while ignoring NaNs.

After depth control, the vertical score computation (step 4) and optional z-scoring (step 5) are the same as in the simple version.

def get_vertical_scores(
    avg_attention_matrix: Float[np.ndarray, "n_sentences n_sentences"],
    proximity_ignore: int = 4,
    control_depth: bool = True,
    return_z_scores: bool = False,
) -> Float[np.ndarray, " n_sentences"]:
    """
    Compute vertical attention scores for each sentence with optional depth control and z-scoring.

    Args:
        avg_attention_matrix: Shape (n_sentences, n_sentences), where entry (i, j)
            is the average attention from sentence i to sentence j
        proximity_ignore: Ignore this many nearby sentences (to avoid trivial patterns)
        control_depth: Apply rank normalization to control for depth effects
        return_z_scores: If True, return z-score normalized values

    Returns:
        Array of shape (n_sentences,) with vertical scores

    The vertical score for sentence j is the mean attention it receives from
    all sentences i where i > j + proximity_ignore.

    Depth control (when enabled): Normalizes each row by ranking attention values
    and dividing by the count of valid (non-NaN) positions. This ensures fair
    comparison between early sentences (fewer tokens to attend to) and late
    sentences (many tokens to attend to).

    Adapted from thought-anchors: attn_funcs.py:get_vertical_scores
    """
    # Step 1: Copy matrix and set upper triangle to NaN (can't attend to future)
    # Step 2: Set near-diagonal to NaN using proximity_ignore (ignore nearby sentences)
    # Step 3: If control_depth, rank-normalize each row:
    #   - Count valid (non-NaN) values per row
    #   - Use stats.rankdata(mat, axis=1, nan_policy="omit") to rank within rows
    #   - Divide by the per-row valid counts to normalize to [0, 1]
    # Step 4: Compute vertical scores (same as simple version)
    # Step 5: If return_z_scores, z-score normalize the result
    raise NotImplementedError()
Solution
def get_vertical_scores(
    avg_attention_matrix: Float[np.ndarray, "n_sentences n_sentences"],
    proximity_ignore: int = 4,
    control_depth: bool = True,
    return_z_scores: bool = False,
) -> Float[np.ndarray, " n_sentences"]:
    """
    Compute vertical attention scores for each sentence with optional depth control and z-scoring.
    Args:
        avg_attention_matrix: Shape (n_sentences, n_sentences), where entry (i, j)
            is the average attention from sentence i to sentence j
        proximity_ignore: Ignore this many nearby sentences (to avoid trivial patterns)
        control_depth: Apply rank normalization to control for depth effects
        return_z_scores: If True, return z-score normalized values
    Returns:
        Array of shape (n_sentences,) with vertical scores
    The vertical score for sentence j is the mean attention it receives from
    all sentences i where i > j + proximity_ignore.
    Depth control (when enabled): Normalizes each row by ranking attention values
    and dividing by the count of valid (non-NaN) positions. This ensures fair
    comparison between early sentences (fewer tokens to attend to) and late
    sentences (many tokens to attend to).
    Adapted from thought-anchors: attn_funcs.py:get_vertical_scores
    """
    n = avg_attention_matrix.shape[0]
    mat = avg_attention_matrix.copy()
# Step 1: Clean matrix - set upper triangle to NaN (can't attend to future)
    mat[np.triu_indices_from(mat, k=1)] = np.nan
# Step 2: Remove proximity - set near-diagonal to NaN (ignore nearby sentences)
    mat[np.triu_indices_from(mat, k=-proximity_ignore + 1)] = np.nan
# Step 3: Depth control normalization (critical for fair comparison!)
    if control_depth:
        # Count non-NaN values per row (available positions to attend to)
        per_row = np.sum(~np.isnan(mat), axis=1)
# Rank-normalize each row: convert attention values to ranks,
        # then divide by number of valid positions
        # This puts all rows on the same scale regardless of depth
        mat = stats.rankdata(mat, axis=1, nan_policy="omit") / per_row[:, None]
# Step 4: Compute vertical scores: mean attention received from future sentences
    vert_scores = []
    for j in range(n):
        # Extract "vertical line" - attention from all future sentences to sentence j
        future_attention = mat[j + proximity_ignore :, j]
if len(future_attention) == 0 or np.all(np.isnan(future_attention)):
            vert_scores.append(np.nan)
        else:
            vert_scores.append(np.nanmean(future_attention))
vert_scores = np.array(vert_scores)
# Step 5: Optional z-score normalization
    if return_z_scores:
        vert_scores = (vert_scores - np.nanmean(vert_scores)) / np.nanstd(vert_scores)
return vert_scores

Let's test vertical scores on real CoT attention patterns and see the effect of depth control.

# Compute vertical scores WITH and WITHOUT depth control (using z-scores for comparison)
vert_scores_no_control = get_vertical_scores(
    sentence_attention, proximity_ignore=4, control_depth=False, return_z_scores=True
)
vert_scores_with_control = get_vertical_scores(
    sentence_attention, proximity_ignore=4, control_depth=True, return_z_scores=True
)

# Get top-3 sentences (filtering NaN values)
valid_mask = ~np.isnan(vert_scores_no_control) & ~np.isnan(vert_scores_with_control)
valid_indices = np.where(valid_mask)[0]
top_3_no_control = valid_indices[np.argsort(vert_scores_no_control[valid_mask])[-3:][::-1]]
top_3_with_control = valid_indices[np.argsort(vert_scores_with_control[valid_mask])[-3:][::-1]]

print(f"Top-3 sentences (no depth control): {top_3_no_control}")
print(f"Top-3 sentences (with depth control): {top_3_with_control}")
print(
    f"Correlation between methods: {np.corrcoef(vert_scores_no_control[valid_mask], vert_scores_with_control[valid_mask])[0, 1]:.3f}"
)

# Plot comparison (already z-scored from the function)
hover_texts = [s[:80] + "..." if len(s) > 80 else s for s in sentences]

fig = go.Figure()
for name, scores, color in [
    ("No depth control", vert_scores_no_control, "steelblue"),
    ("With depth control", vert_scores_with_control, "darkorange"),
]:
    fig.add_trace(
        go.Bar(
            x=list(range(len(scores))),
            y=scores,
            name=name,
            opacity=0.7,
            marker_color=color,
            hovertemplate="<b>Sentence %{x}</b><br>Z-score: %{y:.2f}<br>%{customdata}<extra></extra>",
            customdata=hover_texts,
        )
    )
fig.add_hline(y=0, line_color="black", line_width=0.5)
fig.update_layout(
    title="Vertical Scores Comparison (Z-scored)",
    xaxis_title="Sentence Index",
    yaxis_title="Vertical Score (z-scored)",
    width=1000,
    height=450,
    barmode="group",
)
fig.show()
Click to see the expected output
Why depth control matters

Without depth control, later sentences appear more important simply because they have more past tokens to attend to. For example: - Sentence 0: Can only be attended to by 19 future sentences - Sentence 15: Can only be attended to by 4 future sentences

This creates a bias where early sentences automatically get higher scores. Depth control (rank normalization) fixes this by normalizing each sentence's attention pattern relative to the number of positions it could attend to.

Key insight from the paper: With depth control enabled, the vertical scores better identify genuine "thought anchors" - sentences that are disproportionately important for specific reasoning steps, not just sentences that happen to be early in the trace.

Exercise - find top receiver heads

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
> You should spend up to 10-15 minutes on this exercise.

Receiver heads are identified by high kurtosis in their vertical attention scores across multiple problems.

def compute_head_kurtosis(vertical_scores: Float[np.ndarray, " n_sentences"]) -> float:
    """
    Compute kurtosis of vertical attention scores.
    Higher kurtosis = more "spiky" attention = better receiver head candidate.
    """
    raise NotImplementedError("Compute kurtosis of vertical scores")


tests.test_compute_head_kurtosis(compute_head_kurtosis)
Solution
def compute_head_kurtosis(vertical_scores: Float[np.ndarray, " n_sentences"]) -> float:
    """
    Compute kurtosis of vertical attention scores.
    Higher kurtosis = more "spiky" attention = better receiver head candidate.
    """
    # Remove NaN values
    valid_scores = vertical_scores[~np.isnan(vertical_scores)]
    if len(valid_scores) < 4:
        return np.nan
# Compute Fisher's kurtosis (excess kurtosis)
    return stats.kurtosis(valid_scores, fisher=True, bias=True)

Exercise - identify receiver heads from real CoT

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵🔵⚪
> You should spend up to 25-30 minutes on this exercise.

Now let's implement the full receiver head selection pipeline: extract attention from all heads, compute vertical scores, rank by kurtosis, and select top-k heads.

Note: This is computationally expensive (requires forward passes for each layer/head). For a 1.5B model with 28 layers × 12 heads = 336 forward passes. We'll use caching and work with a short example.

def find_receiver_heads(
    text: str,
    sentences: list[str],
    model,
    tokenizer,
    top_k: int = 10,
    proximity_ignore: int = 16,
) -> tuple[Int[np.ndarray, "top_k 2"], Float[np.ndarray, " top_k"]]:
    """
    Identify top-k receiver heads by kurtosis of vertical attention scores.

    Args:
        text: Full reasoning trace
        sentences: List of sentences in the trace
        model: Loaded model
        tokenizer: Corresponding tokenizer
        top_k: Number of top receiver heads to return
        proximity_ignore: Proximity parameter for vertical scores

    Returns:
        receiver_heads: Shape (top_k, 2) array of (layer, head) pairs
        kurtosis_scores: Shape (top_k,) array of kurtosis values for each head
    """
    raise NotImplementedError("Implement receiver head identification")


# Load example CoT
text_full, sentences_full, _ = utils.get_whitebox_example_data(problem_data)

n_chunks = 74
sentences_subset = sentences_full[:n_chunks]
# Find where the n_chunks-th sentence ends in the original text to extract the correct substring
# This preserves the original text formatting and ensures tokenization consistency
end_char_pos = 0
for sent in sentences_subset:
    # Find this sentence in the text starting from where we left off
    sent_pos = text_full.find(sent, end_char_pos)
    if sent_pos == -1:
        # Try with stripped version
        sent_pos = text_full.find(sent.strip(), end_char_pos)
    if sent_pos != -1:
        end_char_pos = sent_pos + len(sent)

text_subset = text_full[:end_char_pos]

print(f"Analyzing first {len(sentences_subset)} sentences...")
print(f"Text length: {len(text_subset)} characters")

# Find top-20 receiver heads (paper recommends 16-32)
torch.cuda.empty_cache()
receiver_heads, receiver_kurts = find_receiver_heads(
    text_subset,
    sentences_subset,
    model,
    tokenizer,
    top_k=20,
    proximity_ignore=4,
)

print(f"\nTop-{len(receiver_heads)} Receiver Heads:")
for i, ((layer, head), kurt) in enumerate(zip(receiver_heads, receiver_kurts)):
    print(f"  {i + 1:2d}. Layer {layer:2d}, Head {head:2d} | Kurtosis: {kurt:.3f}")

# Visualize kurtosis distribution
# Recompute kurtosis matrix for visualization
n_layers = len(model.model.layers)
n_heads = model.config.num_attention_heads
kurtosis_viz = np.zeros((n_layers, n_heads))

for layer, head in receiver_heads:
    kurtosis_viz[layer, head] = receiver_kurts[receiver_heads.tolist().index([layer, head])]

fig = px.imshow(
    kurtosis_viz,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    labels=dict(x="Head", y="Layer", color="Kurtosis"),
    title="Top Receiver Heads by Kurtosis",
    aspect="auto",
    width=800,
    height=600,
)
fig.show()
print("Receiver heads should attend 'spikily' to specific important sentences.")
Click to see the expected output
Interpretation: What are receiver heads? Receiver heads are attention heads that: 1. Have high kurtosis in their vertical attention scores across sentences 2. Attend strongly to a few specific sentences rather than uniformly to all past sentences 3. Correlate with sentences identified as "thought anchors" by black-box methods Why they matter: - They reveal the mechanistic implementation of how the model uses key reasoning steps - Different heads may specialize in different types of anchors (planning, backtracking, etc.) - They provide a whitebox validation of black-box importance metrics Key finding from paper: Receiver head scores correlate with counterfactual importance (r ≈ 0.4-0.6), showing that whitebox and blackbox methods identify the same underlying phenomenon.
Solution
def find_receiver_heads(
    text: str,
    sentences: list[str],
    model,
    tokenizer,
    top_k: int = 10,
    proximity_ignore: int = 16,
) -> tuple[Int[np.ndarray, "top_k 2"], Float[np.ndarray, " top_k"]]:
    """
    Identify top-k receiver heads by kurtosis of vertical attention scores.
    Args:
        text: Full reasoning trace
        sentences: List of sentences in the trace
        model: Loaded model
        tokenizer: Corresponding tokenizer
        top_k: Number of top receiver heads to return
        proximity_ignore: Proximity parameter for vertical scores
    Returns:
        receiver_heads: Shape (top_k, 2) array of (layer, head) pairs
        kurtosis_scores: Shape (top_k,) array of kurtosis values for each head
    """
n_layers = len(model.model.layers)
    n_heads = model.config.num_attention_heads
    n_sentences = len(sentences)
print(f"Computing vertical scores for {n_layers} layers × {n_heads} heads = {n_layers * n_heads} heads...")
    print("This may take a few minutes...")
# Get sentence boundaries once (expensive operation)
    boundaries = utils.get_sentence_token_boundaries(text, sentences, tokenizer)
# Store vertical scores for all heads
    all_vert_scores = np.zeros((n_layers, n_heads, n_sentences))
# Extract attention and compute vertical scores for each head
    for layer in tqdm(range(n_layers), desc="Layers"):
        for head in range(n_heads):
            # Extract attention matrix
            attn_matrix, _ = extract_attention_matrix(text, model, tokenizer, layer, head)
# Average to sentence level
            avg_attn = utils.average_attention_by_sentences(attn_matrix, boundaries)
# Compute vertical scores with depth control
            vert_scores = get_vertical_scores(avg_attn, proximity_ignore=proximity_ignore, control_depth=True)
all_vert_scores[layer, head] = vert_scores
torch.cuda.empty_cache()
# Compute kurtosis for each head (over sentences dimension)
    kurtosis_matrix = np.zeros((n_layers, n_heads))
    for layer in range(n_layers):
        for head in range(n_heads):
            kurtosis_matrix[layer, head] = compute_head_kurtosis(all_vert_scores[layer, head])
# Find top-k heads with highest kurtosis
    flat_kurts = kurtosis_matrix.flatten()
    valid_indices = ~np.isnan(flat_kurts)
if valid_indices.sum() < top_k:
        print(f"Warning: Only {valid_indices.sum()} valid heads found")
        top_k = valid_indices.sum()
# Get indices of top-k highest kurtosis values
    valid_kurts = flat_kurts[valid_indices]
    valid_flat_indices = np.where(valid_indices)[0]
top_k_in_valid = np.argpartition(valid_kurts, -top_k)[-top_k:]
    top_k_in_valid = top_k_in_valid[np.argsort(-valid_kurts[top_k_in_valid])]  # Sort descending
top_flat_indices = valid_flat_indices[top_k_in_valid]
# Convert flat indices back to (layer, head) pairs
    receiver_heads = np.array(np.unravel_index(top_flat_indices, (n_layers, n_heads))).T
    receiver_kurts = flat_kurts[top_flat_indices]
return receiver_heads, receiver_kurts
# Load example CoT
text_full, sentences_full, _ = utils.get_whitebox_example_data(problem_data)
n_chunks = 74
sentences_subset = sentences_full[:n_chunks]
# Find where the n_chunks-th sentence ends in the original text to extract the correct substring
# This preserves the original text formatting and ensures tokenization consistency
end_char_pos = 0
for sent in sentences_subset:
    # Find this sentence in the text starting from where we left off
    sent_pos = text_full.find(sent, end_char_pos)
    if sent_pos == -1:
        # Try with stripped version
        sent_pos = text_full.find(sent.strip(), end_char_pos)
    if sent_pos != -1:
        end_char_pos = sent_pos + len(sent)
text_subset = text_full[:end_char_pos]
print(f"Analyzing first {len(sentences_subset)} sentences...")
print(f"Text length: {len(text_subset)} characters")
# Find top-20 receiver heads (paper recommends 16-32)
torch.cuda.empty_cache()
receiver_heads, receiver_kurts = find_receiver_heads(
    text_subset,
    sentences_subset,
    model,
    tokenizer,
    top_k=20,
    proximity_ignore=4,
)
print(f"\nTop-{len(receiver_heads)} Receiver Heads:")
for i, ((layer, head), kurt) in enumerate(zip(receiver_heads, receiver_kurts)):
    print(f"  {i + 1:2d}. Layer {layer:2d}, Head {head:2d} | Kurtosis: {kurt:.3f}")
# Visualize kurtosis distribution
# Recompute kurtosis matrix for visualization
n_layers = len(model.model.layers)
n_heads = model.config.num_attention_heads
kurtosis_viz = np.zeros((n_layers, n_heads))
for layer, head in receiver_heads:
    kurtosis_viz[layer, head] = receiver_kurts[receiver_heads.tolist().index([layer, head])]
fig = px.imshow(
    kurtosis_viz,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    labels=dict(x="Head", y="Layer", color="Kurtosis"),
    title="Top Receiver Heads by Kurtosis",
    aspect="auto",
    width=800,
    height=600,
)
fig.show()
print("Receiver heads should attend 'spikily' to specific important sentences.")

Exercise - compute receiver head scores

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵🔵⚪
> You should spend up to 15-20 minutes on this exercise.

Now that we've identified receiver heads, let's compute the final receiver head score for each sentence by averaging the vertical attention scores across the top-k receiver heads.

def compute_receiver_head_scores(
    text: str,
    sentences: list[str],
    receiver_heads: Int[np.ndarray, "k 2"],
    model,
    tokenizer,
    proximity_ignore: int = 4,
) -> Float[np.ndarray, " n_sentences"]:
    """
    Compute final receiver head scores by averaging vertical scores from top-k heads.

    Args:
        text: Full reasoning trace
        sentences: List of sentences
        receiver_heads: Shape (k, 2) array of (layer, head) pairs
        model: Loaded model
        tokenizer: Corresponding tokenizer
        proximity_ignore: Proximity parameter

    Returns:
        receiver_scores: Shape (n_sentences,) - final importance scores
    """
    raise NotImplementedError("Implement receiver head score computation")


# Use the receiver heads we just found
print(f"Computing scores from {len(receiver_heads)} receiver heads...")

receiver_scores = compute_receiver_head_scores(
    text_subset,
    sentences_subset,
    receiver_heads,
    model,
    tokenizer,
    proximity_ignore=4,
)

print(f"\nReceiver head scores computed for {len(receiver_scores)} sentences")
print(f"Mean score: {np.nanmean(receiver_scores):.4f}")
print(f"Std score: {np.nanstd(receiver_scores):.4f}")

# Visualize
fig = go.Figure()
fig.add_trace(
    go.Bar(
        x=list(range(len(receiver_scores))),
        y=receiver_scores,
        marker_color="indianred",
        hovertemplate="<b>Sentence %{x}</b><br>Score: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
        customdata=[s[:100] for s in sentences_subset],
    )
)
fig.add_hline(
    y=np.nanmean(receiver_scores),
    line_dash="dash",
    line_color="black",
    annotation_text="Mean",
)
fig.update_layout(
    title="Receiver Head Scores (Averaged Across Top-k Heads)",
    xaxis_title="Sentence Index",
    yaxis_title="Receiver Head Score",
    width=800,
    height=400,
)
fig.show()

print("\nTop-3 sentences by receiver head score:")
# Filter out NaN values before selecting top sentences
valid_mask = ~np.isnan(receiver_scores)
valid_indices = np.where(valid_mask)[0]
valid_scores = receiver_scores[valid_mask]
top_k = min(3, len(valid_scores))
top_in_valid = np.argsort(valid_scores)[-top_k:][::-1]
top_indices = valid_indices[top_in_valid]

for idx in top_indices:
    print(f"  Sentence {idx}: {receiver_scores[idx]:.4f}")
    print(f"    Text: {sentences_subset[idx][:100]}...")

print("Receiver head scores computed! These aggregate the attention from all high-kurtosis receiver heads.")
Click to see the expected output
Why average across receiver heads?

Averaging across multiple receiver heads is more robust than using a single head because:

  1. Reduces noise: Individual heads may have idiosyncratic attention patterns
  2. Captures multiple aspects: Different receiver heads may focus on different types of important sentences (e.g., planning vs. backtracking)
  3. More stable: Less sensitive to architectural choices or specific head initialization

The paper found that using the top 16-32 receiver heads gives the most stable correlation with black-box importance metrics.

Solution
def compute_receiver_head_scores(
    text: str,
    sentences: list[str],
    receiver_heads: Int[np.ndarray, "k 2"],
    model,
    tokenizer,
    proximity_ignore: int = 4,
) -> Float[np.ndarray, " n_sentences"]:
    """
    Compute final receiver head scores by averaging vertical scores from top-k heads.
    Args:
        text: Full reasoning trace
        sentences: List of sentences
        receiver_heads: Shape (k, 2) array of (layer, head) pairs
        model: Loaded model
        tokenizer: Corresponding tokenizer
        proximity_ignore: Proximity parameter
    Returns:
        receiver_scores: Shape (n_sentences,) - final importance scores
    """
boundaries = utils.get_sentence_token_boundaries(text, sentences, tokenizer)
all_vert_scores = []
print(f"Computing scores from {len(receiver_heads)} receiver heads...")
    for layer, head in receiver_heads:
        # Extract attention for this receiver head
        attn_matrix, _ = extract_attention_matrix(text, model, tokenizer, layer, head)
# Average to sentence level
        avg_attn = utils.average_attention_by_sentences(attn_matrix, boundaries)
# Compute vertical scores
        vert_scores = get_vertical_scores(avg_attn, proximity_ignore=proximity_ignore, control_depth=True)
all_vert_scores.append(vert_scores)
# Average across all receiver heads
    all_vert_scores = np.array(all_vert_scores)  # Shape: (k, n_sentences)
    receiver_scores = np.nanmean(all_vert_scores, axis=0)  # Average over heads
return receiver_scores
# Use the receiver heads we just found
print(f"Computing scores from {len(receiver_heads)} receiver heads...")
receiver_scores = compute_receiver_head_scores(
    text_subset,
    sentences_subset,
    receiver_heads,
    model,
    tokenizer,
    proximity_ignore=4,
)
print(f"\nReceiver head scores computed for {len(receiver_scores)} sentences")
print(f"Mean score: {np.nanmean(receiver_scores):.4f}")
print(f"Std score: {np.nanstd(receiver_scores):.4f}")
# Visualize
fig = go.Figure()
fig.add_trace(
    go.Bar(
        x=list(range(len(receiver_scores))),
        y=receiver_scores,
        marker_color="indianred",
        hovertemplate="<b>Sentence %{x}</b><br>Score: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
        customdata=[s[:100] for s in sentences_subset],
    )
)
fig.add_hline(
    y=np.nanmean(receiver_scores),
    line_dash="dash",
    line_color="black",
    annotation_text="Mean",
)
fig.update_layout(
    title="Receiver Head Scores (Averaged Across Top-k Heads)",
    xaxis_title="Sentence Index",
    yaxis_title="Receiver Head Score",
    width=800,
    height=400,
)
fig.show()
print("\nTop-3 sentences by receiver head score:")
# Filter out NaN values before selecting top sentences
valid_mask = ~np.isnan(receiver_scores)
valid_indices = np.where(valid_mask)[0]
valid_scores = receiver_scores[valid_mask]
top_k = min(3, len(valid_scores))
top_in_valid = np.argsort(valid_scores)[-top_k:][::-1]
top_indices = valid_indices[top_in_valid]
for idx in top_indices:
    print(f"  Sentence {idx}: {receiver_scores[idx]:.4f}")
    print(f"    Text: {sentences_subset[idx][:100]}...")
print("Receiver head scores computed! These aggregate the attention from all high-kurtosis receiver heads.")

Attention Suppression

If receiver heads are important for propagating information from thought anchors, then suppressing attention to those sentences should change the model's output.

Exercise - implement RoPE attention

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵⚪⚪⚪
> You should spend up to 15-20 minutes on this exercise.

Qwen uses Rotary Position Embeddings (RoPE), which is a method for encoding positional information directly into the attention mechanism. Instead of adding position embeddings to token embeddings (like in vanilla transformers), RoPE rotates the query and key vectors based on their positions.

The RoPE formula:

For a query or key vector split into pairs of dimensions $(x_i, x_{i+1})$, RoPE applies a rotation based on position $m$ and dimension index $i$:

$$ \begin{bmatrix} q_i^m \\ q_{i+1}^m \end{bmatrix} = \begin{bmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{bmatrix} \begin{bmatrix} q_i \\ q_{i+1} \end{bmatrix} $$

where $\theta_i = 10000^{-2i/d}$ is a frequency that decreases with dimension index.

Intuitively, RoPE encodes relative positions into the attention scores. When you compute $Q \cdot K^T$, the dot product naturally captures the distance between positions because both queries and keys have been rotated by position-dependent angles. This means attention scores automatically depend on relative position differences, not absolute positions.

Implementation: The formula above can be simplified. For the full vector (not just pairs), we can write:

$$ \text{RoPE}(x) = x \odot \cos(m\theta) + \text{rotate\_half}(x) \odot \sin(m\theta) $$

where $\odot$ is element-wise multiplication, and rotate_half swaps and negates pairs of dimensions.

Your task: Implement the apply_rotary_pos_emb function below. You need to: 1. Apply the RoPE formula to both the query (q) and key (k) tensors 2. The cos and sin tensors are already computed for you (these are the $\cos(m\theta)$ and $\sin(m\theta)$ terms) 3. Use the provided rotate_half helper function 4. Return the rotated query and key tensors

How to know you're done: The tests below (tests.test_apply_rotary_pos_emb()) will verify that your implementation matches the expected behavior. Once the tests pass, you've successfully implemented RoPE!

A few notes:

  • We've given you a helper function rotate_half that rotates half the hidden dimensions (swapping pairs and negating the first of each pair), which implements the rotation effect.
  • The reason q_heads and kv_heads are distinguished is because of Grouped Query Attention (GQA): each KV head servers a group of multiple query heads. You don't need to worry much about this for the purposes of our exercises here though.
def rotate_half(x: Float[Tensor, "... d"]) -> Float[Tensor, " ... d"]:
    """Rotates half the hidden dims of the input (for RoPE)."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(
    q: Float[Tensor, "batch q_heads seq head_dim"],
    k: Float[Tensor, "batch kv_heads seq head_dim"],
    cos: Float[Tensor, "batch seq head_dim"],
    sin: Float[Tensor, "batch seq head_dim"],
    position_ids: torch.Tensor | None = None,
    unsqueeze_dim: int = 1,
) -> tuple[Float[Tensor, "batch q_heads seq head_dim"], Float[Tensor, "batch kv_heads seq head_dim"]]:
    """Applies Rotary Position Embedding to query and key tensors."""

    del position_ids  # unused

    raise NotImplementedError()


tests.test_rotary_pos_emb(apply_rotary_pos_emb)
Solution
def rotate_half(x: Float[Tensor, "... d"]) -> Float[Tensor, " ... d"]:
    """Rotates half the hidden dims of the input (for RoPE)."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
    q: Float[Tensor, "batch q_heads seq head_dim"],
    k: Float[Tensor, "batch kv_heads seq head_dim"],
    cos: Float[Tensor, "batch seq head_dim"],
    sin: Float[Tensor, "batch seq head_dim"],
    position_ids: torch.Tensor | None = None,
    unsqueeze_dim: int = 1,
) -> tuple[Float[Tensor, "batch q_heads seq head_dim"], Float[Tensor, "batch kv_heads seq head_dim"]]:
    """Applies Rotary Position Embedding to query and key tensors."""
del position_ids  # unused
cos = cos.unsqueeze(unsqueeze_dim)  # Shape: (batch, 1, seq, head_dim)
    sin = sin.unsqueeze(unsqueeze_dim)  # Shape: (batch, 1, seq, head_dim)
    q_embed = (q  cos) + (rotate_half(q)  sin)
    k_embed = (k  cos) + (rotate_half(k)  sin)
    return q_embed, k_embed
tests.test_rotary_pos_emb(apply_rotary_pos_emb)

Exercise - compute suppressed attention

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
> You should spend up to 15-25 minutes on this exercise

Now you'll implement the core logic for attention suppression. The function compute_suppressed_attention takes query and key states and computes attention weights, but with specific token ranges masked out (set to $-\infty$ before softmax, so they get zero attention weight after softmax).

How attention suppression works:

  1. Compute attention scores: $\text{scores} = QK^T / \sqrt{d_k}$
  2. Before softmax, mask specific token positions by setting their attention scores to $-\infty$ (or the minimum float value)
  3. Apply the attention mask if provided
  4. Apply softmax to get attention weights

Why mask before softmax? Setting attention scores to $-\infty$ before softmax ensures they become 0 after softmax, effectively removing those positions from the attention computation.

Key arguments: - token_ranges: List of (start, end) tuples indicating which tokens to suppress - heads_mask: If provided, only suppress in specific attention heads (otherwise suppress in all heads) - attention_mask: Standard causal/padding mask to apply after suppression

Fill in the function below. The main logic you need to implement is: 1. Compute base attention scores (scaled dot product) 2. For each token range, set attention scores to minimum value 3. Apply attention mask if provided 4. Apply softmax

Hint: Use torch.finfo(attn_weights.dtype).min to get the minimum value for the tensor's dtype.

def compute_suppressed_attention(
    query_states: Float[Tensor, "batch heads seq head_dim"],
    key_states: Float[Tensor, "batch heads seq head_dim"],
    token_ranges: list[tuple[int, int]],
    head_dim: int,  # for scaling attention weights by sqrt of!
    query_len: int,  # in case token_ranges exceed query length
    attention_mask: Float[Tensor, "batch 1 seq seq"] | None = None,
    heads_mask: Int[Tensor, " heads"] | None = None,  # shape (num_heads,) with True for heads to suppress
) -> Float[Tensor, "batch heads seq seq"]:
    # YOUR CODE HERE - compute attention scores, apply suppression, apply attention mask, then softmax


tests.test_compute_suppressed_attention(compute_suppressed_attention)
Solution
def compute_suppressed_attention(
    query_states: Float[Tensor, "batch heads seq head_dim"],
    key_states: Float[Tensor, "batch heads seq head_dim"],
    token_ranges: list[tuple[int, int]],
    head_dim: int,  # for scaling attention weights by sqrt of!
    query_len: int,  # in case token_ranges exceed query length
    attention_mask: Float[Tensor, "batch 1 seq seq"] | None = None,
    heads_mask: Int[Tensor, " heads"] | None = None,  # shape (num_heads,) with True for heads to suppress
) -> Float[Tensor, "batch heads seq seq"]:
    # Compute attention scores
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
# Apply suppression mask BEFORE softmax (key step!)
    for start, end in token_ranges:
        effective_start = min(start, query_len)
        effective_end = min(end, query_len)
        if effective_start < effective_end:
            mask_value = torch.finfo(attn_weights.dtype).min
            if heads_mask is None:
                # Suppress in all heads
                attn_weights[..., effective_start:effective_end] = mask_value
            else:
                # Suppress only in specific heads
                attn_weights[:, heads_mask, :, effective_start:effective_end] = mask_value
# Apply attention mask if provided
    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask
# Softmax
    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    return attn_weights
tests.test_compute_suppressed_attention(compute_suppressed_attention)

Exercise - understand full attention suppression hook

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
> You should spend up to 10-20 minutes reading this and making sure you understand how it works.

Now read and understand the full attention suppression pipeline below. This code patches the forward method of Qwen's attention modules to use your apply_rotary_pos_emb and compute_suppressed_attention functions.

What this function does:

  1. Finds attention modules: Searches through the model for self_attn modules in each layer
  2. Stores original forward methods: Saves them so we can restore later
  3. Creates masked forward functions: For each attention module, creates a new forward method that:
  4. Projects to Q, K, V
  5. Applies RoPE (using your apply_rotary_pos_emb)
  6. Computes suppressed attention (using your compute_suppressed_attention)
  7. Applies attention to values and projects output
  8. Replaces forward methods: Uses MethodType to bind the new forward method to each module

Key implementation details:

  • layer_to_heads allows suppressing only specific heads in specific layers (if None, suppresses all heads in all layers)
  • The function uses closures (create_masked_forward) to capture layer-specific parameters
  • repeat_kv handles grouped-query attention where key/value heads are fewer than query heads

How to use it:

# Apply suppression to tokens 10-20 in all heads
suppression_info = apply_qwen_attention_suppression(model, token_ranges=[(10, 20)])

# Run model with suppression active
output = model(inputs)

# Restore original behavior
remove_qwen_attention_suppression(model, suppression_info)

There's no code to fill in here, just read through the implementation and make sure you understand: - How the two functions you implemented are used - How the patching mechanism works - What role token_ranges and layer_to_heads play

def repeat_kv(
    hidden_states: Float[Tensor, "batch kv_heads seq head_dim"], n_rep: int
) -> Float[Tensor, "batch heads seq head_dim"]:
    """Expands key/value tensors for grouped-query attention."""
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def apply_qwen_attention_suppression(
    model,
    token_ranges: list[tuple[int, int]] | tuple[int, int],
    layer_to_heads: dict[int, list[int]] | None = None,
) -> dict[str, Any]:
    """
    Suppresses attention to specific token positions by replacing forward methods.

    Args:
        model: The model to apply suppression to
        token_ranges: Token range(s) to suppress - single tuple or list of tuples
        layer_to_heads: Dict mapping layer indices to lists of head indices to suppress (None = all heads)

    Returns:
        Dict with 'original_forwards' for restoration

    Adapted from thought-anchors: whitebox-analyses/pytorch_models/hooks.py
    """
    raise NotImplementedError("Implement Qwen attention suppression")


def remove_qwen_attention_suppression(model, suppression_info: dict[str, Any]):
    """Restores original forward methods after attention suppression."""
    original_forwards = suppression_info.get("original_forwards", {})
    if not original_forwards:
        return

    for name, module in model.named_modules():
        if name in original_forwards:
            module.forward = original_forwards[name]
Solution
def repeat_kv(
    hidden_states: Float[Tensor, "batch kv_heads seq head_dim"], n_rep: int
) -> Float[Tensor, "batch heads seq head_dim"]:
    """Expands key/value tensors for grouped-query attention."""
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def apply_qwen_attention_suppression(
    model,
    token_ranges: list[tuple[int, int]] | tuple[int, int],
    layer_to_heads: dict[int, list[int]] | None = None,
) -> dict[str, Any]:
    """
    Suppresses attention to specific token positions by replacing forward methods.
    Args:
        model: The model to apply suppression to
        token_ranges: Token range(s) to suppress - single tuple or list of tuples
        layer_to_heads: Dict mapping layer indices to lists of head indices to suppress (None = all heads)
    Returns:
        Dict with 'original_forwards' for restoration
    Adapted from thought-anchors: whitebox-analyses/pytorch_models/hooks.py
    """
    # Normalize token_ranges to list of tuples
    if isinstance(token_ranges, tuple):
        token_ranges = [token_ranges]
# Find rotary embedding module
    rotary_emb_module = None
    if hasattr(model, "model") and hasattr(model.model, "rotary_emb"):
        rotary_emb_module = model.model.rotary_emb
# Find attention modules to patch
    target_modules = []
    for name, module in model.named_modules():
        if name.startswith("model.layers") and name.endswith("self_attn"):
            try:
                layer_idx = int(name.split(".")[2])
                if layer_to_heads is None or layer_idx in layer_to_heads:
                    if all(hasattr(module, attr) for attr in ["config", "q_proj", "k_proj", "v_proj", "o_proj"]):
                        target_modules.append((name, module, layer_idx))
            except (IndexError, ValueError):
                continue
if not target_modules:
        print("Warning: No Qwen attention modules found to patch")
        return {"original_forwards": {}}
# Store original forward methods
    original_forwards = {}
# Create and apply masked forward functions
    for name, attn_module, layer_idx in target_modules:
        original_forwards[name] = attn_module.forward
        heads_mask = layer_to_heads[layer_idx] if layer_to_heads is not None else None
# Create masked forward function
        def create_masked_forward(orig_forward, layer_idx, rotary_ref, heads_mask):
            def masked_forward(
                self,
                hidden_states: torch.Tensor,
                attention_mask: torch.Tensor | None = None,
                position_ids: torch.LongTensor | None = None,
                past_key_value: tuple[torch.Tensor] | None = None,
                output_attentions: bool = False,
                use_cache: bool = False,
                cache_position: torch.LongTensor | None = None,
                **kwargs,
            ) -> tuple[torch.Tensor, torch.Tensor | None]:
                bsz, q_len, _ = hidden_states.size()
                config = self.config
                device = hidden_states.device
# Project to Q, K, V
                query_states = self.q_proj(hidden_states)
                key_states = self.k_proj(hidden_states)
                value_states = self.v_proj(hidden_states)
# Reshape for multi-head attention
                num_heads = config.num_attention_heads
                head_dim = config.hidden_size // num_heads
                num_key_value_heads = config.num_key_value_heads
                num_key_value_groups = num_heads // num_key_value_heads
query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
                key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
                value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
# Apply RoPE
                if position_ids is None:
                    position_ids = torch.arange(0, q_len, dtype=torch.long, device=device).unsqueeze(0)
if rotary_ref is not None and callable(rotary_ref):
                    cos, sin = rotary_ref(value_states, position_ids=position_ids)
                    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# Repeat K/V for grouped-query attention
                key_states = repeat_kv(key_states, num_key_value_groups)
                value_states = repeat_kv(value_states, num_key_value_groups)
attn_weights = compute_suppressed_attention(
                    query_states,
                    key_states,
                    token_ranges,
                    head_dim,
                    q_len,
                    attention_mask,
                    heads_mask,
                )
# Apply attention to values
                attn_output = torch.matmul(attn_weights, value_states)
# Reshape and project output
                attn_output = attn_output.transpose(1, 2).contiguous()
                attn_output = attn_output.reshape(bsz, q_len, config.hidden_size)
                attn_output = self.o_proj(attn_output)
return attn_output, attn_weights if output_attentions else None
return masked_forward
attn_module.forward = MethodType(
            create_masked_forward(attn_module.forward, layer_idx, rotary_emb_module, heads_mask), attn_module
        )
return {"original_forwards": original_forwards}
def remove_qwen_attention_suppression(model, suppression_info: dict[str, Any]):
    """Restores original forward methods after attention suppression."""
    original_forwards = suppression_info.get("original_forwards", {})
    if not original_forwards:
        return
for name, module in model.named_modules():
        if name in original_forwards:
            module.forward = original_forwards[name]

Exercise - measure suppression effect with KL divergence

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
> You should spend up to 15-20 minutes on this exercise.

For a real model, we'd measure how suppressing attention to each sentence changes the output distribution.

def compute_suppression_kl(
    original_logits: Float[np.ndarray, " vocab"],
    suppressed_logits: Float[np.ndarray, " vocab"],
    temperature: float = 1.0,
) -> float:
    """
    Compute KL divergence between original and suppressed output distributions.

    Args:
        original_logits: Logits from original forward pass
        suppressed_logits: Logits from suppressed forward pass
        temperature: Temperature for softmax

    Returns:
        KL divergence (in nats)
    """

    raise NotImplementedError("Compute KL divergence between distributions")


tests.test_compute_suppression_kl(compute_suppression_kl)
Solution
def compute_suppression_kl(
    original_logits: Float[np.ndarray, " vocab"],
    suppressed_logits: Float[np.ndarray, " vocab"],
    temperature: float = 1.0,
) -> float:
    """
    Compute KL divergence between original and suppressed output distributions.
    Args:
        original_logits: Logits from original forward pass
        suppressed_logits: Logits from suppressed forward pass
        temperature: Temperature for softmax
    Returns:
        KL divergence (in nats)
    """
def softmax(x, temp):
        x = x / temp
        exp_x = np.exp(x - np.max(x))
        return exp_x / exp_x.sum()
p = softmax(original_logits, temperature)
    q = softmax(suppressed_logits, temperature)
# Add small epsilon for numerical stability
    eps = 1e-10
    return np.sum(p * np.log((p + eps) / (q + eps)))

Testing Attention Suppression

Now let's properly test attention suppression using forward method replacement.

Key insight: We need to mask attention scores BEFORE softmax, not after. The implementation above replaces the entire forward method to intervene at the right point.

Expected result: Suppressing high-importance sentences should cause larger KL divergence than suppressing low-importance sentences.

# Select high and low importance sentences
valid_mask = ~np.isnan(receiver_scores)
valid_indices = np.where(valid_mask)[0]
valid_scores = receiver_scores[valid_mask]

high_idx = valid_indices[np.argmax(valid_scores)]
low_idx = valid_indices[np.argmin(valid_scores)]

print(f"\nHigh importance sentence {high_idx}: score={receiver_scores[high_idx]:.4f}")
print(f"  Text: {sentences_subset[high_idx][:100]}...")
print(f"\nLow importance sentence {low_idx}: score={receiver_scores[low_idx]:.4f}")
print(f"  Text: {sentences_subset[low_idx][:100]}...")

# Get token boundaries
high_sent_range = boundaries[high_idx]
low_sent_range = boundaries[low_idx]

# Prepare input
inputs = tokenizer(text_subset, return_tensors="pt")
if model.device.type == "cuda":
    inputs = {k: v.cuda() for k, v in inputs.items()}

# Original forward pass (no suppression)
with torch.no_grad():
    original_output = model(**inputs)
    original_logits = original_output.logits[0, -1].cpu().numpy()

layer_to_heads = {}
for layer_idx, head_idx in receiver_heads[:5]:
    if layer_idx not in layer_to_heads:
        layer_to_heads[layer_idx] = []
    layer_to_heads[layer_idx].append(int(head_idx))

# Test 1: Suppress high-importance sentence in top receiver heads
suppression_info = apply_qwen_attention_suppression(model, high_sent_range, layer_to_heads)

with torch.no_grad():
    suppressed_high_output = model(**inputs)
    suppressed_high_logits = suppressed_high_output.logits[0, -1].cpu().numpy()

remove_qwen_attention_suppression(model, suppression_info)

kl_high = compute_suppression_kl(original_logits, suppressed_high_logits)

# Test 2: Suppress low-importance sentence
suppression_info = apply_qwen_attention_suppression(model, low_sent_range, layer_to_heads)

with torch.no_grad():
    suppressed_low_output = model(**inputs)
    suppressed_low_logits = suppressed_low_output.logits[0, -1].cpu().numpy()

remove_qwen_attention_suppression(model, suppression_info)

kl_low = compute_suppression_kl(original_logits, suppressed_low_logits)

# Compare results
print(f"\nKL divergence when suppressing HIGH importance sentence: {kl_high:.4e}")
print(f"KL divergence when suppressing LOW importance sentence:  {kl_low:.4e}")
print(f"Ratio (high/low): {kl_high / (kl_low + 1e-10):.2f}x")

if kl_high > kl_low * 10:  # At least 1.5x larger
    print("\nSuccess! Suppressing high-importance sentences causes larger output changes")
    print("  This causally validates that receiver head scores identify important sentences")
Click to see the expected output
High importance sentence 68: score=0.9711 Text: Wait, but 66666 in hex is 5 digits, so 5 * 4 = 20 bits....

Low importance sentence 58: score=0.2037 Text: 419,424 + 6 = 419,430....

KL divergence when suppressing HIGH importance sentence: 8.7300e-05 KL divergence when suppressing LOW importance sentence: 4.0152e-07 Ratio (high/low): 217.37x

Success! Suppressing high-importance sentences causes larger output changes This causally validates that receiver head scores identify important sentences

What this demonstrates:

The attention suppression test provides causal validation of the receiver head scores: 1. We suppress attention to specific sentences in the receiver heads 2. We measure how much the model's output distribution changes (KL divergence) 3. High-importance sentences (high receiver scores) cause larger changes when suppressed

This validates that: - Receiver head scores correctly identify important sentences - These sentences causally affect the model's output (not just correlational) - The attention mechanism is the actual pathway for information flow

Implementation details: - We replace the forward method of attention modules (not post-hooks!) - We mask attention scores BEFORE softmax by setting them to -inf - This works specifically for Qwen/Qwen2 architecture with the helper functions provided - For other architectures, you'd need to adapt the forward method replacement

Note: This approach is from the thought-anchors paper's whitebox analysis code.

Bonus Exercises: Replicating Paper Whitebox Results

The correlation analyses above produce weak results on a single reasoning trace — this is expected, and reflects a broader methodological point: robust claims about thought anchors require aggregating evidence across many problems. The paper's whitebox experiments run over a large dataset of math rollouts (available on HuggingFace) and the key results replicate across two model families (Qwen-14B and Llama-8B).

The bonus exercises below each target a specific figure from the paper. Each has a starting point using the infrastructure already built in this notebook, and a stretch goal that scales up to use the full dataset and the paper's code from the whitebox-analyses/ directory.

Exercise - Replicate the sentence-to-sentence suppression KL matrix (Figure 6)

Difficulty: 🔴🔴🔴🔴🔴
Importance: 🔵🔵🔵🔵🔵

The paper's most direct causal whitebox result is a sentence-to-sentence suppression matrix: for each pair of sentences $(j, i)$, suppress attention to sentence $j$ across all attention heads, then measure the mean KL divergence of the model's token-level predictions within sentence $i$. This produces an $N \times N$ matrix where entry $[i, j]$ captures how much sentence $j$ causally influences the predictions at positions in sentence $i$. This is the whitebox complement of the blackbox counterfactual importance computed earlier in the notebook, and Figure 6 of the paper visualizes it as a heatmap for a case study.

Starting point (single problem): Using the model and suppression tools from this notebook:

  1. Run a baseline forward pass and collect all token-level logits (not just the final position) — you will need the full logit tensor, not just the last position slice.
  2. For each sentence $j$, call apply_qwen_attention_suppression across all attention heads (set layer_to_heads to include every head in every layer), run a forward pass to collect token-level logits, then restore with remove_qwen_attention_suppression. This is compute-intensive — start with a subset of sentences.
  3. For each pair $(j, i)$, compute the mean KL divergence between baseline and suppressed logits over all token positions that fall within sentence $i$'s token boundaries (use compute_suppression_kl from earlier in the notebook). Store results in an $N \times N$ matrix.
  4. Mask the upper triangle and diagonal with np.nan (future sentences cannot be causally affected by later ones), then visualize as a heatmap. Try subtracting each row's mean to normalize for sentence depth. Use a white-to-red colormap.

What to look for: Thought anchors appear as columns with elevated KL values stretching down across many rows — they causally affect many later sentences. These columns typically correspond to plan generation or explicit backtracking steps. Compare the column positions to the receiver head scores computed earlier and the blackbox counterfactual importances from Section 2.

Stretch goal (full dataset): The paper's implementation is get_suppression_KL_matrix in whitebox-analyses/attention_analysis/attn_supp_funcs.py. It uses nucleus-sampling logit compression (p_nucleus=0.9999) to store only the top-p logits efficiently. Run it on 10–20 problems and average the row-normalized matrices. The plot_suppression_matrix.py script handles visualization. Compare your result to Figure 6 in the paper.

Exercise - Identify receiver heads via kurtosis and verify split-half reliability (Figure 4)

Difficulty: 🔴🔴🔴🔴🔴
Importance: 🔵🔵🔵🔵🔵

The paper identifies "receiver heads" by computing the kurtosis of each attention head's vertical attention scores across a set of reasoning traces. The vertical score for sentence $s$ under head $(l, h)$ is the mean attention weight that sentence $s$ receives from all subsequent token positions, after ignoring a local proximity window (to filter out trivially local attention patterns). High kurtosis means that a head's attention is concentrated on a small number of "anchor" sentences rather than spread broadly. The paper finds that only a small fraction of attention heads have high kurtosis, these heads are consistent across problems, and they correspond to the receiver heads identified qualitatively earlier in the notebook.

Starting point (single problem): Using the current reasoning trace and loaded model:

  1. For each attention head $(l, h)$, extract the full attention matrix and compute sentence-level vertical scores: for each sentence $s$, average the attention weights directed toward sentence $s$ from all token positions that are at least proximity_ignore=16 positions after sentence $s$'s end boundary. The paper implements this in get_vertical_scores in whitebox-analyses/attention_analysis/attn_funcs.py.
  2. Compute scipy.stats.kurtosis of the resulting per-sentence score array for each head.
  3. Produce a scatter plot of kurtosis versus layer index. You should see a small number of outlier heads with substantially higher kurtosis than the bulk of heads (which cluster near zero). These are candidate receiver heads.
  4. Check whether the top-kurtosis heads from this single trace agree with the receiver head set used in this notebook (which was pre-computed from many traces).

Stretch goal — split-half reliability (paper result: r = 0.84): Replicate the paper's reliability analysis across many reasoning traces:

  1. Download the dataset from HuggingFace (uzaymacar/math-rollouts) and process a set of problems using get_all_problems_vert_scores from the paper's receiver_head_funcs.py.
  2. Split problems into two halves (odd/even indices). For each half, average per-head kurtosis using get_3d_ar_kurtosis.
  3. Scatter plot first-half vs. second-half mean kurtosis (one point per head). Compute Pearson correlation using scipy.stats.pearsonr. The paper reports r = 0.84, showing the same heads consistently function as receivers across different reasoning problems.
  4. Select the top-$k$ heads by mean kurtosis (paper uses $k = 20$–$32$) as the receiver head set. Note which layers these heads concentrate in.

Exercise - Receiver head scores by sentence taxonomic category (Figure 5)

Difficulty: 🔴🔴🔴🔴🔴
Importance: 🔵🔵🔵🔵🔵

The paper's most interpretable whitebox result is that receiver head attention is not uniformly distributed across sentence types. Plan generation and uncertainty management sentences receive significantly higher receiver-head scores than active computation sentences (p < 0.001 for all pairwise comparisons across many problems). This directly links the functional role of a sentence (planning, pivoting) to its mechanistic role (being selectively attended to by specialized receiver heads), providing mechanistic evidence for the concept of thought anchors.

The sentence categories used in the paper are: plan_generation, active_computation, fact_retrieval, uncertainty_management, result_consolidation, self_checking, problem_setup, and final_answer_emission. Each sentence is labeled by an LLM auto-labeler using a structured prompt (DAG_PROMPT in prompts.py in the paper's repository).

Starting point (single problem):

  1. Use an LLM (e.g. via the Anthropic API) with the DAG_PROMPT from the paper's prompts.py to classify each sentence in the reasoning trace from this notebook into one or more function tags. The prompt asks the model to act as a "chain-of-thought function tagger" and assign categories based on the sentence's role in the reasoning.
  2. Compute receiver head scores for each sentence using compute_receiver_head_scores (already defined in this notebook).
  3. Group sentences by their assigned function tag, compute mean receiver-head score per group, and plot a bar chart. Even on a single trace, plan generation and uncertainty management sentences should score noticeably higher than active computation sentences.

Stretch goal — statistical validation across many problems (paper result: p < 0.001):

  1. Use the paper's generate_rec_csvs.py script to produce CSV files containing receiver head scores and taxonomic labels for every problem in the dataset.
  2. For each problem, aggregate receiver head scores by sentence category (per-problem median, to avoid pseudoreplication across sentences within a problem).
  3. Run paired t-tests using scipy.stats.ttest_rel, comparing pairs of categories using only problems that have examples of both categories. Exclude final_answer_emission and problem_setup from pairwise comparisons.
  4. Plot boxplots stratified by category (as in plot_rec_taxonomy.py from the paper code). Compare to Figure 5 of the paper. The result provides statistical support for the core claim: plan generation and uncertainty management sentences are the primary thought anchors — they are the steps that most influence downstream reasoning, both causally (blackbox) and via attention (whitebox).