3️⃣ White-box Methods
Learning Objectives
- Understand receiver heads: attention heads that aggregate information from reasoning steps into the final answer
- Compute vertical attention scores: quantifying how much each token position attends to specific reasoning steps
- Implement RoPE (Rotary Position Embeddings) to understand positional encoding in attention
- Build attention suppression interventions to causally validate which attention patterns matter
- Compare white-box attention metrics with black-box importance scores to validate mechanistic understanding
In Section 2, we used black-box methods (resampling, counterfactuals) to identify which sentences are important for the final answer. These told us what matters, but not why or how the model implements this mechanistically.
Now we'll open the black box and look inside the model. The key question: if certain sentences are behaviorally important (high counterfactual importance), can we find attention heads that attend strongly to those same sentences? If so, we've found a mechanistic explanation for the black-box patterns.
The hypothesis is that attention is the mechanism by which information from earlier sentences flows to later positions. If a sentence is important for the final answer, we should see attention heads at later positions attending heavily to it. These are called receiver heads. In this section, we'll compute vertical attention scores (how much future tokens attend to each sentence), identify receiver heads (those with spiky attention, i.e. high kurtosis), compare attention patterns to black-box importance scores, and use attention suppression to causally validate the patterns.
By the end, we'll have shown that the black-box patterns (which sentences matter) correspond to white-box mechanisms (which sentences get attended to).
First, let's define a helper function for getting whitebox data from our loaded problem_data. We only ever need the full CoT & chunks plus the counterfactual importance scores later on (for comparison with whitebox analysis).
Receiver Head Analysis
Two key concepts: the vertical attention score for a sentence measures how much all future sentences attend to it, and receiver heads are attention heads with high kurtosis in their vertical attention scores (they attend spikily to specific sentences rather than uniformly).
For these exercises, we'll use a few helper functions: utils.get_whitebox_example_data() extracts a tuple from problem_data containing the full CoT reasoning trace, list of chunks, and counterfactual importance scores. utils.get_sentence_token_boundaries() maps sentences to token positions, and utils.average_attention_by_sentences() aggregates token-level attention to sentence-level.
Exercise - extract attention matrices
The first step in whitebox analysis is extracting attention patterns from the model's forward pass. You can use output_attentions=True to get attention weights, accessible as output.attentions[layer][batch_idx, head_idx] (note this requires setting the model's config to output attention - done in the code cell below where the model gets loaded).
extract_attention_matrix should return the attention weights for a particular layer & head, plus the list of string tokens for convenience.
def extract_attention_matrix(
text: str,
model,
tokenizer,
layer: int,
head: int,
) -> tuple[Float[np.ndarray, "seq seq"], list[str]]:
"""
Extract attention matrix from a specific layer and head using hooks.
Args:
text: Input text to analyze
model: HuggingFace model (already loaded)
tokenizer: Corresponding tokenizer
layer: Which layer to extract from (0-indexed)
head: Which attention head to extract from (0-indexed)
Returns:
attention_matrix: Shape (seq_len, seq_len) attention weights for the specified head
tokens: List of token strings for visualization
"""
raise NotImplementedError("Implement attention extraction using hooks")
# Load model and tokenizer (with attention output enabled - important!)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_1B)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_1B, device_map="auto", dtype=dtype)
# Flash Attention (the default for efficiency) fuses the attention computation into a single GPU
# kernel and never materializes the full attention matrix, so it can't return attention weights.
# We switch to "eager" mode which computes attention weights explicitly, allowing us to extract
# and analyze them for the whitebox methods in this section.
model.config._attn_implementation = "eager"
model.config.output_attentions = True
model.eval()
# Test with a short example
test_text = "The cat sat on the mat. It was sleeping."
# Extract attention from middle layer
n_layers = len(model.model.layers)
layer_to_check = n_layers // 2
head_to_check = 8 # head 9 is better for the 1.5b model
print(f"\nExtracting attention from head L{layer_to_check}H{head_to_check}")
attention_matrix, tokens = extract_attention_matrix(
test_text, model, tokenizer, layer=layer_to_check, head=head_to_check
)
assert attention_matrix.shape[0] == len(tokens), "Shape mismatch"
assert attention_matrix.shape[0] == attention_matrix.shape[1], "Not square"
assert np.allclose(attention_matrix.sum(axis=1), 1.0, atol=1e-5), "Rows don't sum to 1"
# Check causal structure (upper triangle should be mostly zeros)
upper_triangle_sum = np.triu(attention_matrix, k=1).sum()
assert upper_triangle_sum < 1e-5, "Upper triangle has non-zero values, not causal attention"
# Visualize
display(
cv.attention.attention_heads(
attention=attention_matrix[None],
tokens=tokens,
)
)
Click to see the expected output
Hint #1
Set model.config.output_attentions = True before the forward pass, then access the attention weights from the model output via output.attentions[layer]. Each element has shape (batch, num_heads, seq_len, seq_len).
Solution
def extract_attention_matrix(
text: str,
model,
tokenizer,
layer: int,
head: int,
) -> tuple[Float[np.ndarray, "seq seq"], list[str]]:
"""
Extract attention matrix from a specific layer and head using hooks.
Args:
text: Input text to analyze
model: HuggingFace model (already loaded)
tokenizer: Corresponding tokenizer
layer: Which layer to extract from (0-indexed)
head: Which attention head to extract from (0-indexed)
Returns:
attention_matrix: Shape (seq_len, seq_len) attention weights for the specified head
tokens: List of token strings for visualization
"""
# Tokenize input
inputs = tokenizer(text, return_tensors="pt")
if model.device.type == "cuda":
inputs = {k: v.cuda() for k, v in inputs.items()}
# Forward pass - hook will capture attention
with t.no_grad():
output = model(**inputs, use_cache=False, output_attentions=True)
# captured_attention[0] has shape (batch_size, num_heads, seq_len, seq_len)
attn_tensor = output.attentions[layer][0, head] # Select batch 0, specific head
# Convert to numpy
attention_matrix = attn_tensor.cpu().numpy().astype(np.float32)
# Get token strings
token_ids = inputs["input_ids"][0].tolist()
tokens = tokenizer.convert_ids_to_tokens(token_ids)
return attention_matrix, tokens
# Load model and tokenizer (with attention output enabled - important!)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_1B)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_1B, device_map="auto", dtype=dtype)
# Flash Attention (the default for efficiency) fuses the attention computation into a single GPU
# kernel and never materializes the full attention matrix, so it can't return attention weights.
# We switch to "eager" mode which computes attention weights explicitly, allowing us to extract
# and analyze them for the whitebox methods in this section.
model.config._attn_implementation = "eager"
model.config.output_attentions = True
model.eval()
# Test with a short example
test_text = "The cat sat on the mat. It was sleeping."
# Extract attention from middle layer
n_layers = len(model.model.layers)
layer_to_check = n_layers // 2
head_to_check = 8 # head 9 is better for the 1.5b model
print(f"\nExtracting attention from head L{layer_to_check}H{head_to_check}")
attention_matrix, tokens = extract_attention_matrix(
test_text, model, tokenizer, layer=layer_to_check, head=head_to_check
)
assert attention_matrix.shape[0] == len(tokens), "Shape mismatch"
assert attention_matrix.shape[0] == attention_matrix.shape[1], "Not square"
assert np.allclose(attention_matrix.sum(axis=1), 1.0, atol=1e-5), "Rows don't sum to 1"
# Check causal structure (upper triangle should be mostly zeros)
upper_triangle_sum = np.triu(attention_matrix, k=1).sum()
assert upper_triangle_sum < 1e-5, "Upper triangle has non-zero values, not causal attention"
# Visualize
display(
cv.attention.attention_heads(
attention=attention_matrix[None],
tokens=tokens,
)
)
Exercise - compute vertical attention scores (Part 1: Simple version)
Vertical attention scores measure how much future sentences attend back to each sentence. This is the core metric for identifying "thought anchors" - sentences that shape downstream reasoning.
In this first part, implement a simple version that computes the basic vertical attention scores without any normalization.
def get_vertical_scores_simple(
avg_attention_matrix: Float[np.ndarray, "n_sentences n_sentences"],
proximity_ignore: int = 4,
) -> Float[np.ndarray, " n_sentences"]:
"""
Compute basic vertical attention scores for each sentence (no normalization).
Args:
avg_attention_matrix: Shape (n_sentences, n_sentences), where entry (i, j)
is the average attention from sentence i to sentence j
proximity_ignore: Ignore this many nearby sentences (to avoid trivial patterns)
Returns:
Array of shape (n_sentences,) with vertical scores
The vertical score for sentence j is the mean attention it receives from
all sentences i where i > j + proximity_ignore.
Adapted from thought-anchors: attn_funcs.py:get_vertical_scores
"""
raise NotImplementedError("Implement basic vertical attention scores")
tests.test_get_vertical_scores_simple(get_vertical_scores_simple)
Solution
def get_vertical_scores_simple(
avg_attention_matrix: Float[np.ndarray, "n_sentences n_sentences"],
proximity_ignore: int = 4,
) -> Float[np.ndarray, " n_sentences"]:
"""
Compute basic vertical attention scores for each sentence (no normalization).
Args:
avg_attention_matrix: Shape (n_sentences, n_sentences), where entry (i, j)
is the average attention from sentence i to sentence j
proximity_ignore: Ignore this many nearby sentences (to avoid trivial patterns)
Returns:
Array of shape (n_sentences,) with vertical scores
The vertical score for sentence j is the mean attention it receives from
all sentences i where i > j + proximity_ignore.
Adapted from thought-anchors: attn_funcs.py:get_vertical_scores
"""
n = avg_attention_matrix.shape[0]
mat = avg_attention_matrix.copy()
# Step 1: Clean matrix - set upper triangle to NaN (can't attend to future)
mat[np.triu_indices_from(mat, k=1)] = np.nan
# Step 2: Remove proximity - set near-diagonal to NaN (ignore nearby sentences)
mat[np.triu_indices_from(mat, k=-proximity_ignore + 1)] = np.nan
# Step 3: Compute vertical scores: mean attention received from future sentences
vert_scores = []
for j in range(n):
# Extract "vertical line" - attention from all future sentences to sentence j
future_attention = mat[j + proximity_ignore :, j]
if len(future_attention) == 0 or np.all(np.isnan(future_attention)):
vert_scores.append(np.nan)
else:
vert_scores.append(np.nanmean(future_attention))
return np.array(vert_scores)
Let's test the simple version on real CoT attention patterns.
# Get example problem data
text, sentences, _ = utils.get_whitebox_example_data(problem_data)
# Extract attention from a middle layer
layer, head = 15, 8
print(f"\nExtracting attention from head L{layer}H{head}...")
token_attention, tokens = extract_attention_matrix(text, model, tokenizer, layer, head)
# Convert to sentence-level attention
boundaries = utils.get_sentence_token_boundaries(text, sentences, tokenizer)
sentence_attention = utils.average_attention_by_sentences(token_attention, boundaries)
# Compute vertical scores (simple version)
vert_scores_simple = get_vertical_scores_simple(sentence_attention, proximity_ignore=4)
# Visualize
hover_texts = [s[:80] + "..." if len(s) > 80 else s for s in sentences]
fig = go.Figure()
fig.add_trace(
go.Bar(
x=list(range(len(vert_scores_simple))),
y=vert_scores_simple,
marker_color="steelblue",
hovertemplate="<b>Sentence %{x}</b><br>Score: %{y:.4f}<br>%{customdata}<extra></extra>",
customdata=hover_texts,
)
)
fig.add_hline(y=0, line_color="black", line_width=0.5)
fig.update_layout(
title="Simple Vertical Scores (No Normalization)",
xaxis_title="Sentence Index",
yaxis_title="Vertical Score",
width=1000,
height=450,
)
fig.show()
Click to see the expected output
Exercise - compute vertical attention scores (Part 2: Full version with depth control)
The challenge is that later sentences have more tokens to attend to than earlier sentences, which can bias scores. The paper addresses this with depth-control normalization using rank-based normalization.
Now implement the full algorithm including depth control and optional z-score normalization.
Hint: Understanding the depth-control normalization
The depth bias problem: sentence 20 has ~20 previous sentences to attend to, while sentence 5 has only ~5. So raw attention values aren't comparable across rows.
The fix uses rank normalization per row:
1. For each row $i$, count the number of valid (non-NaN) positions: n_valid = sum(~isnan(row))
2. Replace each attention value with its rank among valid values in that row
3. Divide by n_valid to normalize to [0, 1]
This way, every row has the same scale regardless of how many positions are available. You can use scipy.stats.rankdata with nan_policy="omit" to rank values while ignoring NaNs.
After depth control, the vertical score computation (step 4) and optional z-scoring (step 5) are the same as in the simple version.
def get_vertical_scores(
avg_attention_matrix: Float[np.ndarray, "n_sentences n_sentences"],
proximity_ignore: int = 4,
control_depth: bool = True,
return_z_scores: bool = False,
) -> Float[np.ndarray, " n_sentences"]:
"""
Compute vertical attention scores for each sentence with optional depth control and z-scoring.
Args:
avg_attention_matrix: Shape (n_sentences, n_sentences), where entry (i, j)
is the average attention from sentence i to sentence j
proximity_ignore: Ignore this many nearby sentences (to avoid trivial patterns)
control_depth: Apply rank normalization to control for depth effects
return_z_scores: If True, return z-score normalized values
Returns:
Array of shape (n_sentences,) with vertical scores
The vertical score for sentence j is the mean attention it receives from
all sentences i where i > j + proximity_ignore.
Depth control (when enabled): Normalizes each row by ranking attention values
and dividing by the count of valid (non-NaN) positions. This ensures fair
comparison between early sentences (fewer tokens to attend to) and late
sentences (many tokens to attend to).
Adapted from thought-anchors: attn_funcs.py:get_vertical_scores
"""
# Step 1: Copy matrix and set upper triangle to NaN (can't attend to future)
# Step 2: Set near-diagonal to NaN using proximity_ignore (ignore nearby sentences)
# Step 3: If control_depth, rank-normalize each row:
# - Count valid (non-NaN) values per row
# - Use stats.rankdata(mat, axis=1, nan_policy="omit") to rank within rows
# - Divide by the per-row valid counts to normalize to [0, 1]
# Step 4: Compute vertical scores (same as simple version)
# Step 5: If return_z_scores, z-score normalize the result
raise NotImplementedError()
tests.test_get_vertical_scores(get_vertical_scores)
Solution
def get_vertical_scores(
avg_attention_matrix: Float[np.ndarray, "n_sentences n_sentences"],
proximity_ignore: int = 4,
control_depth: bool = True,
return_z_scores: bool = False,
) -> Float[np.ndarray, " n_sentences"]:
"""
Compute vertical attention scores for each sentence with optional depth control and z-scoring.
Args:
avg_attention_matrix: Shape (n_sentences, n_sentences), where entry (i, j)
is the average attention from sentence i to sentence j
proximity_ignore: Ignore this many nearby sentences (to avoid trivial patterns)
control_depth: Apply rank normalization to control for depth effects
return_z_scores: If True, return z-score normalized values
Returns:
Array of shape (n_sentences,) with vertical scores
The vertical score for sentence j is the mean attention it receives from
all sentences i where i > j + proximity_ignore.
Depth control (when enabled): Normalizes each row by ranking attention values
and dividing by the count of valid (non-NaN) positions. This ensures fair
comparison between early sentences (fewer tokens to attend to) and late
sentences (many tokens to attend to).
Adapted from thought-anchors: attn_funcs.py:get_vertical_scores
"""
n = avg_attention_matrix.shape[0]
mat = avg_attention_matrix.copy()
# Step 1: Clean matrix - set upper triangle to NaN (can't attend to future)
mat[np.triu_indices_from(mat, k=1)] = np.nan
# Step 2: Remove proximity - set near-diagonal to NaN (ignore nearby sentences)
mat[np.triu_indices_from(mat, k=-proximity_ignore + 1)] = np.nan
# Step 3: Depth control normalization (critical for fair comparison!)
if control_depth:
# Count non-NaN values per row (available positions to attend to)
per_row = np.sum(~np.isnan(mat), axis=1)
# Rank-normalize each row: convert attention values to ranks,
# then divide by number of valid positions
# This puts all rows on the same scale regardless of depth
mat = stats.rankdata(mat, axis=1, nan_policy="omit") / per_row[:, None]
# Step 4: Compute vertical scores: mean attention received from future sentences
vert_scores = []
for j in range(n):
# Extract "vertical line" - attention from all future sentences to sentence j
future_attention = mat[j + proximity_ignore :, j]
if len(future_attention) == 0 or np.all(np.isnan(future_attention)):
vert_scores.append(np.nan)
else:
vert_scores.append(np.nanmean(future_attention))
vert_scores = np.array(vert_scores)
# Step 5: Optional z-score normalization
if return_z_scores:
vert_scores = (vert_scores - np.nanmean(vert_scores)) / np.nanstd(vert_scores)
return vert_scores
Let's test vertical scores on real CoT attention patterns and see the effect of depth control.
# Compute vertical scores WITH and WITHOUT depth control (using z-scores for comparison)
vert_scores_no_control = get_vertical_scores(
sentence_attention, proximity_ignore=4, control_depth=False, return_z_scores=True
)
vert_scores_with_control = get_vertical_scores(
sentence_attention, proximity_ignore=4, control_depth=True, return_z_scores=True
)
# Get top-3 sentences (filtering NaN values)
valid_mask = ~np.isnan(vert_scores_no_control) & ~np.isnan(vert_scores_with_control)
valid_indices = np.where(valid_mask)[0]
top_3_no_control = valid_indices[np.argsort(vert_scores_no_control[valid_mask])[-3:][::-1]]
top_3_with_control = valid_indices[np.argsort(vert_scores_with_control[valid_mask])[-3:][::-1]]
print(f"Top-3 sentences (no depth control): {top_3_no_control}")
print(f"Top-3 sentences (with depth control): {top_3_with_control}")
print(
f"Correlation between methods: {np.corrcoef(vert_scores_no_control[valid_mask], vert_scores_with_control[valid_mask])[0, 1]:.3f}"
)
# Plot comparison (already z-scored from the function)
hover_texts = [s[:80] + "..." if len(s) > 80 else s for s in sentences]
fig = go.Figure()
for name, scores, color in [
("No depth control", vert_scores_no_control, "steelblue"),
("With depth control", vert_scores_with_control, "darkorange"),
]:
fig.add_trace(
go.Bar(
x=list(range(len(scores))),
y=scores,
name=name,
opacity=0.7,
marker_color=color,
hovertemplate="<b>Sentence %{x}</b><br>Z-score: %{y:.2f}<br>%{customdata}<extra></extra>",
customdata=hover_texts,
)
)
fig.add_hline(y=0, line_color="black", line_width=0.5)
fig.update_layout(
title="Vertical Scores Comparison (Z-scored)",
xaxis_title="Sentence Index",
yaxis_title="Vertical Score (z-scored)",
width=1000,
height=450,
barmode="group",
)
fig.show()
Click to see the expected output
Why depth control matters
Without depth control, later sentences appear more important simply because they have more past tokens to attend to. For example: - Sentence 0: Can only be attended to by 19 future sentences - Sentence 15: Can only be attended to by 4 future sentences
This creates a bias where early sentences automatically get higher scores. Depth control (rank normalization) fixes this by normalizing each sentence's attention pattern relative to the number of positions it could attend to.
With depth control enabled, the vertical scores better identify genuine "thought anchors" - sentences that are disproportionately important for specific reasoning steps, not just sentences that happen to be early in the trace.
Kurtosis as a receiver head signal
Receiver heads are identified by high kurtosis in their vertical attention scores across multiple problems. We use kurtosis rather than variance because we want to identify heads with a few very high attention scores (spiky distributions) rather than heads with generally spread-out attention. High kurtosis means the head is highly selective about which previous sentences it attends to.
def compute_head_kurtosis(vertical_scores: Float[np.ndarray, " n_sentences"]) -> float:
"""
Compute kurtosis of vertical attention scores.
Higher kurtosis = more "spiky" attention = better receiver head candidate.
"""
# Remove NaN values
valid_scores = vertical_scores[~np.isnan(vertical_scores)]
if len(valid_scores) < 4:
return np.nan
# Compute Fisher's kurtosis (excess kurtosis)
return stats.kurtosis(valid_scores, fisher=True, bias=True)
tests.test_compute_head_kurtosis(compute_head_kurtosis)
Exercise - identify receiver heads from real CoT
Most of this function is given to you. Your job is to fill in the inner loop: iterate over layers and heads from output.attentions, average each head's attention matrix to the sentence level, and compute vertical scores with depth control. Store the result in all_vert_scores[layer, head].
def find_receiver_heads(
text: str,
sentences: list[str],
model,
tokenizer,
top_k: int = 10,
proximity_ignore: int = 4,
) -> tuple[Int[np.ndarray, "top_k 2"], Float[np.ndarray, " top_k"]]:
"""
Identify top-k receiver heads by kurtosis of vertical attention scores.
Args:
text: Full reasoning trace
sentences: List of sentences in the trace
model: Loaded model
tokenizer: Corresponding tokenizer
top_k: Number of top receiver heads to return
proximity_ignore: Proximity parameter for vertical scores
Returns:
receiver_heads: Shape (top_k, 2) array of (layer, head) pairs
kurtosis_scores: Shape (top_k,) array of kurtosis values for each head
"""
n_layers = len(model.model.layers)
n_heads = model.config.num_attention_heads
n_sentences = len(sentences)
print(f"Computing vertical scores for {n_layers} layers × {n_heads} heads = {n_layers * n_heads} heads...")
# Get sentence boundaries once
boundaries = utils.get_sentence_token_boundaries(text, sentences, tokenizer)
# Single forward pass to get attention weights from ALL layers and heads at once
inputs = tokenizer(text, return_tensors="pt")
if model.device.type == "cuda":
inputs = {k: v.cuda() for k, v in inputs.items()}
with t.no_grad():
output = model(**inputs, use_cache=False, output_attentions=True)
# Store vertical scores for all heads
all_vert_scores = np.zeros((n_layers, n_heads, n_sentences))
# YOUR CODE HERE: fill in `all_vert_scores` by iterating through layers and headers, using
# `output.attentions` to get average sentence-level attention, then get vertical scores
# with depth control.
# Free attention memory
del output
t.cuda.empty_cache()
# Compute kurtosis for each head (over sentences dimension)
kurtosis_matrix = np.zeros((n_layers, n_heads))
for layer in range(n_layers):
for head in range(n_heads):
kurtosis_matrix[layer, head] = compute_head_kurtosis(all_vert_scores[layer, head])
# Find top-k heads with highest kurtosis
flat_kurts = kurtosis_matrix.flatten()
valid_indices = ~np.isnan(flat_kurts)
if valid_indices.sum() < top_k:
print(f"Warning: Only {valid_indices.sum()} valid heads found")
top_k = valid_indices.sum()
# Get indices of top-k highest kurtosis values
valid_kurts = flat_kurts[valid_indices]
valid_flat_indices = np.where(valid_indices)[0]
top_k_in_valid = np.argpartition(valid_kurts, -top_k)[-top_k:]
top_k_in_valid = top_k_in_valid[np.argsort(-valid_kurts[top_k_in_valid])] # Sort descending
top_flat_indices = valid_flat_indices[top_k_in_valid]
# Convert flat indices back to (layer, head) pairs
receiver_heads = np.array(np.unravel_index(top_flat_indices, (n_layers, n_heads))).T
receiver_kurts = flat_kurts[top_flat_indices]
return receiver_heads, receiver_kurts
# Load example CoT
text_full, sentences_full, _ = utils.get_whitebox_example_data(problem_data)
n_chunks = 74
sentences_subset = sentences_full[:n_chunks]
# Find where the n_chunks-th sentence ends in the original text to extract the correct substring
# This preserves the original text formatting and ensures tokenization consistency
end_char_pos = 0
for sent in sentences_subset:
# Find this sentence in the text starting from where we left off
sent_pos = text_full.find(sent, end_char_pos)
if sent_pos == -1:
# Try with stripped version
sent_pos = text_full.find(sent.strip(), end_char_pos)
if sent_pos != -1:
end_char_pos = sent_pos + len(sent)
text_subset = text_full[:end_char_pos]
print(f"Analyzing first {len(sentences_subset)} sentences...")
print(f"Text length: {len(text_subset)} characters")
# Find top-20 receiver heads (paper recommends ~16)
t.cuda.empty_cache()
receiver_heads, receiver_kurts = find_receiver_heads(
text_subset,
sentences_subset,
model,
tokenizer,
top_k=20,
proximity_ignore=4,
)
print(f"\nTop-{len(receiver_heads)} Receiver Heads:")
for i, ((layer, head), kurt) in enumerate(zip(receiver_heads, receiver_kurts)):
print(f" {i + 1:2d}. Layer {layer:2d}, Head {head:2d} | Kurtosis: {kurt:.3f}")
# Visualize kurtosis distribution
# Recompute kurtosis matrix for visualization
n_layers = len(model.model.layers)
n_heads = model.config.num_attention_heads
kurtosis_viz = np.zeros((n_layers, n_heads))
for layer, head in receiver_heads:
kurtosis_viz[layer, head] = receiver_kurts[receiver_heads.tolist().index([layer, head])]
fig = px.imshow(
kurtosis_viz,
color_continuous_scale="RdBu",
color_continuous_midpoint=0.0,
labels=dict(x="Head", y="Layer", color="Kurtosis"),
title="Top Receiver Heads by Kurtosis",
aspect="auto",
width=800,
height=600,
)
fig.show()
print("Receiver heads should attend 'spikily' to specific important sentences.")
Click to see the expected output
Interpretation: What are receiver heads?
Receiver heads are attention heads that have high kurtosis in their vertical attention scores across sentences, meaning they attend strongly to a few specific sentences rather than uniformly to all past sentences. They correlate with sentences identified as "thought anchors" by black-box methods.
Why they matter: they reveal the mechanistic implementation of how the model uses key reasoning steps, different heads may specialize in different types of anchors (planning, backtracking, etc.), and they provide whitebox validation of black-box importance metrics.
Solution
def find_receiver_heads(
text: str,
sentences: list[str],
model,
tokenizer,
top_k: int = 10,
proximity_ignore: int = 4,
) -> tuple[Int[np.ndarray, "top_k 2"], Float[np.ndarray, " top_k"]]:
"""
Identify top-k receiver heads by kurtosis of vertical attention scores.
Args:
text: Full reasoning trace
sentences: List of sentences in the trace
model: Loaded model
tokenizer: Corresponding tokenizer
top_k: Number of top receiver heads to return
proximity_ignore: Proximity parameter for vertical scores
Returns:
receiver_heads: Shape (top_k, 2) array of (layer, head) pairs
kurtosis_scores: Shape (top_k,) array of kurtosis values for each head
"""
n_layers = len(model.model.layers)
n_heads = model.config.num_attention_heads
n_sentences = len(sentences)
print(f"Computing vertical scores for {n_layers} layers × {n_heads} heads = {n_layers * n_heads} heads...")
# Get sentence boundaries once
boundaries = utils.get_sentence_token_boundaries(text, sentences, tokenizer)
# Single forward pass to get attention weights from ALL layers and heads at once
inputs = tokenizer(text, return_tensors="pt")
if model.device.type == "cuda":
inputs = {k: v.cuda() for k, v in inputs.items()}
with t.no_grad():
output = model(**inputs, use_cache=False, output_attentions=True)
# Store vertical scores for all heads
all_vert_scores = np.zeros((n_layers, n_heads, n_sentences))
# Process attention from each layer and head
for layer in tqdm(range(n_layers), desc="Layers"):
layer_attn = output.attentions[layer][0] # shape [n_heads, seq_len, seq_len]
for head in range(n_heads):
attn_matrix = layer_attn[head].cpu().numpy().astype(np.float32)
# Average to sentence level
avg_attn = utils.average_attention_by_sentences(attn_matrix, boundaries)
# Compute vertical scores with depth control
vert_scores = get_vertical_scores(avg_attn, proximity_ignore=proximity_ignore, control_depth=True)
all_vert_scores[layer, head] = vert_scores
# Free attention memory
del output
t.cuda.empty_cache()
# Compute kurtosis for each head (over sentences dimension)
kurtosis_matrix = np.zeros((n_layers, n_heads))
for layer in range(n_layers):
for head in range(n_heads):
kurtosis_matrix[layer, head] = compute_head_kurtosis(all_vert_scores[layer, head])
# Find top-k heads with highest kurtosis
flat_kurts = kurtosis_matrix.flatten()
valid_indices = ~np.isnan(flat_kurts)
if valid_indices.sum() < top_k:
print(f"Warning: Only {valid_indices.sum()} valid heads found")
top_k = valid_indices.sum()
# Get indices of top-k highest kurtosis values
valid_kurts = flat_kurts[valid_indices]
valid_flat_indices = np.where(valid_indices)[0]
top_k_in_valid = np.argpartition(valid_kurts, -top_k)[-top_k:]
top_k_in_valid = top_k_in_valid[np.argsort(-valid_kurts[top_k_in_valid])] # Sort descending
top_flat_indices = valid_flat_indices[top_k_in_valid]
# Convert flat indices back to (layer, head) pairs
receiver_heads = np.array(np.unravel_index(top_flat_indices, (n_layers, n_heads))).T
receiver_kurts = flat_kurts[top_flat_indices]
return receiver_heads, receiver_kurts
# Load example CoT
text_full, sentences_full, _ = utils.get_whitebox_example_data(problem_data)
n_chunks = 74
sentences_subset = sentences_full[:n_chunks]
# Find where the n_chunks-th sentence ends in the original text to extract the correct substring
# This preserves the original text formatting and ensures tokenization consistency
end_char_pos = 0
for sent in sentences_subset:
# Find this sentence in the text starting from where we left off
sent_pos = text_full.find(sent, end_char_pos)
if sent_pos == -1:
# Try with stripped version
sent_pos = text_full.find(sent.strip(), end_char_pos)
if sent_pos != -1:
end_char_pos = sent_pos + len(sent)
text_subset = text_full[:end_char_pos]
print(f"Analyzing first {len(sentences_subset)} sentences...")
print(f"Text length: {len(text_subset)} characters")
# Find top-20 receiver heads (paper recommends ~16)
t.cuda.empty_cache()
receiver_heads, receiver_kurts = find_receiver_heads(
text_subset,
sentences_subset,
model,
tokenizer,
top_k=20,
proximity_ignore=4,
)
print(f"\nTop-{len(receiver_heads)} Receiver Heads:")
for i, ((layer, head), kurt) in enumerate(zip(receiver_heads, receiver_kurts)):
print(f" {i + 1:2d}. Layer {layer:2d}, Head {head:2d} | Kurtosis: {kurt:.3f}")
# Visualize kurtosis distribution
# Recompute kurtosis matrix for visualization
n_layers = len(model.model.layers)
n_heads = model.config.num_attention_heads
kurtosis_viz = np.zeros((n_layers, n_heads))
for layer, head in receiver_heads:
kurtosis_viz[layer, head] = receiver_kurts[receiver_heads.tolist().index([layer, head])]
fig = px.imshow(
kurtosis_viz,
color_continuous_scale="RdBu",
color_continuous_midpoint=0.0,
labels=dict(x="Head", y="Layer", color="Kurtosis"),
title="Top Receiver Heads by Kurtosis",
aspect="auto",
width=800,
height=600,
)
fig.show()
print("Receiver heads should attend 'spikily' to specific important sentences.")
Exercise - compute receiver head scores
Now that we've identified receiver heads, let's compute the final receiver head score for each sentence by averaging the vertical attention scores across the top-k receiver heads.
def compute_receiver_head_scores(
text: str,
sentences: list[str],
receiver_heads: Int[np.ndarray, "k 2"],
model,
tokenizer,
proximity_ignore: int = 4,
) -> Float[np.ndarray, " n_sentences"]:
"""
Compute final receiver head scores by averaging vertical scores from top-k heads.
Args:
text: Full reasoning trace
sentences: List of sentences
receiver_heads: Shape (k, 2) array of (layer, head) pairs
model: Loaded model
tokenizer: Corresponding tokenizer
proximity_ignore: Proximity parameter
Returns:
receiver_scores: Shape (n_sentences,) - final importance scores
"""
raise NotImplementedError("Implement receiver head score computation")
# Use the receiver heads we just found
receiver_scores = compute_receiver_head_scores(
text_subset,
sentences_subset,
receiver_heads,
model,
tokenizer,
proximity_ignore=4,
)
print(f"\nReceiver head scores computed for {len(receiver_scores)} sentences")
print(f"Mean score: {np.nanmean(receiver_scores):.4f}")
print(f"Std score: {np.nanstd(receiver_scores):.4f}")
# With depth-controlled rank normalization, baseline vertical scores center around 0.5
# (since rank/n_valid ≈ 0.5 on average). Thought anchors should be noticeably above this.
mean_score = np.nanmean(receiver_scores)
assert 0.3 < mean_score < 0.7, f"Mean receiver score {mean_score:.3f} unexpectedly far from ~0.5 baseline"
# Visualize
fig = go.Figure()
fig.add_trace(
go.Bar(
x=list(range(len(receiver_scores))),
y=receiver_scores,
marker_color="indianred",
hovertemplate="<b>Sentence %{x}</b><br>Score: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
customdata=[s[:100] for s in sentences_subset],
)
)
fig.add_hline(
y=np.nanmean(receiver_scores),
line_dash="dash",
line_color="black",
annotation_text="Mean",
)
fig.update_layout(
title="Receiver Head Scores (Averaged Across Top-k Heads)",
xaxis_title="Sentence Index",
yaxis_title="Receiver Head Score",
width=800,
height=400,
)
fig.show()
print("\nTop-3 sentences by receiver head score:")
# Filter out NaN values before selecting top sentences
valid_mask = ~np.isnan(receiver_scores)
valid_indices = np.where(valid_mask)[0]
valid_scores = receiver_scores[valid_mask]
top_k = min(3, len(valid_scores))
top_in_valid = np.argsort(valid_scores)[-top_k:][::-1]
top_indices = valid_indices[top_in_valid]
for idx in top_indices:
print(f" Sentence {idx}: {receiver_scores[idx]:.4f}")
print(f" Text: {sentences_subset[idx][:100]}...")
print("Receiver head scores computed! These aggregate the attention from all high-kurtosis receiver heads.")
Click to see the expected output
Why average across receiver heads?
Averaging across multiple receiver heads is more robust than using a single head: it reduces noise from idiosyncratic attention patterns, captures multiple aspects (different heads may focus on different types of important sentences), and is more stable across architectures.
The paper found that using the top 16 receiver heads gives the most stable correlation with black-box importance metrics.
Solution
def compute_receiver_head_scores(
text: str,
sentences: list[str],
receiver_heads: Int[np.ndarray, "k 2"],
model,
tokenizer,
proximity_ignore: int = 4,
) -> Float[np.ndarray, " n_sentences"]:
"""
Compute final receiver head scores by averaging vertical scores from top-k heads.
Args:
text: Full reasoning trace
sentences: List of sentences
receiver_heads: Shape (k, 2) array of (layer, head) pairs
model: Loaded model
tokenizer: Corresponding tokenizer
proximity_ignore: Proximity parameter
Returns:
receiver_scores: Shape (n_sentences,) - final importance scores
"""
boundaries = utils.get_sentence_token_boundaries(text, sentences, tokenizer)
all_vert_scores = []
print(f"Computing scores from {len(receiver_heads)} receiver heads...")
for layer, head in receiver_heads:
# Extract attention for this receiver head
attn_matrix, _ = extract_attention_matrix(text, model, tokenizer, layer, head)
# Average to sentence level
avg_attn = utils.average_attention_by_sentences(attn_matrix, boundaries)
# Compute vertical scores
vert_scores = get_vertical_scores(avg_attn, proximity_ignore=proximity_ignore, control_depth=True)
all_vert_scores.append(vert_scores)
# Average across all receiver heads
all_vert_scores = np.array(all_vert_scores) # Shape: (k, n_sentences)
receiver_scores = np.nanmean(all_vert_scores, axis=0) # Average over heads
return receiver_scores
# Use the receiver heads we just found
receiver_scores = compute_receiver_head_scores(
text_subset,
sentences_subset,
receiver_heads,
model,
tokenizer,
proximity_ignore=4,
)
print(f"\nReceiver head scores computed for {len(receiver_scores)} sentences")
print(f"Mean score: {np.nanmean(receiver_scores):.4f}")
print(f"Std score: {np.nanstd(receiver_scores):.4f}")
# With depth-controlled rank normalization, baseline vertical scores center around 0.5
# (since rank/n_valid ≈ 0.5 on average). Thought anchors should be noticeably above this.
mean_score = np.nanmean(receiver_scores)
assert 0.3 < mean_score < 0.7, f"Mean receiver score {mean_score:.3f} unexpectedly far from ~0.5 baseline"
# Visualize
fig = go.Figure()
fig.add_trace(
go.Bar(
x=list(range(len(receiver_scores))),
y=receiver_scores,
marker_color="indianred",
hovertemplate="<b>Sentence %{x}</b><br>Score: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
customdata=[s[:100] for s in sentences_subset],
)
)
fig.add_hline(
y=np.nanmean(receiver_scores),
line_dash="dash",
line_color="black",
annotation_text="Mean",
)
fig.update_layout(
title="Receiver Head Scores (Averaged Across Top-k Heads)",
xaxis_title="Sentence Index",
yaxis_title="Receiver Head Score",
width=800,
height=400,
)
fig.show()
print("\nTop-3 sentences by receiver head score:")
# Filter out NaN values before selecting top sentences
valid_mask = ~np.isnan(receiver_scores)
valid_indices = np.where(valid_mask)[0]
valid_scores = receiver_scores[valid_mask]
top_k = min(3, len(valid_scores))
top_in_valid = np.argsort(valid_scores)[-top_k:][::-1]
top_indices = valid_indices[top_in_valid]
for idx in top_indices:
print(f" Sentence {idx}: {receiver_scores[idx]:.4f}")
print(f" Text: {sentences_subset[idx][:100]}...")
print("Receiver head scores computed! These aggregate the attention from all high-kurtosis receiver heads.")
Attention Suppression
If receiver heads are important for propagating information from thought anchors, then suppressing attention to those sentences should change the model's output.
Exercise - implement RoPE attention
We need to implement RoPE (Rotary Position Embeddings) for our white-box attention suppression metrics. This is a tangent from the main thought anchors material - the key concept is that positional information is encoded via rotation of the query and key vectors, rather than by adding position embeddings. If you're not interested in the implementation details, feel free to skip this exercise and use the provided solution.
Qwen uses RoPE, which rotates query and key vectors based on their positions. The implementation formula is:
where $\odot$ is element-wise multiplication, $m$ is the position, $\theta_i = 10000^{-2i/d}$, and rotate_half swaps and negates pairs of dimensions.
Your task: Apply this formula to both query and key tensors using the provided cos, sin, and rotate_half helper.
A few notes:
- We've given you a helper function
rotate_halfthat rotates half the hidden dimensions (swapping pairs and negating the first of each pair), which implements the rotation effect. - The reason
q_headsandkv_headsare distinguished is because of Grouped Query Attention (GQA): each KV head serves a group of multiple query heads. You don't need to worry much about this for the purposes of our exercises here though.
def rotate_half(x: Float[Tensor, "... d"]) -> Float[Tensor, " ... d"]:
"""Rotates half the hidden dims of the input (for RoPE)."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return t.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: Float[Tensor, "batch q_heads seq head_dim"],
k: Float[Tensor, "batch kv_heads seq head_dim"],
cos: Float[Tensor, "batch seq head_dim"],
sin: Float[Tensor, "batch seq head_dim"],
position_ids: t.Tensor | None = None,
unsqueeze_dim: int = 1,
) -> tuple[Float[Tensor, "batch q_heads seq head_dim"], Float[Tensor, "batch kv_heads seq head_dim"]]:
"""Applies Rotary Position Embedding to query and key tensors."""
del position_ids # unused
raise NotImplementedError()
tests.test_rotary_pos_emb(apply_rotary_pos_emb)
Solution
def rotate_half(x: Float[Tensor, "... d"]) -> Float[Tensor, " ... d"]:
"""Rotates half the hidden dims of the input (for RoPE)."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return t.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: Float[Tensor, "batch q_heads seq head_dim"],
k: Float[Tensor, "batch kv_heads seq head_dim"],
cos: Float[Tensor, "batch seq head_dim"],
sin: Float[Tensor, "batch seq head_dim"],
position_ids: t.Tensor | None = None,
unsqueeze_dim: int = 1,
) -> tuple[Float[Tensor, "batch q_heads seq head_dim"], Float[Tensor, "batch kv_heads seq head_dim"]]:
"""Applies Rotary Position Embedding to query and key tensors."""
del position_ids # unused
cos = cos.unsqueeze(unsqueeze_dim) # Shape: (batch, 1, seq, head_dim)
sin = sin.unsqueeze(unsqueeze_dim) # Shape: (batch, 1, seq, head_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
tests.test_rotary_pos_emb(apply_rotary_pos_emb)
Exercise - compute suppressed attention
Now you'll implement the core logic for attention suppression. compute_suppressed_attention takes query and key states and computes attention weights, but with specific token ranges masked out (set to $-\infty$ before softmax, so they get zero attention weight after softmax).
How it works: (1) compute attention scores $\text{scores} = QK^T / \sqrt{d_k}$, (2) before softmax, mask specific token positions by setting their scores to $-\infty$ (minimum float value), (3) apply the attention mask if provided, (4) apply softmax. Setting scores to $-\infty$ before softmax ensures they become 0 after softmax, effectively removing those positions from the computation.
The key arguments are token_ranges (list of (start, end) tuples for which tokens to suppress), heads_mask (if provided, only suppress in specific heads), and attention_mask (standard causal/padding mask).
Hint: use t.finfo(attn_weights.dtype).min to get the minimum value for the tensor's dtype.
def compute_suppressed_attention(
query_states: Float[Tensor, "batch heads seq head_dim"],
key_states: Float[Tensor, "batch heads seq head_dim"],
token_ranges: list[tuple[int, int]],
head_dim: int, # for scaling attention weights by sqrt of!
query_len: int, # in case token_ranges exceed query length
attention_mask: Float[Tensor, "batch 1 seq seq"] | None = None,
heads_mask: Int[Tensor, " heads"] | None = None, # shape (num_heads,) with True for heads to suppress
) -> Float[Tensor, "batch heads seq seq"]:
# YOUR CODE HERE - compute attention scores, apply suppression, apply attention mask, then softmax
tests.test_compute_suppressed_attention(compute_suppressed_attention)
Solution
def compute_suppressed_attention(
query_states: Float[Tensor, "batch heads seq head_dim"],
key_states: Float[Tensor, "batch heads seq head_dim"],
token_ranges: list[tuple[int, int]],
head_dim: int, # for scaling attention weights by sqrt of!
query_len: int, # in case token_ranges exceed query length
attention_mask: Float[Tensor, "batch 1 seq seq"] | None = None,
heads_mask: Int[Tensor, " heads"] | None = None, # shape (num_heads,) with True for heads to suppress
) -> Float[Tensor, "batch heads seq seq"]:
# Compute attention scores
attn_weights = t.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
# Apply suppression mask BEFORE softmax (key step!)
for start, end in token_ranges:
effective_start = min(start, query_len)
effective_end = min(end, query_len)
if effective_start < effective_end:
mask_value = t.finfo(attn_weights.dtype).min
if heads_mask is None:
# Suppress in all heads
attn_weights[..., effective_start:effective_end] = mask_value
else:
# Suppress only in specific heads
attn_weights[:, heads_mask, :, effective_start:effective_end] = mask_value
# Apply attention mask if provided
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# Softmax
attn_weights = t.nn.functional.softmax(attn_weights, dim=-1, dtype=t.float32).to(query_states.dtype)
return attn_weights
tests.test_compute_suppressed_attention(compute_suppressed_attention)
Recommended reading: the full attention suppression hook
You don't need to implement this function from scratch. Instead, read through the implementation below and make sure you understand how each step works.
This code patches the forward method of Qwen's attention modules to use your apply_rotary_pos_emb and compute_suppressed_attention functions. It finds attention modules in each layer, stores original forward methods (for restoration later), creates masked forward functions that project to Q/K/V, apply RoPE, compute suppressed attention, and project output, then replaces the forward methods using MethodType.
A few implementation details: layer_to_heads allows suppressing only specific heads in specific layers (None = all heads in all layers), the function uses closures (create_masked_forward) to capture layer-specific parameters, and repeat_kv handles grouped-query attention where key/value heads are fewer than query heads.
How to use it:
# Apply suppression to tokens 10-20 in all heads
suppression_info = apply_qwen_attention_suppression(model, token_ranges=[(10, 20)])
# Run model with suppression active
output = model(inputs)
# Restore original behavior
remove_qwen_attention_suppression(model, suppression_info)
As you read, pay attention to:
- How the two functions you implemented (
apply_rotary_pos_embandcompute_suppressed_attention) are used within the masked forward pass - How the patching mechanism works (replacing
forwardmethods viaMethodType) - What role
token_rangesandlayer_to_headsplay in controlling the suppression
def repeat_kv(
hidden_states: Float[Tensor, "batch kv_heads seq head_dim"], n_rep: int
) -> Float[Tensor, "batch heads seq head_dim"]:
"""Expands key/value tensors for grouped-query attention."""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def apply_qwen_attention_suppression(
model,
token_ranges: list[tuple[int, int]] | tuple[int, int],
layer_to_heads: dict[int, list[int]] | None = None,
) -> dict[str, Any]:
"""
Suppresses attention to specific token positions by replacing forward methods.
Args:
model: The model to apply suppression to
token_ranges: Token range(s) to suppress - single tuple or list of tuples
layer_to_heads: Dict mapping layer indices to lists of head indices to suppress (None = all heads)
Returns:
Dict with 'original_forwards' for restoration
Adapted from thought-anchors: whitebox-analyses/pytorch_models/hooks.py
"""
# Normalize token_ranges to list of tuples
if isinstance(token_ranges, tuple):
token_ranges = [token_ranges]
# Find rotary embedding module
rotary_emb_module = None
if hasattr(model, "model") and hasattr(model.model, "rotary_emb"):
rotary_emb_module = model.model.rotary_emb
# Find attention modules to patch
target_modules = []
for name, module in model.named_modules():
if name.startswith("model.layers") and name.endswith("self_attn"):
try:
layer_idx = int(name.split(".")[2])
if layer_to_heads is None or layer_idx in layer_to_heads:
if all(hasattr(module, attr) for attr in ["config", "q_proj", "k_proj", "v_proj", "o_proj"]):
target_modules.append((name, module, layer_idx))
except (IndexError, ValueError):
continue
if not target_modules:
print("Warning: No Qwen attention modules found to patch")
return {"original_forwards": {}}
# Store original forward methods
original_forwards = {}
# Create and apply masked forward functions
for name, attn_module, layer_idx in target_modules:
original_forwards[name] = attn_module.forward
heads_mask = layer_to_heads[layer_idx] if layer_to_heads is not None else None
# Create masked forward function
def create_masked_forward(orig_forward, layer_idx, rotary_ref, heads_mask):
def masked_forward(
self,
hidden_states: t.Tensor,
attention_mask: t.Tensor | None = None,
position_ids: t.LongTensor | None = None,
past_key_value: tuple[t.Tensor] | None = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: t.LongTensor | None = None,
**kwargs,
) -> tuple[t.Tensor, t.Tensor | None]:
bsz, q_len, _ = hidden_states.size()
config = self.config
device = hidden_states.device
# Project to Q, K, V
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Reshape for multi-head attention
num_heads = config.num_attention_heads
head_dim = config.hidden_size // num_heads
num_key_value_heads = config.num_key_value_heads
num_key_value_groups = num_heads // num_key_value_heads
query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
# Apply RoPE
if position_ids is None:
position_ids = t.arange(0, q_len, dtype=t.long, device=device).unsqueeze(0)
if rotary_ref is not None and callable(rotary_ref):
cos, sin = rotary_ref(value_states, position_ids=position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# Repeat K/V for grouped-query attention
key_states = repeat_kv(key_states, num_key_value_groups)
value_states = repeat_kv(value_states, num_key_value_groups)
attn_weights = compute_suppressed_attention(
query_states,
key_states,
token_ranges,
head_dim,
q_len,
attention_mask,
heads_mask,
)
# Apply attention to values
attn_output = t.matmul(attn_weights, value_states)
# Reshape and project output
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, config.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights if output_attentions else None
return masked_forward
attn_module.forward = MethodType(
create_masked_forward(attn_module.forward, layer_idx, rotary_emb_module, heads_mask), attn_module
)
return {"original_forwards": original_forwards}
def remove_qwen_attention_suppression(model, suppression_info: dict[str, Any]):
"""Restores original forward methods after attention suppression."""
original_forwards = suppression_info.get("original_forwards", {})
if not original_forwards:
return
for name, module in model.named_modules():
if name in original_forwards:
module.forward = original_forwards[name]
Comprehension questions
Now that you've read through the implementation, try answering these questions to check your understanding.
1. Why do we need to apply rotary position embeddings (RoPE) before computing attention scores?
Answer
Rotary position embeddings encode relative position information into the query and key vectors. Without RoPE, the attention scores would have no awareness of token positions, so the model wouldn't know which tokens come before or after others. RoPE must be applied before computing attention because the attention pattern depends on the relative positions of the query and key tokens. If we skipped this step, the suppression mask would be applied to position-unaware attention scores, which wouldn't match the model's normal behavior.
2. What would happen if we suppressed ALL sentence tokens rather than just the target sentence?
Answer
Suppressing all sentence tokens would set all attention weights to negative infinity before softmax, leaving only the attention to non-sentence tokens (e.g. special tokens or whitespace). The model's output would become essentially meaningless since almost all semantic content is carried by sentence tokens. The point of suppressing a single target sentence is to isolate its causal contribution: we want to measure how much removing attention to that specific sentence changes the output, while keeping the rest of the context intact.
3. Why do we compute KL divergence rather than just checking if the top predicted token changes?
Answer
Checking only the top predicted token is a very coarse measure. A sentence could substantially shift the probability distribution (e.g. redistributing probability mass among the top-5 tokens) without changing which token has the highest probability. KL divergence captures the full distributional shift, including subtle changes in confidence and probability rankings. This is especially important for thought anchors: a sentence might not change the single most likely next token, but it could significantly affect the model's uncertainty and the tail of the distribution, which matters for downstream reasoning over many token positions.
Exercise - measure suppression effect with KL divergence
For a real model, we'd measure how suppressing attention to each sentence changes the output distribution.
def compute_suppression_kl(
original_logits: Float[np.ndarray, " vocab"],
suppressed_logits: Float[np.ndarray, " vocab"],
temperature: float = 1.0,
) -> float:
"""
Compute KL divergence between original and suppressed output distributions.
Args:
original_logits: Logits from original forward pass
suppressed_logits: Logits from suppressed forward pass
temperature: Temperature for softmax
Returns:
KL divergence (in nats)
"""
raise NotImplementedError("Compute KL divergence between distributions")
tests.test_compute_suppression_kl(compute_suppression_kl)
Solution
def compute_suppression_kl(
original_logits: Float[np.ndarray, " vocab"],
suppressed_logits: Float[np.ndarray, " vocab"],
temperature: float = 1.0,
) -> float:
"""
Compute KL divergence between original and suppressed output distributions.
Args:
original_logits: Logits from original forward pass
suppressed_logits: Logits from suppressed forward pass
temperature: Temperature for softmax
Returns:
KL divergence (in nats)
"""
def softmax(x, temp):
x = x / temp
exp_x = np.exp(x - np.max(x))
return exp_x / exp_x.sum()
p = softmax(original_logits, temperature)
q = softmax(suppressed_logits, temperature)
# Add small epsilon for numerical stability
eps = 1e-10
return np.sum(p * np.log((p + eps) / (q + eps)))
Testing attention suppression
Now let's test attention suppression using forward method replacement. The key point is that we mask attention scores BEFORE softmax, not after - the implementation above replaces the entire forward method to intervene at the right point. We expect that suppressing high-importance sentences causes larger KL divergence than suppressing low-importance sentences.
# Select high and low importance sentences
valid_mask = ~np.isnan(receiver_scores)
valid_indices = np.where(valid_mask)[0]
valid_scores = receiver_scores[valid_mask]
high_idx = valid_indices[np.argmax(valid_scores)]
low_idx = valid_indices[np.argmin(valid_scores)]
print(f"\nHigh importance sentence {high_idx}: score={receiver_scores[high_idx]:.4f}")
print(f" Text: {sentences_subset[high_idx][:100]}...")
print(f"\nLow importance sentence {low_idx}: score={receiver_scores[low_idx]:.4f}")
print(f" Text: {sentences_subset[low_idx][:100]}...")
# Get token boundaries
high_sent_range = boundaries[high_idx]
low_sent_range = boundaries[low_idx]
# Prepare input
inputs = tokenizer(text_subset, return_tensors="pt")
if model.device.type == "cuda":
inputs = {k: v.cuda() for k, v in inputs.items()}
# Original forward pass (no suppression)
with t.no_grad():
original_output = model(**inputs)
original_logits = original_output.logits[0, -1].cpu().numpy()
layer_to_heads = {}
for layer_idx, head_idx in receiver_heads[:5]:
if layer_idx not in layer_to_heads:
layer_to_heads[layer_idx] = []
layer_to_heads[layer_idx].append(int(head_idx))
# Test 1: Suppress high-importance sentence in top receiver heads
suppression_info = apply_qwen_attention_suppression(model, high_sent_range, layer_to_heads)
with t.no_grad():
suppressed_high_output = model(**inputs)
suppressed_high_logits = suppressed_high_output.logits[0, -1].cpu().numpy()
remove_qwen_attention_suppression(model, suppression_info)
kl_high = compute_suppression_kl(original_logits, suppressed_high_logits)
# Test 2: Suppress low-importance sentence
suppression_info = apply_qwen_attention_suppression(model, low_sent_range, layer_to_heads)
with t.no_grad():
suppressed_low_output = model(**inputs)
suppressed_low_logits = suppressed_low_output.logits[0, -1].cpu().numpy()
remove_qwen_attention_suppression(model, suppression_info)
kl_low = compute_suppression_kl(original_logits, suppressed_low_logits)
# Compare results
print(f"\nKL divergence when suppressing HIGH importance sentence: {kl_high:.4e}")
print(f"KL divergence when suppressing LOW importance sentence: {kl_low:.4e}")
print(f"Ratio (high/low): {kl_high / (kl_low + 1e-10):.2f}x")
if kl_high > kl_low * 10:
print("\nSuccess! Suppressing high-importance sentences causes larger output changes")
print(" This causally validates that receiver head scores identify important sentences")
Click to see the expected output
High importance sentence 68: score=0.9711 Text: Wait, but 66666 in hex is 5 digits, so 5 * 4 = 20 bits.... Low importance sentence 58: score=0.2037 Text: 419,424 + 6 = 419,430.... KL divergence when suppressing HIGH importance sentence: 8.7300e-05 KL divergence when suppressing LOW importance sentence: 4.0152e-07 Ratio (high/low): 217.37x Success! Suppressing high-importance sentences causes larger output changes This causally validates that receiver head scores identify important sentences
What this shows: suppressing attention to high-importance sentences (as identified by receiver head scores) causes much larger KL divergence than suppressing low-importance ones. This is a causal validation - it confirms that the receiver head scores aren't just correlational, but that attention to these sentences actually drives the model's output.
Note: the suppression works by replacing attention module forward methods (not post-hooks) and masking attention scores to -inf before softmax. This is specific to the Qwen/Qwen2 architecture.
Bonus Exercises: Replicating Paper Whitebox Results
The correlation analyses above produce weak results on a single reasoning trace - this is expected. Robust claims about thought anchors require aggregating evidence across many problems. The paper's whitebox experiments run over a large dataset of math rollouts (available on HuggingFace) and replicate across two model families (Qwen-14B and Llama-8B).
The bonus exercises below each target a specific figure from the paper. Each has a starting point using the infrastructure already built in this notebook, and a stretch goal that scales up to the full dataset and the paper's code from the whitebox-analyses/ directory.
Replicate the sentence-to-sentence suppression KL matrix (Figure 6)
The paper's most direct causal whitebox result is a sentence-to-sentence suppression matrix: for each pair of sentences $(j, i)$, suppress attention to sentence $j$ across all attention heads, then measure the mean KL divergence of the model's token-level predictions within sentence $i$. This produces an $N \times N$ matrix where entry $[i, j]$ captures how much sentence $j$ causally influences predictions at positions in sentence $i$. This is the whitebox complement of the blackbox counterfactual importance computed earlier, and Figure 6 visualizes it as a heatmap.
Starting point (single problem): using the model and suppression tools from this notebook:
- Run a baseline forward pass and collect all token-level logits (not just the final position) - you will need the full logit tensor, not just the last position slice.
- For each sentence $j$, call
apply_qwen_attention_suppressionacross all attention heads (setlayer_to_headsto include every head in every layer), run a forward pass to collect token-level logits, then restore withremove_qwen_attention_suppression. This is compute-intensive - start with a subset of sentences. - For each pair $(j, i)$, compute the mean KL divergence between baseline and suppressed logits over all token positions that fall within sentence $i$'s token boundaries (use
compute_suppression_klfrom earlier in the notebook). Store results in an $N \times N$ matrix. - Mask the upper triangle and diagonal with
np.nan(future sentences cannot be causally affected by later ones), then visualize as a heatmap. Try subtracting each row's mean to normalize for sentence depth. Use a white-to-red colormap.
What to look for: thought anchors appear as columns with elevated KL values stretching down across many rows - they causally affect many later sentences. These columns typically correspond to plan generation or explicit backtracking steps. Compare the column positions to the receiver head scores and blackbox counterfactual importances from Section 2.
Stretch goal (full dataset): the paper's implementation is get_suppression_KL_matrix in whitebox-analyses/attention_analysis/attn_supp_funcs.py. It uses nucleus-sampling logit compression (p_nucleus=0.9999) to store only the top-p logits efficiently. Run it on 10-20 problems and average the row-normalized matrices. The plot_suppression_matrix.py script handles visualization. Compare your result to Figure 6.
def replicate_figure_6(
model,
tokenizer,
input_ids: Tensor,
sentence_boundaries: list[tuple[int, int]],
layer_to_heads: dict[int, list[int]] | None = None,
row_normalize: bool = True,
) -> np.ndarray:
"""
Compute the sentence-to-sentence suppression KL matrix,
replicating Figure 6 from the Thought Anchors paper.
For a single reasoning trace with N sentences:
1. Run a baseline forward pass and collect all token-level logits
(the full logit tensor, not just the last position)
2. For each sentence j (the suppressed sentence), call
apply_qwen_attention_suppression to suppress attention to
sentence j's tokens across all heads, run a forward pass to
get suppressed logits, then restore with
remove_qwen_attention_suppression
3. For each pair (j, i), compute mean KL divergence between
baseline and suppressed logits over all token positions within
sentence i's boundaries, using compute_suppression_kl
4. Store results in an N x N matrix where entry [i, j] is the
causal influence of sentence j on predictions at sentence i
5. Mask the upper triangle and diagonal with np.nan (future
sentences cannot be causally affected by later ones)
6. Optionally subtract each row's mean to normalize for sentence
depth effects
Returns the N x N suppression KL matrix (lower-triangular, with
np.nan on and above the diagonal).
"""
raise NotImplementedError()
Identify receiver heads via kurtosis and verify split-half reliability (Figure 4)
The paper identifies "receiver heads" by computing the kurtosis of each attention head's vertical attention scores across a set of reasoning traces. The vertical score for sentence $s$ under head $(l, h)$ is the mean attention weight that $s$ receives from all subsequent token positions, after ignoring a local proximity window. High kurtosis means a head's attention is concentrated on a small number of "anchor" sentences rather than spread broadly. The paper finds that only a small fraction of heads have high kurtosis, these heads are consistent across problems, and they correspond to the receiver heads identified qualitatively earlier.
Starting point (single problem): using the current reasoning trace and loaded model:
- For each attention head $(l, h)$, extract the full attention matrix and compute sentence-level vertical scores: for each sentence $s$, average the attention weights directed toward sentence $s$ from all token positions that are at least
proximity_ignore=16positions after sentence $s$'s end boundary. The paper implements this inget_vertical_scoresinwhitebox-analyses/attention_analysis/attn_funcs.py. - Compute
scipy.stats.kurtosisof the resulting per-sentence score array for each head. - Produce a scatter plot of kurtosis versus layer index. You should see a small number of outlier heads with substantially higher kurtosis than the bulk of heads (which cluster near zero). These are candidate receiver heads.
- Check whether the top-kurtosis heads from this single trace agree with the receiver head set used in this notebook (which was pre-computed from many traces).
Stretch goal - split-half reliability (paper result: r = 0.84): replicate the paper's reliability analysis across many reasoning traces:
- Download the dataset from HuggingFace (
uzaymacar/math-rollouts) and process a set of problems usingget_all_problems_vert_scoresfrom the paper'sreceiver_head_funcs.py. - Split problems into two halves (odd/even indices). For each half, average per-head kurtosis using
get_3d_ar_kurtosis. - Scatter plot first-half vs. second-half mean kurtosis (one point per head). Compute Pearson correlation using
scipy.stats.pearsonr. The paper reports r = 0.84, showing the same heads consistently function as receivers across different reasoning problems. - Select the top-$k$ heads by mean kurtosis (paper uses $k = 20$) as the receiver head set. Note which layers these heads concentrate in.
def replicate_figure_4(
model,
tokenizer,
input_ids: Tensor,
sentence_boundaries: list[tuple[int, int]],
proximity_ignore: int = 16,
top_k: int = 20,
) -> tuple[list[tuple[int, int, float]], np.ndarray]:
"""
Identify receiver heads via kurtosis of vertical attention scores,
replicating Figure 4 from the Thought Anchors paper.
For a single reasoning trace:
1. Run a forward pass with output_attentions=True to get attention
matrices for every layer and head
2. For each head (layer, head_idx), compute sentence-level vertical
scores: for each sentence s, average the attention weights directed
toward sentence s from all token positions that are at least
proximity_ignore positions after sentence s's end boundary
3. Compute scipy.stats.kurtosis of the per-sentence vertical score
array for each head, storing results in a (num_layers, num_heads)
matrix
4. Identify the top_k heads with highest kurtosis as candidate
receiver heads
5. Produce a scatter plot of kurtosis vs layer index to visualize
outlier heads
Returns a tuple of:
- List of (layer, head, kurtosis) for the top_k heads, sorted by
kurtosis descending
- The full kurtosis matrix of shape (num_layers, num_heads)
"""
raise NotImplementedError()
Receiver head scores by sentence taxonomic category (Figure 5)
The paper's most interpretable whitebox result is that receiver head attention is not uniformly distributed across sentence types. Plan generation and uncertainty management sentences receive significantly higher receiver-head scores than active computation sentences (p < 0.001 for all pairwise comparisons). This directly links the functional role of a sentence (planning, pivoting) to its mechanistic role (being selectively attended to by specialized receiver heads).
The sentence categories used in the paper are: plan_generation, active_computation, fact_retrieval, uncertainty_management, result_consolidation, self_checking, problem_setup, and final_answer_emission. Each sentence is labeled by an LLM auto-labeler using a structured prompt (DAG_PROMPT in prompts.py in the paper's repository).
Starting point (single problem):
- Use an LLM (e.g. via the Anthropic API) with the
DAG_PROMPTfrom the paper'sprompts.pyto classify each sentence into function tags based on its role in the reasoning. - Compute receiver head scores for each sentence using
compute_receiver_head_scores(already defined in this notebook). - Group sentences by function tag, compute mean receiver-head score per group, and plot a bar chart. Even on a single trace, plan generation and uncertainty management should score noticeably higher than active computation.
Stretch goal - statistical validation across many problems (paper result: p < 0.001):
- Use the paper's
generate_rec_csvs.pyscript to produce CSV files containing receiver head scores and taxonomic labels for every problem in the dataset. - For each problem, aggregate receiver head scores by sentence category (per-problem median, to avoid pseudoreplication across sentences within a problem).
- Run paired t-tests using
scipy.stats.ttest_rel, comparing pairs of categories using only problems that have examples of both categories. Excludefinal_answer_emissionandproblem_setupfrom pairwise comparisons. - Plot boxplots stratified by category (as in
plot_rec_taxonomy.pyfrom the paper code). Compare to Figure 5 of the paper. The result provides statistical support for the core claim: plan generation and uncertainty management sentences are the primary thought anchors - they are the steps that most influence downstream reasoning, both causally (blackbox) and via attention (whitebox).
def replicate_figure_5(
model,
tokenizer,
sentences: list[str],
sentence_boundaries: list[tuple[int, int]],
receiver_heads: list[tuple[int, int]],
function_tags: list[str],
categories: list[str] | None = None,
) -> dict[str, list[float]]:
"""
Compute receiver head scores grouped by sentence taxonomic category,
replicating Figure 5 from the Thought Anchors paper.
For a single reasoning trace:
1. Compute receiver head vertical scores for each sentence using
compute_receiver_head_scores (already defined in this notebook)
2. Group sentences by their function tag (from LLM classification)
3. For each category, collect the receiver head scores of all sentences
in that category
4. Plot a bar chart (or boxplot) of mean receiver head score per category
5. Plan generation and uncertainty management should score noticeably
higher than active computation
Returns a dict mapping each category to a list of receiver head scores
for sentences in that category.
"""
raise NotImplementedError()