2️⃣ Black-box Analysis
Learning Objectives
- Understand forced answer importance: measuring impact by directly forcing model answers at specific steps
- Implement resampling importance: measuring impact by having the model regenerate its own continuation from that point forward
- Implement counterfactual importance: measuring impact by resampling then filtering to steps where the regenerated sentence is semantically different from the original
- Learn when each metric is appropriate (causal vs correlational analysis)
- Replicate key paper figures showing which reasoning steps are most critical for final answers
Now that we've split reasoning traces into sentences and categorized them, we can treat each sentence as a unit and ask: which sentences are most important for the model's final answer?
This section introduces three black-box methods for measuring sentence importance. "Black-box" here means we only observe inputs and outputs - we don't look inside the model at activations or attention patterns. By intervening on individual sentences (forcing answers, resampling, or replacing with counterfactuals), we can measure how much each sentence shapes the model's reasoning trajectory and final conclusion.
The paper uses three methods to measure sentence importance:
1. Forced Answer Importance. For each sentence $S_i$ in the CoT, we interrupt the model and force a final output: Therefore, the final answer is \\boxed{...}. We measure accuracy when forced to answer immediately. If $A^f_i$ is the final answer when we force after $S_i$, and $P(A)$ is the probability of answer $A$ being correct:
Can you see a flaw in this approach, when it comes to identifying sentences which are critical for the model's reasoning process? Click here to reveal the answer.
Some sentence $S$ might be necessary for the final answer, but comes late in the reasoning process, meaning all sentences before $S$ will result in low accuracy by this metric. We will only pick up sentences whose inclusion in combination with all previous sentences gets us from the wrong answer to the right one.
For people who have completed the IOI material / done much work with model internals, this is the equivalent of finding the most important model components by seeing which ones write the final answer to the residual stream. It's a good start, but doesn't tell you much about the steps of the computation beyond the last one.
2. Resampling Importance. To address this flaw, for each sentence $S_i$ we resample a whole trajectory $(S_1, S_2, \ldots, S_{i-1}, S_i, S'_{i+1}, ..., S'_N, A^r_i)$ (where $A^r_i$ is the final answer from resampling chunks after $S_i$). Comparing to the trajectory $(S_1, S_2, \ldots, S'_i, ..., S'_{N}, A^r_i)$ from resampling chunks including $S_i$, we get:
3. Counterfactual Importance. An issue with resampling importance is that often $S'_i$ will be very similar to $S_i$, if the reasoning context strongly constrains what can be expressed at that position. We fix this by filtering for cases where these two sentences are fairly different, using a semantic similarity metric from our embedding model. This gives us a counterfactual importance metric identical to the previous one, but with filtered rollouts:
We'll work through each of these metrics in turn, computing them on the dataset we've already been provided. The full replication (including your own implementation of resampling) comes in the next section.
Looking closer at our dataset
Before getting into the exercises, let's look closer at our dataset. We'll set aside our actual model and chunking functions temporarily, focusing on the resampled rollouts already provided in the dataset.
The code below loads all the data for a single problem:
problemwhich contains the problem statement, along with some basic metadata (including the answer)chunks_labeled[i], data for thei-th chunk (e.g. what category it is, plus some metrics)chunk_solutions_forced[i][j], data for thej-th rollout we get from forcing an answer immediately after thei-th chunk (e.g. fori=0this means forcing an answer before any chunks are included)chunk_solutions[i][j], data for thej-th rollout we get from resampling immediately after thei-th chunk
We'll be using this data to implement the three importance metrics described above.
Run the code below to load in the data for a single problem (note we're using generations from deepseek-r1-distill-qwen-14b rather than our Llama-8b model here, so that we match the case study in the paper's appendix).
Make sure you understand the structure of the data, since this will make the following exercises easier. Here are a few investigative questions you might try to answer when inspecting this problem's data:
- How has the chunking worked here? Can you see any obvious issues with how it's been applied, e.g. places where a chunk has been split that shouldn't have been?
- Do the categories look reasonable? You can try comparing them to each of your autoraters, and see what fraction match.
- Inspect
chunk_solutions_forced, and see how early the model generally manages to get the answer correct. - Inspect
chunk_solutions, and see how much variety the model's completions have when resampled from various stages in the reasoning process.
def load_single_file(file_path: str):
local_path = hf_hub_download(repo_id=DATASET_NAME, filename=file_path, repo_type="dataset")
with open(local_path, "r") as f:
return json.load(f)
def load_problem_data(problem_id: int, model_name: str = "deepseek-r1-distill-qwen-14b", verbose: bool = True):
disable_progress_bars()
problem_dir = "correct_base_solution"
problem_dir_forced = "correct_base_solution_forced_answer"
problem_path = f"{model_name}/temperature_0.6_top_p_0.95/{problem_dir}/problem_{problem_id}"
problem_path_forced = f"{model_name}/temperature_0.6_top_p_0.95/{problem_dir_forced}/problem_{problem_id}"
base_solution = load_single_file(f"{problem_path}/base_solution.json")
problem = load_single_file(f"{problem_path}/problem.json")
chunks_labeled = load_single_file(f"{problem_path}/chunks_labeled.json")
n_chunks = len(chunks_labeled)
chunk_solutions = [None] * n_chunks
chunk_solutions_forced = [None] * n_chunks
def load_chunk(chunk_idx):
sol = load_single_file(f"{problem_path}/chunk_{chunk_idx}/solutions.json")
forced = load_single_file(f"{problem_path_forced}/chunk_{chunk_idx}/solutions.json")
return chunk_idx, sol, forced
with ThreadPoolExecutor(max_workers=16) as executor:
futures = [executor.submit(load_chunk, i) for i in range(n_chunks)]
for future in tqdm(as_completed(futures), total=n_chunks, disable=not verbose, desc="Loading chunks"):
idx, sol, forced = future.result()
chunk_solutions[idx] = sol
chunk_solutions_forced[idx] = forced
enable_progress_bars()
return {
"problem": problem,
"base_solution": base_solution,
"chunks_labeled": chunks_labeled,
"chunk_solutions_forced": chunk_solutions_forced,
"chunk_solutions": chunk_solutions,
}
# Load a problem
problem_data = load_problem_data(PROBLEM_ID)
# Inspect the problem structure
print("Problem:", problem_data["problem"]["problem"][:200], "...")
print(f"\nGround truth answer: {problem_data['problem']['gt_answer']}")
print(f"Number of chunks: {len(problem_data['chunks_labeled'])}")
print(f"Rollouts per chunk: {len(problem_data['chunk_solutions'][0])}")
# Show first few chunks
for i, c in enumerate(problem_data["chunks_labeled"][:5]):
print(f"{i}. [{c['function_tags'][0]}] {c['chunk'][:80]}...")
Click to see the expected output
Problem: When the base-16 number $66666_{16}$ is written in base 2, how many base-2 digits (bits) does it have? ...
Ground truth answer: 19
Number of chunks: 145
Rollouts per chunk: 100
0. [problem_setup] Okay, so I have this problem where I need to find out how many bits the base-16 ...
1. [problem_setup] Hmm, let's see....
2. [fact_retrieval] I remember that each hexadecimal digit corresponds to exactly 4 binary digits, o...
3. [plan_generation] So, maybe I can just figure out how many hexadecimal digits there are and multip...
4. [uncertainty_management] Let me check that....
Metric (1/3): Forced Answer Importance
For each sentence $S_i$, we interrupt the model and force it to output a final answer. We measure how accuracy changes:
What's the limitation of this approach?
This only captures the final step that tips the scales - when all previous sentences together finally produce enough information for the correct answer. It doesn't tell us which earlier sentences were critical for setting up that final step.
Analogy: It's like measuring which brick "caused" a wall to reach a certain height - technically only the last brick did, but all previous bricks were necessary too.
def extract_answer_from_cot(cot: str) -> str:
"""Extract the numerical answer from a chain-of-thought solution by parsing the \\boxed{} expression."""
ans = cot.split("\\boxed{")[-1].split("}")[0]
return "".join(char for char in ans if char.isdigit() or char == ".")
def get_filtered_indices(
chunk_removed: str, resampled: list[str], embedding_model: "SentenceTransformer", threshold: float = 0.7
) -> list[int]:
"""Return indices of resampled rollouts whose chunk is sufficiently dissimilar from the original."""
emb_original = embedding_model.encode(chunk_removed)
emb_resampled = embedding_model.encode(resampled)
cos_sims = emb_original @ emb_resampled.T
return np.where(cos_sims < threshold)[0]
Exercise - calculate forced answer importance
You should fill in the function below, to compute the forced answer importance on a set of chunks and their labelled categories. Note that the function takes a list of lists of full CoT rollouts, meaning you'll need to parse out the model's answer from the CoT yourself using the extract_answer_from_cot helper defined above.
Note: extract_answer_from_cot is a simplification that strips everything except digits and decimal points from the \boxed{} expression. This works fine for this case study (which has an integer answer), but for other problem types you'd need a more robust parser - for example, answers like 6.17% or r = 6.17 would lose their context. Keep this limitation in mind if you adapt this code to other datasets.
def calculate_answer_importance(full_cot_list: list[list[str]], answer: str) -> list[float]:
"""
Calculate importance for chunks based on accuracy differences.
Args:
full_cot_list: List of lists of rollouts. full_cot_list[i][j] is the j-th rollout
generated by forcing an answer after the i-th chunk.
answer: The ground truth answer.
Returns:
List of importance scores (one fewer than chunks, since we measure differences).
"""
raise NotImplementedError("Implement answer importance calculation")
tests.test_calculate_answer_importance(calculate_answer_importance)
Solution
def calculate_answer_importance(full_cot_list: list[list[str]], answer: str) -> list[float]:
"""
Calculate importance for chunks based on accuracy differences.
Args:
full_cot_list: List of lists of rollouts. full_cot_list[i][j] is the j-th rollout
generated by forcing an answer after the i-th chunk.
answer: The ground truth answer.
Returns:
List of importance scores (one fewer than chunks, since we measure differences).
"""
# Get P(correct) for each chunk position
probabilities = [
sum(extract_answer_from_cot(cot) == answer for cot in cot_list) / len(cot_list) for cot_list in full_cot_list
]
# Return differences: P(A_{S_i}) - P(A_{S_{i-1}})
return np.diff(probabilities).tolist()
tests.test_calculate_answer_importance(calculate_answer_importance)
When you've done this, run the cell below to get your results and plot them. This should not match the paper's "Figure 2" results yet, since we're using forced answer importance, not resampled or counterfactual importance.
What do you notice? What sentences are necessary for the model to start getting greater-than-zero accuracy? Are there any sentences which significantly drop or raise accuracy, and can you explain why?
What patterns do you notice? (click to see discussion)
You should see that: - Most early chunks have ~0 importance (the model can't answer correctly yet) - There are sharp spikes near the end when critical computations are made - Some chunks have negative importance - they might introduce confusion or errors
For example, if a model is near the end of its reasoning, it might follow a pattern like "maybe X, wait no actually Y", and this will result in wildly careening forced answer importance values for those last few sentences, which aren't really representative of how important the sentence was in the actual reasoning process.
Earlier important reasoning steps get 0 importance even if they were essential - each intermediate step only shows importance when it contributes to tipping the final answer from wrong to right in that iteration.
Note on the small discrepancy from the actual dataset metrics
A small discrepancy between your results and those in the dataset is fine, and expected. The current version of the dataset that is uploaded seems to have a bug in the "forced answer" metric data, for example it will classify the following rollout:
'Therefore, the final answers is \\boxed{20}. However, upon re-examining ... so the correct answer is \\boxed{19}.'
as having a final answer of 20 rather than 19, hence incorrectly classifying the answer as wrong.
# Calculate forced answer importance
full_cot_list = [
[rollout["full_cot"] for rollout in chunk_rollouts] for chunk_rollouts in problem_data["chunk_solutions_forced"]
]
answer = problem_data["problem"]["gt_answer"]
forced_importances = calculate_answer_importance(full_cot_list, answer)
# Get chunk texts for hover data
chunks_for_hover = [chunk["chunk"] for chunk in problem_data["chunks_labeled"][:-1]]
# Plot with plotly
fig = go.Figure()
fig.add_trace(
go.Bar(
x=list(range(len(forced_importances))),
y=forced_importances,
opacity=0.7,
hovertemplate="<b>Chunk %{x}</b><br>Importance: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
customdata=[chunk[:100] + "..." if len(chunk) > 100 else chunk for chunk in chunks_for_hover],
)
)
fig.add_hline(y=0, line_color="black", line_width=0.5)
fig.update_layout(
title="Forced Answer Importance by Chunk",
xaxis_title="Chunk Index",
yaxis_title="Forced Answer Importance",
width=900,
height=400,
)
fig.show()
Click to see the expected output
Metric (2/3) : Resampling Importance
To address the limitation described in the dropdown above, we resample immediately after each sentence and measure how this changes the final answer distribution:
The key difference: $A^r_i$ comes from a full resampled trajectory starting after $S_i$, not a forced early answer. This captures less narrow ways in which the sentence is important.
Exercise - compare resampling importance
The same calculate_answer_importance function works - we just apply it to different data!
Before computing this, make a concrete prediction: which chunks do you think will have the highest resampling importance? Will they be the same chunks that had high forced answer importance, or different ones? Will the importance be concentrated near the end of the trace (like forced answer importance) or more spread out? Write down your prediction, then compute the results and see how well your intuition matched.
# YOUR CODE HERE - compute `resampling_importances` using `calculate_answer_importance`
# on the resampled rollouts
resampling_importances = []
# Compare with precomputed values from dataset (they used a slightly different method to us, but
# we should get an answer within 1% of theirs)
precomputed = [chunk["resampling_importance_accuracy"] for chunk in problem_data["chunks_labeled"][:-1]]
avg_diff = np.abs(np.subtract(resampling_importances, precomputed)).mean()
assert avg_diff < 0.01, "Error above 1% threshold"
# Plot comparison between these two metrics
chunks_for_hover = [chunk["chunk"] for chunk in problem_data["chunks_labeled"][:-1]]
hover_texts = [chunk[:100] + "..." if len(chunk) > 100 else chunk for chunk in chunks_for_hover]
fig = make_subplots(
rows=2,
cols=1,
shared_xaxes=True,
vertical_spacing=0.1,
subplot_titles=("Forced Answer Importance", "Resampling Importance"),
)
for row, color, y in [(1, "cornflowerblue", forced_importances), (2, "orange", resampling_importances)]:
fig.add_trace(
go.Bar(
x=list(range(len(y))),
y=y,
opacity=0.7,
name="Forced" if color == "cornflowerblue" else "Resampling",
marker_color=color,
hovertemplate="<b>Chunk %{x}</b><br>Importance: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
customdata=hover_texts,
),
row=row,
col=1,
)
fig.add_hline(y=0, line_color="black", line_width=0.5, row=1, col=1)
fig.add_hline(y=0, line_color="black", line_width=0.5, row=2, col=1)
fig.update_layout(width=900, height=500, showlegend=False)
fig.update_xaxes(title_text="Chunk Index", row=2, col=1)
fig.update_yaxes(title_text="Importance", row=1, col=1)
fig.update_yaxes(title_text="Importance", row=2, col=1)
fig.show()
Click to see the expected output
Solution & discussion
# Calculate resampling importance (same function, different data!)
full_cot_list_resampled = [
[rollout["full_cot"] for rollout in chunk_rollouts] for chunk_rollouts in problem_data["chunk_solutions"]
]
resampling_importances = calculate_answer_importance(full_cot_list_resampled, answer)
# Compare with precomputed values from dataset (they used a slightly different method to us, but
# we should get an answer within 1% of theirs)
precomputed = [chunk["resampling_importance_accuracy"] for chunk in problem_data["chunks_labeled"][:-1]]
avg_diff = np.abs(np.subtract(resampling_importances, precomputed)).mean()
assert avg_diff < 0.01, "Error above 1% threshold"
# Plot comparison between these two metrics
chunks_for_hover = [chunk["chunk"] for chunk in problem_data["chunks_labeled"][:-1]]
hover_texts = [chunk[:100] + "..." if len(chunk) > 100 else chunk for chunk in chunks_for_hover]
fig = make_subplots(
rows=2,
cols=1,
shared_xaxes=True,
vertical_spacing=0.1,
subplot_titles=("Forced Answer Importance", "Resampling Importance"),
)
for row, color, y in [(1, "cornflowerblue", forced_importances), (2, "orange", resampling_importances)]:
fig.add_trace(
go.Bar(
x=list(range(len(y))),
y=y,
opacity=0.7,
name="Forced" if color == "cornflowerblue" else "Resampling",
marker_color=color,
hovertemplate="<b>Chunk %{x}</b><br>Importance: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
customdata=hover_texts,
),
row=row,
col=1,
)
fig.add_hline(y=0, line_color="black", line_width=0.5, row=1, col=1)
fig.add_hline(y=0, line_color="black", line_width=0.5, row=2, col=1)
fig.update_layout(width=900, height=500, showlegend=False)
fig.update_xaxes(title_text="Chunk Index", row=2, col=1)
fig.update_yaxes(title_text="Importance", row=1, col=1)
fig.update_yaxes(title_text="Importance", row=2, col=1)
fig.show()
You should see the resampling importance values are often higher at earlier steps, petering out by the second half of the rollout chunks.
Moreover, the highest-importance sentences according to this metric are the ones that involve some amount of planning. The most important (nearly 4x higher than any other positive sentence) is one that appears in the first 15% of the rollout chunks, and which outlines the plan the model will use for solving this problem. The most negative sentence is the one which appears immediately before that one, which expresses a slightly different plan.
Semantic Similarity in Resampling
Before we look at the last metric (counterfactual importance), let's revisit the notion of embedding cosine similarity. Since we have data on a bunch of resampled rollouts at different chunks, we can compute the average cosine similarity between a chunk and all of its resampled chunks (i.e. $S_i$ and $S'_i$ in the notation above). Run the cells below to compute these cosine similarities and plot them.
Which kinds of sentences seem like their resamples have the highest or lowest cosine similarity? Can you explain why?
# Compute cosine similarity between original and resampled chunks
chunks_removed = [chunk["chunk"] for chunk in problem_data["chunks_labeled"]]
embeddings_original = embedding_model.encode(chunks_removed)
chunks_resampled = [
[rollout["chunk_resampled"] for rollout in chunk_rollouts] for chunk_rollouts in problem_data["chunk_solutions"]
]
embeddings_resampled = np.stack([embedding_model.encode(r) for r in chunks_resampled])
# Compute similarities
cos_sims = einops.einsum(embeddings_original, embeddings_resampled, "chunk d, chunk resample d -> chunk resample")
cos_sims_mean = cos_sims.mean(axis=1)
# Plot by category with plotly
chunk_labels = [CATEGORIES[chunk["function_tags"][0]] for chunk in problem_data["chunks_labeled"]]
chunks_for_hover = [chunk["chunk"] for chunk in problem_data["chunks_labeled"]]
df = pd.DataFrame({"Label": chunk_labels, "Cosine Similarity": cos_sims_mean, "Chunk Text": chunks_for_hover})
fig = go.Figure()
for label in df["Label"].unique():
subset = df[df["Label"] == label]
hover_texts = [text[:100] + "..." if len(text) > 100 else text for text in subset["Chunk Text"]]
fig.add_trace(
go.Bar(
x=subset.index.tolist(),
y=subset["Cosine Similarity"].tolist(),
name=label,
marker_color=utils.CATEGORY_COLORS.get(label, "#9E9E9E"),
hovertemplate="<b>Chunk %{x}</b><br>Category: "
+ label
+ "<br>Cosine Similarity: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
customdata=hover_texts,
)
)
fig.update_layout(
title="How Similar are Resampled Chunks to Originals?",
xaxis_title="Chunk Index",
yaxis_title="Mean Cosine Similarity to Resamples",
width=900,
height=400,
legend=dict(x=1.02, y=1, xanchor="left"),
bargap=0,
)
fig.show()
Click to see the expected output
What patterns do you notice? (click to see discussion)
Some of the highest average cosine similarities are for Final Answer Emission chunks (there are only so many ways to express an answer), and Active Computation / Result Consolidation chunks that come in batches (e.g. chunks 23-33, 51-59, 91-97). These make sense because once the model starts an iterative computation, it's very constrained in what it says until it's finished.
Some of the lowest average cosine similarities are for Plan Generation chunks that represent changes in trajectory (e.g. chunk 13: "Alternatively, maybe I can calculate...", chunk 49: "Let me convert it step by step again"), and Uncertainty Management chunks that represent re-evaluation (e.g. chunk 45: "I must have made a mistake somewhere", chunk 84: "So, which is correct?", chunk 139: "Wait, but that's not correct because...").
This motivates counterfactual importance: we only count resamples that are genuinely different (because it doesn't tell us much if our resamples are all the same as the original - we want to figure out how the rollout changes when the resample is not the same!).
Note, there is one chunk (28) which is classified as "result consolidation" and is something of an outlier, with extremely low average cosine similarity to its resamples. However, inspection of problem_data["chunk_solutions"][28] shows that this is actually an artifact of incorrect chunking: the resamples here all follow the pattern "Now, adding all these up:" followed by an equation, and this has low similarity to the original chunk which (correctly) splits at : and so doesn't include the equation. If you want to fix this, you can try using our split_solution_into_chunks function from earlier to process the resampled chunks before plotting them. Moral of the story - this kind of string parsing is finnicky and easy to get wrong.
Metric (3/3): Counterfactual Importance
Finally, we'll implement counterfactual importance:
where $A^c$ only includes rollouts where the resampled sentence was different enough.
This is the same as the resampling importance (and we'll use the same data), but with one difference: we only keep resampled rollouts where the first resampled chunk $T_i$ is sufficiently different from the chunk $S_i$ which it replaces. Specifically, we discard rollouts where the embedding cosine similarity is 0.8 or above (i.e., too similar), keeping only those where the resampled chunk actually represents a different reasoning path.
The intuition for this metric: if resampling importance told us the effect when we choose a different sentence than $S_i$, then counterfactual importance tells us the effect when we choose a different reasoning path than represented by $S_i$. Low cosine similarity in this case is a proxy for the reasoning paths being very different (rather than just light rephrasings of what is essentially the same reasoning step).
Exercise - compute counterfactual importance
The function takes a generic score_fn parameter rather than hardcoding answer extraction. For math accuracy, you'll pass score_fn=lambda cot: extract_answer_from_cot(cot) == answer. We'll reuse this same function later for blackmail rate analysis with a different scoring function.
Note: some chunk positions may have fewer than min_samples filtered rollouts (because all resampled chunks are too similar to the original). You should handle these positions by setting their probability to None and forward-filling from the previous valid probability (using pd.Series.ffill()).
def calculate_counterfactual_importance(
chunks_removed: list[str],
chunks_resampled: list[list[str]],
rollout_data: list[list],
score_fn: Callable,
embedding_model: SentenceTransformer,
threshold: float = 0.8,
min_samples: int = 5,
) -> list[float]:
"""
Calculate counterfactual importance by filtering for low-similarity resamples.
This is a generic function: `rollout_data[i]` is a list of per-rollout values at chunk
position i, and `score_fn(rollout_data_item)` returns True/False for whether that rollout
counts as a "success". For math accuracy, rollout_data contains CoT strings and score_fn
checks if the extracted answer matches the ground truth. For blackmail rate, rollout_data
contains boolean labels and score_fn is just the identity.
Args:
chunks_removed: Original chunks that were removed
chunks_resampled: List of resampled chunks for each position
rollout_data: Per-rollout data for each chunk position (passed to score_fn)
score_fn: Function mapping a single rollout data item to True/False
threshold: Maximum cosine similarity to count as "different"
min_samples: Minimum samples needed to compute probability
embedding_model: Sentence embedding model
Returns:
List of counterfactual importance scores
"""
raise NotImplementedError("Implement counterfactual importance calculation")
Solution
def calculate_counterfactual_importance(
chunks_removed: list[str],
chunks_resampled: list[list[str]],
rollout_data: list[list],
score_fn: Callable,
embedding_model: SentenceTransformer,
threshold: float = 0.8,
min_samples: int = 5,
) -> list[float]:
"""
Calculate counterfactual importance by filtering for low-similarity resamples.
This is a generic function: `rollout_data[i]` is a list of per-rollout values at chunk
position i, and `score_fn(rollout_data_item)` returns True/False for whether that rollout
counts as a "success". For math accuracy, rollout_data contains CoT strings and score_fn
checks if the extracted answer matches the ground truth. For blackmail rate, rollout_data
contains boolean labels and score_fn is just the identity.
Args:
chunks_removed: Original chunks that were removed
chunks_resampled: List of resampled chunks for each position
rollout_data: Per-rollout data for each chunk position (passed to score_fn)
score_fn: Function mapping a single rollout data item to True/False
threshold: Maximum cosine similarity to count as "different"
min_samples: Minimum samples needed to compute probability
embedding_model: Sentence embedding model
Returns:
List of counterfactual importance scores
"""
# Get filtered indices for each chunk (using the module-level get_filtered_indices helper)
filtered_indices = [
get_filtered_indices(chunk, resampled, embedding_model, threshold)
for chunk, resampled in zip(chunks_removed, chunks_resampled)
]
# Compute P(success) using only filtered samples
probabilities = []
for data_list, indices in zip(rollout_data, filtered_indices):
if len(indices) >= min_samples:
successes = sum(score_fn(data_list[i]) for i in indices)
probabilities.append(successes / len(indices))
else:
probabilities.append(None)
# Forward-fill None values
probabilities = pd.Series(probabilities).ffill().bfill().tolist()
return np.diff(probabilities).tolist()
When you've filled this in, run the cells below to compute and plot the counterfactual importance scores next to your resampling importance scores.
You should find the two metrics (resampling and counterfactual) are mostly quite similar for this example. They differ most in sentences which were also shown from the plot above to have high semantic variance, because these are our thought anchors: sentences which guide the entire reasoning process, and so changing them to something different in embedding space has a large effect on the subsequent trajectory.
For example, chunk 3 "So, maybe I can just figure out how many hexadecimal digits there are..." has a higher counterfactual importance than resampling importance (about 50% higher). This is because the chunk represents a key part of the overall reasoning process: when the model doesn't say this, it often expresses a different plan such as "So, maybe I can start by converting each digit one by one..." or "So maybe I can just multiply the number of hex digits by 4..." - you can confirm this by looking at the dataset of rollouts yourself. Even if some of these plans will end up at the same place, the exact plan & phrasing that the model produces in this step will significantly affect its trajectory.
Discussion of some other chunks with very different resampling & counterfactual values
Chunks 53-58 are a series of active computations. The counterfactual metric should be zero on most or all of these, because all resampled rollouts have very similar semantic content (the model was forced to carry out a multi-stage computation in a specific way). This shows our counterfactual metric is working as intended.
Chunks 43-44 say "Now, according to this, it's 19 bits. There's a discrepancy here." Inspecting the dataset shows we've caught the model at a reasoning crossroads: about half the time it says "it's 19 bits" and the other half just says "there's a discrepancy here". So the counterfactual importance metric can be misleading depending on how we chunk rollouts. There's a similar story for chunk 28: "Now, adding these all up:" - sometimes the sequence is chunked with the equation in the same chunk, sometimes in a different one. Moral of the story: this kind of string parsing is finnicky and can easily cause issues!
# Calculate counterfactual importance
chunks_removed = [chunk["chunk"] for chunk in problem_data["chunks_labeled"]]
chunks_resampled = [
[rollout["chunk_resampled"] for rollout in chunk_rollouts] for chunk_rollouts in problem_data["chunk_solutions"]
]
full_cot_list = [
[rollout["full_cot"] for rollout in chunk_rollouts] for chunk_rollouts in problem_data["chunk_solutions"]
]
counterfactual_importances = calculate_counterfactual_importance(
chunks_removed,
chunks_resampled,
rollout_data=full_cot_list,
score_fn=lambda cot: extract_answer_from_cot(cot) == answer,
embedding_model=embedding_model,
)
# Compare with precomputed
# (We flip the sign because the authors store the negative of the counterfactual metric in the dataset)
precomputed_cf = [-chunk["counterfactual_importance_accuracy"] for chunk in problem_data["chunks_labeled"][:-1]]
avg_diff = np.abs(np.subtract(counterfactual_importances, precomputed_cf)).mean()
assert avg_diff < 0.025, "Error above 2.5% threshold"
print("Precomputed comparison passed!")
tests.test_calculate_counterfactual_importance(calculate_counterfactual_importance)
# Plot comparison of all three metrics with subplots (like previous bar chart)
chunks_for_hover = [chunk["chunk"] for chunk in problem_data["chunks_labeled"][:-1]]
hover_texts = [chunk[:100] + "..." if len(chunk) > 100 else chunk for chunk in chunks_for_hover]
fig = make_subplots(
rows=3,
cols=1,
shared_xaxes=True,
vertical_spacing=0.08,
subplot_titles=("Forced Answer Importance", "Resampling Importance", "Counterfactual Importance"),
)
for row, (name, importances, color) in enumerate(
[
("Forced", forced_importances, "cornflowerblue"),
("Resampling", resampling_importances, "orange"),
("Counterfactual", counterfactual_importances, "seagreen"),
],
start=1,
):
fig.add_trace(
go.Bar(
x=list(range(len(importances))),
y=importances,
name=name,
opacity=0.8,
marker_color=color,
hovertemplate="<b>Chunk %{x}</b><br>"
+ name
+ " Importance: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
customdata=hover_texts,
),
row=row,
col=1,
)
fig.add_hline(y=0, line_color="black", line_width=0.5, row=row, col=1)
fig.update_layout(
title="Comparison of Importance Metrics",
width=1000,
height=700,
showlegend=False,
)
fig.show()
Click to see the expected output
Exercise - replicate Figure 3b (sentence category effect)
As an open-ended challenge, try replicating Figure 3b from the paper. We've already got all the data you need to do this, so it's just a matter of understanding and plotting the data correctly.
We've defined our metrics in terms of accuracy rather than KL divergence, so expect the metrics to look slightly different (also we're only averaging over a single prompt's chunks, so the data will be noisier). But you should still get a qualitatively similar plot. Signs of life: Result Consolidation has the highest normalized position in trace (sanity check), and Plan Generation and Uncertainty Management have the highest counterfactual importance (core result).
Scaling up the analysis to many prompts and adding error bars is left as an exercise.
# YOUR CODE HERE - replicate figure 3b!
Click to see the expected output
Solution
# Create dataframe with position and importance
chunk_labels = [CATEGORIES[chunk["function_tags"][0]] for chunk in problem_data["chunks_labeled"]]
n_chunks = len(chunk_labels) - 1
df = pd.DataFrame(
{
"Label": chunk_labels[:-1],
"Importance": counterfactual_importances,
"Position": np.arange(n_chunks) / n_chunks,
}
)
# Get top 5 most common categories
top_labels = df["Label"].value_counts().head(5).index.tolist()
df_filtered = df[df["Label"].isin(top_labels)]
# Group and calculate means
grouped = df_filtered.groupby("Label")[["Importance", "Position"]].mean().reset_index()
# Plot
fig, ax = plt.subplots(figsize=(8, 6))
for _, row in grouped.iterrows():
ax.scatter(
row["Position"], row["Importance"], s=150, label=row["Label"], color=utils.CATEGORY_COLORS.get(row["Label"])
)
ax.set_xlabel("Normalized Position in Trace (0-1)")
ax.set_ylabel("Mean Counterfactual Importance")
ax.set_title("Sentence Category Effect (Figure 3b)")
ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.tight_layout()
plt.show()
# Sanity check: Plan Generation should have positive mean counterfactual importance
plan_gen_importance = grouped.loc[grouped["Label"] == "Plan Generation", "Importance"].values
assert len(plan_gen_importance) > 0 and plan_gen_importance[0] > 0, (
"Plan Generation should have positive mean counterfactual importance"
)
print("Sanity check passed: Plan Generation has positive counterfactual importance.")
Sentence-to-Sentence Causal Graph
So far we've only measured importance for the final answer: "how much does removing sentence $i$ affect whether the model gets the right answer?" But this misses the causal structure within the reasoning trace. A planning sentence might not directly affect the final answer, but it might be critical for producing a later computation sentence that does affect the answer.
The sentence-to-sentence causal graph captures this by asking: how important is sentence $i$ for whether sentence $j$ appears later? This connects our black-box importance metrics to the white-box attention patterns we'll study next - if sentence $i$ causally influences sentence $j$ in the black-box sense, we'd expect attention heads to attend strongly from $j$ back to $i$.
The paper computes this by: (1) for each target sentence $j$, looking at rollouts where sentence $i$ was kept vs removed, (2) checking how often sentence $j$ (or something semantically similar) appears in each set using embedding cosine similarity, and (3) taking the difference as the causal importance of $i$ on $j$. This produces a matrix of pairwise causal effects, which can be visualized as a directed graph revealing the "backbone" of the reasoning trace.
Before getting into the exercise, let's understand the visualization tool we'll use. utils.chunk_graph_html takes an importance matrix, chunk labels, and chunk texts, and renders an interactive causal graph. Here's a demo with random data:
# Demo: visualize a causal graph with sample data
n_chunks = 5
sample_labels = ["Problem Setup", "Plan Generation", "Active Computation", "Self Checking", "Final Answer Emission"]
sample_texts = [
"Let me start by understanding the problem...",
"I notice that this is equivalent to...",
"Computing the value: 3 * 7 = 21",
"Let me verify: 21 / 7 = 3, correct",
"Therefore the answer is 21",
]
# Create a sample importance matrix (i, j) = how much does chunk i causally influence chunk j
# We expect later chunks to be influenced by earlier ones. In our demo, each step influences
# the next one, plus plan generation influences everything!
sample_importance = np.zeros((n_chunks, n_chunks))
for i in range(4):
sample_importance[i, i + 1] += 0.5
for i in range(2, 5):
sample_importance[1, i] += 0.5
html_str = utils.chunk_graph_html(
edge_weights=sample_importance,
chunk_labels=sample_labels,
chunk_texts=sample_texts,
)
display(HTML(html_str))
Click to see the expected output
Exercise - replicate causal graph
We've broken this into four sub-functions for you to fill in. Tackle them in order.
Some notes before you start:
- The
problem_datawe've loaded in corresponds to the "Case study transcript" problem in appendix C.1 of the thought anchors paper. Note that they only compute the sentence-to-sentence causal graph for chunks 0-73 inclusive, since they observe (as you should have found in the counterfactual importance graphs generated above) that the resampling importance drops off to zero or near zero past this point. To save time, we recommend you follow this convention i.e. only compute the graph up to chunk 73. - For reference, the solution code takes about 6 minutes to run on an A100 GPU, with about half the time spent on computing the rollout chunks' embeddings, and about half of the time on the pairwise importance (i.e. batching cosine similarity calculations), although this was probably under-optimized and a more efficient solution is likely possible.
The four functions you need to fill in are:
precompute_rollout_embeddings: Collect all rollout sentences, batch-embed them, and reconstruct per-rollout embedding arrays. The key idea is to collect all sentences from all rollouts into a single list, batch-encode them in one call, then use a position map to split the embeddings back into per-chunk, per-rollout arrays.
Hint - approach for precompute_rollout_embeddings
Use a two-pass approach. First pass: loop over chunks and rollouts, split each rollout text into sentences with split_solution_into_chunks, append them all to a single flat list all_sentences, and record (start_idx, end_idx) in a position map so you know which slice of the flat list corresponds to each rollout. Second pass: call embedding_model.encode(all_sentences, batch_size=batch_size) once, then use the position map to slice the resulting embedding array back into per-chunk, per-rollout arrays.
precompute_target_embeddings: Extract chunk texts and embed them (straightforward).compute_match_rate_from_embeddings: Given a target sentence embedding and a list of rollout embeddings, compute the fraction of rollouts that contain at least one sentence with cosine similarity above the threshold. We've given you the setup code for tracking rollout boundaries; you need to fill in the cosine similarity computation and match counting.compute_pairwise_importance: For a given (source, target) pair, compare the match rate when the source chunk was kept vs removed. This usescompute_match_rate_from_embeddingson two different sets of rollouts.
Hint 2 - testing your approach incrementally
The full computation can take ~6 minutes, so test each function on a small subset before running everything. For example:
- Test
precompute_rollout_embeddingson just the first 3 chunks:precompute_rollout_embeddings(chunk_solutions[:3], embedding_model)and verify the shapes look right - Test
precompute_target_embeddingsonn_chunks=3and check the output array shape - Test
compute_match_rate_from_embeddingson a single target embedding and a small list of rollout embeddings to confirm it returns a float in [0, 1] - Test
compute_pairwise_importanceon a single (source=0, target=5) pair before the full pairwise loop
Only run the full computation once you're confident each piece works correctly.
At the end, the code calls utils.chunk_graph_html with your importance matrix and the chunk labels/texts to visualize the graph.
Checking your answer and interpreting the graph
You can check your answer by going to thought-anchors.com and selecting "Hex to binary" in the Problem dropdown. Their demo shows the full 144 nodes rather than the subset of 74 we recommend, but you should still be able to tell if yours is correct (for example, there's one planning node which appears early on in the graph and has very high downstream importance - can you replicate this finding in your graph?).
Investigate the graph for interesting patterns. Can you find: a planning sentence A that boosts a computation sentence B you'd only compute if following A's plan? An uncertainty management sentence C that boosts the sentence D resolving that uncertainty? Blocks of active computation chunks that don't really affect each other causally (showing the counterfactual method is better than normal resampling)?
Don't read too much into this single sequence though - concrete takeaways only come from averaging over many sequences.
def precompute_rollout_embeddings(
chunk_solutions: list[list[dict]],
embedding_model: SentenceTransformer,
batch_size: int = 128,
) -> tuple[list[list[np.ndarray]], list[list[list[str]]]]:
"""
Precompute embeddings for all sentences in all rollouts using batched encoding. Does this by
first collecting every single rollout from each chunk, then batch encode them all.
Args:
chunk_solutions: List of chunk solutions, where chunk_solutions[k] contains
rollouts from resampling after chunk k
embedding_model: Sentence embedding model
batch_size: Batch size for embedding model encoding
Returns:
rollout_embeddings: rollout_embeddings[i][j] = array of shape (n_sentences, embed_dim)
rollout_sentences: rollout_sentences[i][j] = list of sentence strings
"""
# First pass: collect all sentences and track their positions
all_sentences: list[str] = []
# position_map[i][j] = (start_idx, end_idx) into all_sentences, or None if empty
position_map: list[list[tuple[int, int] | None]] = []
rollout_sentences: list[list[list[str]]] = []
# YOUR CODE HERE!
return rollout_embeddings, rollout_sentences
def precompute_target_embeddings(
chunks: list[dict],
embedding_model: SentenceTransformer,
n_chunks: int,
) -> Float[np.ndarray, "n_chunks embed_dim"]:
"""
Precompute embeddings for all target chunk sentences.
Args:
chunks: List of chunk dictionaries with "chunk" key
embedding_model: Sentence embedding model
n_chunks: Number of chunks to process
Returns:
Array of shape (n_chunks, embed_dim) with target embeddings
"""
# YOUR CODE HERE!
return embedding_model.encode(chunk_texts)
def compute_match_rate_from_embeddings(
target_embedding: Float[np.ndarray, " embed_dim"],
rollout_embeddings_list: list[Float[np.ndarray, "n_sentences embed_dim"]],
similarity_threshold: float = 0.7,
) -> float:
"""
Compute fraction of rollouts containing a sentence similar to target.
Args:
target_embedding: Embedding of target sentence, shape (embed_dim,)
rollout_embeddings_list: List of arrays, each shape (n_sentences, embed_dim)
similarity_threshold: Minimum cosine similarity for a match
Returns:
Fraction of rollouts with a matching sentence (0 to 1)
"""
# Filter out empty embeddings and track rollout boundaries
valid_embeddings = []
rollout_boundaries = [] # (start, end) indices into concatenated array
current_idx = 0
for embeddings in rollout_embeddings_list:
valid_embeddings.append(embeddings)
rollout_boundaries.append((current_idx, current_idx + len(embeddings)))
current_idx += len(embeddings)
# YOUR CODE HERE!
matches = ... # compute number of rollouts with a match using cosine similarity and threshold
return matches / len(rollout_boundaries)
def compute_pairwise_importance(
source_idx: int,
target_idx: int,
target_embeddings: Float[np.ndarray, "n_chunks embed_dim"],
rollout_embeddings: list[list[Float[np.ndarray, "n_sentences embed_dim"]]],
similarity_threshold: float = 0.7,
) -> float:
"""
Compute causal importance of sentence source_idx on sentence target_idx.
Uses precomputed embeddings for efficiency.
Compares:
- Rollouts from chunk_{source_idx + 1} (source was KEPT)
- Rollouts from chunk_{source_idx} (source was REMOVED)
"""
# YOUR CODE HERE!
include_rate = ... # compute match rate for rollouts where source was kept
exclude_rate = ... # compute match rate for rollouts where source was removed
return include_rate - exclude_rate
# Precompute all embeddings upfront
n_chunks = 74
target_embeddings = precompute_target_embeddings(problem_data["chunks_labeled"], embedding_model, n_chunks=n_chunks)
rollout_embeddings, rollout_sentences = precompute_rollout_embeddings(
problem_data["chunk_solutions"], embedding_model
)
# Compute importance matrix
importance_matrix = np.zeros((n_chunks, n_chunks))
for i in tqdm(range(n_chunks - 1), desc="Computing pairwise importance"):
for j in range(i + 1, n_chunks):
importance_matrix[i, j] = compute_pairwise_importance(i, j, target_embeddings, rollout_embeddings)
# Plot the importance matrix with plotly
# Mask upper triangle of transposed matrix (we want to show lower triangle which has the values)
mask = np.triu(np.ones_like(importance_matrix, dtype=bool))
importance_matrix = np.where(mask, importance_matrix, 0.0)
chunk_labels = [CATEGORIES[chunk["function_tags"][0]] for chunk in problem_data["chunks_labeled"][:n_chunks]]
chunk_texts = [chunk["chunk"] for chunk in problem_data["chunks_labeled"][:n_chunks]]
html_str = utils.chunk_graph_html(
edge_weights=importance_matrix,
chunk_labels=chunk_labels,
chunk_texts=chunk_texts,
n_top_edges_per_direction=3,
)
display(HTML(html_str))
Click to see the expected output
Solution
def precompute_rollout_embeddings(
chunk_solutions: list[list[dict]],
embedding_model: SentenceTransformer,
batch_size: int = 128,
) -> tuple[list[list[np.ndarray]], list[list[list[str]]]]:
"""
Precompute embeddings for all sentences in all rollouts using batched encoding. Does this by
first collecting every single rollout from each chunk, then batch encode them all.
Args:
chunk_solutions: List of chunk solutions, where chunk_solutions[k] contains
rollouts from resampling after chunk k
embedding_model: Sentence embedding model
batch_size: Batch size for embedding model encoding
Returns:
rollout_embeddings: rollout_embeddings[i][j] = array of shape (n_sentences, embed_dim)
rollout_sentences: rollout_sentences[i][j] = list of sentence strings
"""
# First pass: collect all sentences and track their positions
all_sentences: list[str] = []
# position_map[i][j] = (start_idx, end_idx) into all_sentences, or None if empty
position_map: list[list[tuple[int, int] | None]] = []
rollout_sentences: list[list[list[str]]] = []
for i in tqdm(range(len(chunk_solutions)), desc="Precomputing embeddings"):
chunk_positions = []
chunk_sentences = []
for rollout_dict in chunk_solutions[i]:
rollout_text = rollout_dict["rollout"]
sentences = split_solution_into_chunks(rollout_text)
assert sentences, "Expected at least one sentence per rollout"
all_sentences.extend(sentences)
chunk_positions.append((len(all_sentences) - len(sentences), len(all_sentences)))
chunk_sentences.append(sentences)
position_map.append(chunk_positions)
rollout_sentences.append(chunk_sentences)
# Batch embed all sentences at once
all_embeddings = embedding_model.encode(all_sentences, batch_size=batch_size, show_progress_bar=True)
# Reconstruct per-rollout embeddings using position map
rollout_embeddings: list[list[np.ndarray]] = []
for i in range(len(chunk_solutions)):
chunk_embeddings = []
for position in position_map[i]:
start_idx, end_idx = position
chunk_embeddings.append(all_embeddings[start_idx:end_idx])
rollout_embeddings.append(chunk_embeddings)
return rollout_embeddings, rollout_sentences
def precompute_target_embeddings(
chunks: list[dict],
embedding_model: SentenceTransformer,
n_chunks: int,
) -> Float[np.ndarray, "n_chunks embed_dim"]:
"""
Precompute embeddings for all target chunk sentences.
Args:
chunks: List of chunk dictionaries with "chunk" key
embedding_model: Sentence embedding model
n_chunks: Number of chunks to process
Returns:
Array of shape (n_chunks, embed_dim) with target embeddings
"""
chunk_texts = [chunks[i]["chunk"] for i in range(n_chunks)]
return embedding_model.encode(chunk_texts)
def compute_match_rate_from_embeddings(
target_embedding: Float[np.ndarray, " embed_dim"],
rollout_embeddings_list: list[Float[np.ndarray, "n_sentences embed_dim"]],
similarity_threshold: float = 0.7,
) -> float:
"""
Compute fraction of rollouts containing a sentence similar to target.
Args:
target_embedding: Embedding of target sentence, shape (embed_dim,)
rollout_embeddings_list: List of arrays, each shape (n_sentences, embed_dim)
similarity_threshold: Minimum cosine similarity for a match
Returns:
Fraction of rollouts with a matching sentence (0 to 1)
"""
# Filter out empty embeddings and track rollout boundaries
valid_embeddings = []
rollout_boundaries = [] # (start, end) indices into concatenated array
current_idx = 0
for embeddings in rollout_embeddings_list:
valid_embeddings.append(embeddings)
rollout_boundaries.append((current_idx, current_idx + len(embeddings)))
current_idx += len(embeddings)
# Concatenate all embeddings: (total_sentences, embed_dim)
all_embeddings = np.vstack(valid_embeddings)
# Normalize target and all embeddings for cosine similarity
target_norm = target_embedding / (np.linalg.norm(target_embedding) + 1e-8)
all_norms = np.linalg.norm(all_embeddings, axis=1, keepdims=True) + 1e-8
all_embeddings_norm = all_embeddings / all_norms
# Compute all cosine similarities at once: (total_sentences,)
all_similarities = all_embeddings_norm @ target_norm
# Count matches per rollout using boundaries
matches = sum(1 for start, end in rollout_boundaries if np.max(all_similarities[start:end]) >= similarity_threshold)
return matches / len(rollout_boundaries)
def compute_pairwise_importance(
source_idx: int,
target_idx: int,
target_embeddings: Float[np.ndarray, "n_chunks embed_dim"],
rollout_embeddings: list[list[Float[np.ndarray, "n_sentences embed_dim"]]],
similarity_threshold: float = 0.7,
) -> float:
"""
Compute causal importance of sentence source_idx on sentence target_idx.
Uses precomputed embeddings for efficiency.
Compares:
- Rollouts from chunk_{source_idx + 1} (source was KEPT)
- Rollouts from chunk_{source_idx} (source was REMOVED)
"""
if source_idx >= target_idx:
return 0.0
target_embedding = target_embeddings[target_idx]
# Rollouts where source was kept
include_embeddings = rollout_embeddings[source_idx + 1]
# Rollouts where source was removed
exclude_embeddings = rollout_embeddings[source_idx]
include_rate = compute_match_rate_from_embeddings(target_embedding, include_embeddings, similarity_threshold)
exclude_rate = compute_match_rate_from_embeddings(target_embedding, exclude_embeddings, similarity_threshold)
return include_rate - exclude_rate
# Precompute all embeddings upfront
n_chunks = 74
target_embeddings = precompute_target_embeddings(problem_data["chunks_labeled"], embedding_model, n_chunks=n_chunks)
rollout_embeddings, rollout_sentences = precompute_rollout_embeddings(
problem_data["chunk_solutions"], embedding_model
)
# Compute importance matrix
importance_matrix = np.zeros((n_chunks, n_chunks))
for i in tqdm(range(n_chunks - 1), desc="Computing pairwise importance"):
for j in range(i + 1, n_chunks):
importance_matrix[i, j] = compute_pairwise_importance(i, j, target_embeddings, rollout_embeddings)
# Plot the importance matrix with plotly
# Mask upper triangle of transposed matrix (we want to show lower triangle which has the values)
mask = np.triu(np.ones_like(importance_matrix, dtype=bool))
importance_matrix = np.where(mask, importance_matrix, 0.0)
chunk_labels = [CATEGORIES[chunk["function_tags"][0]] for chunk in problem_data["chunks_labeled"][:n_chunks]]
chunk_texts = [chunk["chunk"] for chunk in problem_data["chunks_labeled"][:n_chunks]]
html_str = utils.chunk_graph_html(
edge_weights=importance_matrix,
chunk_labels=chunk_labels,
chunk_texts=chunk_texts,
n_top_edges_per_direction=3,
)
display(HTML(html_str))
Bonus: Implementing Your Own Resampling (local models)
Finally, a bonus exercise for implementing your own resampling method using models loaded locally. This is not the method the authors used (they used API credits), but it's here as an optional exercise if you want to get into the nuts and bolts of how the resampling works.
Exercise - implement resampling (optional)
In the previous exercises we used pre-computed resampled rollouts from the dataset. Here, you'll implement the resampling yourself using a locally loaded model. The approach is:
- Generate a base completion: Run the model on the prompt to get a full CoT, then split it into chunks using
split_solution_into_chunks. - For each chunk, construct a prefix: Take the prompt plus all text up to (but not including) that chunk. Be careful with chunks that appear multiple times in the trace - you need to split on the correct instance.
- Generate resampled completions: From each prefix, generate
num_resamples_per_chunknew completions in batches. Store the full CoT, the first resampled chunk, and the original chunk that was replaced.
We've provided a StopOnThink stopping criteria helper and a generate inner function. You should also ensure num_resamples_per_chunk is divisible by batch_size.
class StopOnThink(StoppingCriteria):
"""Helper class for stopping generation when the <think>...</think> tags are closed."""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.think_token_id = tokenizer.encode("</think>", add_special_tokens=False)[0]
def __call__(self, input_ids, scores, **kwargs):
return input_ids[0, -1] == self.think_token_id
def get_resampled_rollouts(
prompt: str,
model: transformers.models.llama.modeling_llama.LlamaForCausalLM,
tokenizer: transformers.models.llama.tokenization_llama.LlamaTokenizer,
num_resamples_per_chunk: int = 100,
batch_size: int = 4,
max_new_tokens: int = 2048,
up_to_n_chunks: int | None = None,
) -> tuple[str, list[str], list[list[dict]]]:
"""
After each chunk in `chunks`, computes a number of resampled rollouts.
Args:
prompt: The initial problem prompt (which ends with a <think> tag).
num_resamples_per_chunk: Number of resamples to compute for each chunk.
Returns:
Tuple of (full_answer, chunks, resampled_rollouts) where the latter is a list of lists of
dicts (one for each chunk & resample on that chunk).
"""
@t.inference_mode()
def generate(inputs):
return model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.6,
top_p=0.95,
do_sample=True,
stopping_criteria=[StopOnThink(tokenizer=tokenizer)],
pad_token_id=tokenizer.eos_token_id,
)
# YOUR CODE HERE - implement resampling procedure as described in the text!
return full_answer, chunks, chunk_rollouts
# Load the model for generation (only if you want to try this)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_8B, dtype=dtype, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_8B)
t0 = time.time()
full_answer_resampled, chunks_resampled, chunk_rollouts_resampled = get_resampled_rollouts(
prompt=problem_data["base_solution"]["prompt"],
model=model,
tokenizer=tokenizer,
num_resamples_per_chunk=4,
batch_size=2,
max_new_tokens=2048,
up_to_n_chunks=5,
)
print(f"Got resampled rollouts in {time.time() - t0:.2f} seconds")
chunk_rollouts_resampled[0]
for i, resamples in enumerate(chunk_rollouts_resampled):
print("Replaced chunk: " + repr(resamples[0]["chunk_replaced"]))
for j, r in enumerate(resamples):
print(f" Resample {j}: " + repr(r["chunk_resampled"]))
print()
Click to see the expected output
Got resampled rollouts in 79.88 seconds
Replaced chunk: "Okay, so I have this problem where I need to find out how many base-2 digits (bits) the base-16 number 66666₁₆ has when it's written in binary."
Resample 0: 'To determine how many base-2 digits the base-16 number \\(66666_{16}\\) has when converted to base-2, we first need to understand the relationship between hexadecimal and binary representations.'
Resample 1: 'To determine how many base-2 digits (bits) the base-16 number \\(66666_{16}\\) has when written in binary, I need to first convert this hexadecimal number to its decimal equivalent.'
Resample 2: 'Okay, so I have this math problem here:'
Resample 3: "Okay, so I need to figure out how many base-2 digits (bits) the base-16 number \\(66666_{16}\\) has when it's converted to binary."
Replaced chunk: 'Hmm, base conversions can sometimes be tricky, but I think I can handle this step by step.'
Resample 0: 'Hmm, base conversions can sometimes be tricky, but let me break it down step by step.'
Resample 1: 'Hmm, base conversions can sometimes be tricky, but let me try to work through this step by step.'
Resample 2: 'Hmm, let me think about how to approach this.'
Resample 3: 'Hmm, base conversions can sometimes be a bit tricky, but I think I can handle this step by step.'
Replaced chunk: 'Let me try to remember how base conversions work.'
Resample 0: 'Let me recall what I know about number bases.'
Resample 1: 'Let me try to remember how base conversions work, especially from hexadecimal to binary.'
Resample 2: 'Let me try to break it down.'
Resample 3: 'Let me see.'
Replaced chunk: 'First off, base-16, or hexadecimal, uses 16 different symbols to represent 0 through 15 in binary.'
Resample 0: "First, I know that base-16, or hexadecimal, is a common numbering system used in computers because it's easy to represent with just 16 symbols, each corresponding to a value from 0 to 15."
Resample 1: 'First, I know that base-16, which is hexadecimal, uses digits from 0 to 15, and each digit represents four bits in binary.'
Resample 2: 'First, I know that base-16, or hexadecimal, is a common numbering system used in computers.'
Resample 3: 'First, I know that base-16, or hexadecimal, uses digits from 0 to 9 and then A to F to represent values from 0 to 15.'
Replaced chunk: 'Each digit in a base-16 number corresponds to 4 bits in binary because 2⁴ is 16.'
Resample 0: 'The symbols are 0-9 and then A-F, where A is 10, B is 11, up to F being 15.'
Resample 1: 'Each digit in a hexadecimal number corresponds to 4 binary digits because 2⁴ is 16.'
Resample 2: 'Each digit in a hexadecimal number corresponds to 4 binary digits because 2⁴ = 16.'
Resample 3: 'Each digit in a hexadecimal number corresponds to 4 binary digits because 2⁴ = 16.'
Expected output from the cell above
Replaced chunk: 'Alright, so I have this problem here:'
Resample 0: 'Okay, so I need to figure out how many bits are in the binary representation of the hexadecimal number 66666₁₆.'
Resample 1: "Alright, so I've got this problem here:"
Replaced chunk: 'When the base-16 number \\(66666_{16}\\) is written in base 2, how many base-2 digits (bits) does it have?'
Resample 0: 'when the base-16 number \\(66666_{16}\\) is written in base 2, how many base-2 digits (bits) does it have?'
Resample 1: 'when the base-16 number 66666₁₆ is written in base 2, how many base-2 digits (bits) does it have?'
Replaced chunk: 'Hmm, okay.'
Resample 0: 'Hmm, okay.'
Resample 1: 'Hmm, okay.'
Replaced chunk: 'Let me try to figure this out step by step.'
Resample 0: 'Let me think about how to approach this.'
Resample 1: 'I need to figure out how many bits are required to represent this hexadecimal number in binary.'
Solution
class StopOnThink(StoppingCriteria):
"""Helper class for stopping generation when the <think>...</think> tags are closed."""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.think_token_id = tokenizer.encode("</think>", add_special_tokens=False)[0]
def __call__(self, input_ids, scores, **kwargs):
return input_ids[0, -1] == self.think_token_id
def get_resampled_rollouts(
prompt: str,
model: transformers.models.llama.modeling_llama.LlamaForCausalLM,
tokenizer: transformers.models.llama.tokenization_llama.LlamaTokenizer,
num_resamples_per_chunk: int = 100,
batch_size: int = 4,
max_new_tokens: int = 2048,
up_to_n_chunks: int | None = None,
) -> tuple[str, list[str], list[list[dict]]]:
"""
After each chunk in `chunks`, computes a number of resampled rollouts.
Args:
prompt: The initial problem prompt (which ends with a <think> tag).
num_resamples_per_chunk: Number of resamples to compute for each chunk.
Returns:
Tuple of (full_answer, chunks, resampled_rollouts) where the latter is a list of lists of
dicts (one for each chunk & resample on that chunk).
"""
@t.inference_mode()
def generate(inputs):
return model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.6,
top_p=0.95,
do_sample=True,
stopping_criteria=[StopOnThink(tokenizer=tokenizer)],
pad_token_id=tokenizer.eos_token_id,
)
assert num_resamples_per_chunk % batch_size == 0
# First, generate a completion which we'll split up into chunks.
inputs = tokenizer(prompt, return_tensors="pt").to(device)
output = generate(inputs)
full_answer = tokenizer.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=False)
chunks = split_solution_into_chunks(full_answer)
# Second, generate resamples at each chunk.
chunk_rollouts = []
n_chunk_instances = defaultdict(int)
for chunk in tqdm(chunks[:up_to_n_chunks]):
chunk_rollouts.append([])
# To get the answer before the chunk, we split on the correct instance of the chunk (since
# this chunk might have appeared multiple times in the answer).
n_chunk_instances[chunk] += 1
full_answer_split = (
prompt + chunk.join(full_answer.split(chunk, maxsplit=n_chunk_instances[chunk])[:-1]).strip()
)
inputs = tokenizer([full_answer_split] * batch_size, return_tensors="pt").to(device)
for _ in range(num_resamples_per_chunk // batch_size):
output_batch = generate(inputs)
for output_generation in output_batch:
generated_text = tokenizer.decode(
output_generation[inputs["input_ids"].shape[1] :], skip_special_tokens=False
)
chunk_rollouts[-1].append(
{
"full_cot": full_answer_split + generated_text,
"chunk_resampled": split_solution_into_chunks(generated_text)[0],
"chunk_replaced": chunk,
"rollout": generated_text,
}
)
return full_answer, chunks, chunk_rollouts
# Load the model for generation (only if you want to try this)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_8B, dtype=dtype, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_8B)
t0 = time.time()
full_answer_resampled, chunks_resampled, chunk_rollouts_resampled = get_resampled_rollouts(
prompt=problem_data["base_solution"]["prompt"],
model=model,
tokenizer=tokenizer,
num_resamples_per_chunk=4,
batch_size=2,
max_new_tokens=2048,
up_to_n_chunks=5,
)
print(f"Got resampled rollouts in {time.time() - t0:.2f} seconds")
chunk_rollouts_resampled[0]
for i, resamples in enumerate(chunk_rollouts_resampled):
print("Replaced chunk: " + repr(resamples[0]["chunk_replaced"]))
for j, r in enumerate(resamples):
print(f" Resample {j}: " + repr(r["chunk_resampled"]))
print()
When you've got this working, the rest is up to you. Generating enough rollouts for statistical analysis is beyond the scope of these exercises, but it could be a fun project if you want practice working at larger scale and doing full paper replications.