2️⃣ Black-box Analysis
Learning Objectives
- Understand forced answer importance: measuring impact by directly forcing model answers at specific steps
- Implement resampling importance: measuring impact by replacing steps with samples from other reasoning traces
- Implement counterfactual importance: measuring impact by resampling conditioned on keeping the same answer
- Learn when each metric is appropriate (causal vs correlational analysis)
- Replicate key paper figures showing which reasoning steps are most critical for final answers
Now that we've split reasoning traces into sentences and categorized them, we can treat each sentence as a single unit and ask: which sentences are most important for the model's final answer?
This section introduces three black-box methods for measuring sentence importance. These are called "black-box" because they only require observing inputs and outputs - we don't need to look inside the model at activations or attention patterns. By intervening on individual sentences (forcing answers, resampling, or replacing with counterfactuals), we can measure how much each sentence shapes the trajectory of the model's reasoning and its final conclusion.
The paper used three different methods to measure sentence importance:
1. Forced Answer Importance. For each sentence $S_i$ in the CoT, we interrupt the model and append text, inducing a final output: Therefore, the final answer is \\boxed{...}. We then measure the model's accuracy when forced to answer immediately. If $A^f_i$ is the final answer when we force the model to answer immediately after $S_i$, and $P(A)$ is the probability of answer $A$ being correct, then the formula for forced answer importance is:
Can you see a flaw in this approach, when it comes to identifying sentences which are critical for the model's reasoning process? Click here to reveal the answer.
Some sentence $S$ might be necessary for the final answer, but comes late in the reasoning process, meaning all sentences before $S$ will result in low accuracy by this metric. We will only pick up sentences whose inclusion in combination with all previous sentences gets us from the wrong answer to the right one.
For people who have completed the IOI material / done much work with model internals, this is the equivalent of finding the most important model components by seeing which ones write the final answer to the residual stream. It's a good start, but doesn't tell you much about the steps of the computation beyond the last one.
2. Resampling Importance. To address the flaw above, for each sentence $S_i$ in the CoT, we can resample a whole trajectory $(S_1, S_2, \ldots, S_{i-1}, S_i, S'_{i+1}, ..., S'_N, A^r_i)$ (where $A^r_i$ is the final answer we get from resampling chunks after $S_i$). By comparing it to the corresponding trajectory $(S_1, S_2, \ldots, S'_i, ..., S'_{N}, A^r_i)$ which we get from resampling chunks including $S_i$, we can get a sense for how important $S_i$ was for producing the final answer. Our metric is:
3. Counterfactual Importance. An issue with resampling importance is that often $S'_i$ will be very similar to $S_i$, if the reasoning context strongly constrains what can be expressed at that position. To fix this, we can filter for cases where these two sentences are fairly different, using a semantic similarity metric derived from our embedding model. This gives us a counterfactual importance metric which is identical to the previous one, but just with filtered rollouts:
In these sections, we'll work through each of these metrics in turn. The focus will be computing these metrics on the dataset we've already been provided, with the full replication (including your own implementation of resampling) coming in the next section.
Looking closer at our dataset
Before getting into the exercises, let's look closer at our dataset to see what we're working with. As mentioned, we'll be setting aside our actual model and chunking functions temporarily, so we can focus on analysing the resampled rollouts already provided to us in this dataset.
The code below loads in all the data necessary for analyzing a single problem. This contains:
problemwhich contains the problem statement, along with some basic metadata (including the answer)chunks_labeled[i], data for thei-th chunk (e.g. what category it is, plus some metrics)chunk_solutions_forced[i][j], data for thej-th rollout we get from forcing an answer immediately after thei-th chunk (e.g. fori=0this means forcing an answer before any chunks are included)chunk_solutions[i][j], data for thej-th rollout we get from resampling immediately after thei-th chunk
We'll be using this data to implement the three importance metrics described above.
Exercise - inspect the dataset
Run the code below to load in the data for a single problem (note we're using generations from deepseek-r1-distill-qwen-14b rather than our Llama-8b model here, so that we match the case study in the paper's appendix).
Make sure you understand the structure of the data, since this will make the following exercises easier.
Here are a few investigative questions you might try to answer when inspecting this problem's data:
- How has the chunking worked here? Can you see any obvious issues with how it's been applied, e.g. places where a chunk has been split that shouldn't have been?
- Do the categories look reasonable? You can try comparing them to each of your autoraters, and see what fraction match.
- Inspect
chunk_solutions_forced, and see how early the model generally manages to get the answer correct. - Inspect
chunk_solutions, and see how much variety the model's completions have when resampled from various stages in the reasoning process.
def load_single_file(file_path: str):
local_path = hf_hub_download(repo_id=DATASET_NAME, filename=file_path, repo_type="dataset")
with open(local_path, "r") as f:
return json.load(f)
def load_problem_data(problem_id: int, model_name: str = "deepseek-r1-distill-qwen-14b", verbose: bool = True):
disable_progress_bars()
problem_dir = "correct_base_solution"
problem_dir_forced = "correct_base_solution_forced_answer"
problem_path = f"{model_name}/temperature_0.6_top_p_0.95/{problem_dir}/problem_{problem_id}"
problem_path_forced = f"{model_name}/temperature_0.6_top_p_0.95/{problem_dir_forced}/problem_{problem_id}"
base_solution = load_single_file(f"{problem_path}/base_solution.json")
problem = load_single_file(f"{problem_path}/problem.json")
chunks_labeled = load_single_file(f"{problem_path}/chunks_labeled.json")
n_chunks = len(chunks_labeled)
chunk_solutions = [None] * n_chunks
chunk_solutions_forced = [None] * n_chunks
def load_chunk(chunk_idx):
sol = load_single_file(f"{problem_path}/chunk_{chunk_idx}/solutions.json")
forced = load_single_file(f"{problem_path_forced}/chunk_{chunk_idx}/solutions.json")
return chunk_idx, sol, forced
with ThreadPoolExecutor(max_workers=16) as executor:
futures = [executor.submit(load_chunk, i) for i in range(n_chunks)]
for future in tqdm(as_completed(futures), total=n_chunks, disable=not verbose, desc="Loading chunks"):
idx, sol, forced = future.result()
chunk_solutions[idx] = sol
chunk_solutions_forced[idx] = forced
enable_progress_bars()
return {
"problem": problem,
"base_solution": base_solution,
"chunks_labeled": chunks_labeled,
"chunk_solutions_forced": chunk_solutions_forced,
"chunk_solutions": chunk_solutions,
}
# Load a problem
problem_data = load_problem_data(PROBLEM_ID)
# Inspect the problem structure
print("Problem:", problem_data["problem"]["problem"][:200], "...")
print(f"\nGround truth answer: {problem_data['problem']['gt_answer']}")
print(f"Number of chunks: {len(problem_data['chunks_labeled'])}")
print(f"Rollouts per chunk: {len(problem_data['chunk_solutions'][0])}")
# Show first few chunks
for i, c in enumerate(problem_data["chunks_labeled"][:5]):
print(f"{i}. [{c['function_tags'][0]}] {c['chunk'][:80]}...")
Click to see the expected output
Problem: When the base-16 number $66666_{16}$ is written in base 2, how many base-2 digits (bits) does it have? ...
Ground truth answer: 19
Number of chunks: 145
Rollouts per chunk: 100
0. [problem_setup] Okay, so I have this problem where I need to find out how many bits the base-16 ...
1. [problem_setup] Hmm, let's see....
2. [fact_retrieval] I remember that each hexadecimal digit corresponds to exactly 4 binary digits, o...
3. [plan_generation] So, maybe I can just figure out how many hexadecimal digits there are and multip...
4. [uncertainty_management] Let me check that....
Metric (1/3): Forced Answer Importance
For each sentence $S_i$, we interrupt the model and force it to output a final answer. We measure how accuracy changes:
What's the limitation of this approach?
This only captures the final step that tips the scales - when all previous sentences together finally produce enough information for the correct answer. It doesn't tell us which earlier sentences were critical for setting up that final step.
Analogy: It's like measuring which brick "caused" a wall to reach a certain height - technically only the last brick did, but all previous bricks were necessary too.
Exercise - calculate forced answer importance
You should fill in the function below, to compute the forced answer importance on a set of chunks and their labelled categories. Note that the function takes a list of lists of full CoT rollouts, meaning you'll need to parse out the model's answer from the CoT yourself.
Note, the model's return type can vary: for example in a question about interest rates with the answer 6.17, the model sometimes responds with e.g. 6.17% or r = 6.17. In this case study we're dealing with a problem that has an integer solution so this isn't really an issue, but it's still good practice to handle these kinds of cases when you're parsing the answer.
def calculate_answer_importance(full_cot_list: list[list[str]], answer: str) -> list[float]:
"""
Calculate importance for chunks based on accuracy differences.
Args:
full_cot_list: List of lists of rollouts. full_cot_list[i][j] is the j-th rollout
generated by forcing an answer after the i-th chunk.
answer: The ground truth answer.
Returns:
List of importance scores (one fewer than chunks, since we measure differences).
"""
raise NotImplementedError("Implement answer importance calculation")
tests.test_calculate_answer_importance(calculate_answer_importance)
Solution
def calculate_answer_importance(full_cot_list: list[list[str]], answer: str) -> list[float]:
"""
Calculate importance for chunks based on accuracy differences.
Args:
full_cot_list: List of lists of rollouts. full_cot_list[i][j] is the j-th rollout
generated by forcing an answer after the i-th chunk.
answer: The ground truth answer.
Returns:
List of importance scores (one fewer than chunks, since we measure differences).
"""
def extract_answer_from_cot(cot: str) -> str:
ans = cot.split("\\boxed{")[-1].split("}")[0]
return "".join(char for char in ans if char.isdigit() or char == ".")
# Get P(correct) for each chunk position
probabilities = [
sum(extract_answer_from_cot(cot) == answer for cot in cot_list) / len(cot_list) for cot_list in full_cot_list
]
# Return differences: P(A_{S_i}) - P(A_{S_{i-1}})
return np.diff(probabilities).tolist()
tests.test_calculate_answer_importance(calculate_answer_importance)
When you've done this, run the cell below to get your results and plot them. Note, this should not match the paper's "Figure 2" results yet, since we're using forced answer importance, not resampled or counterfactual importance.
What do you notice about these results? What sentences are necessary for the model to start getting greater-than-zero accuracy? Are there any sentences which significantly drop or raise the model's accuracy, and can you explain why?
What patterns do you notice? (click to see discussion)
You should see that: - Most early chunks have ~0 importance (the model can't answer correctly yet) - There are sharp spikes near the end when critical computations are made - Some chunks have negative importance - they might introduce confusion or errors
For example, if a model is near the end of its reasoning, it might follow a pattern like "maybe X, wait no actually Y", and this will result in wildly careening forced answer importance values for those last few sentences, which aren't really representative of how important the sentence was in the actual reasoning process.
The limitation is clear: this only captures the "final step" that tips from wrong to right. Earlier important reasoning steps get 0 importance even if they were essential.
Note on the small discrepancy from the actual dataset metrics
A small discrepancy between your results and those in the dataset is fine, and expected. The current version of the dataset that is uploaded seems to have a bug in the "forced answer" metric data, for example it will classify the following rollout:
'Therefore, the final answers is \\boxed{20}. However, upon re-examining ... so the correct answer is \\boxed{19}.'
as having a final answer of 20 rather than 19, hence incorrectly classifying the answer as wrong.
# Calculate forced answer importance
full_cot_list = [
[rollout["full_cot"] for rollout in chunk_rollouts] for chunk_rollouts in problem_data["chunk_solutions_forced"]
]
answer = problem_data["problem"]["gt_answer"]
forced_importances = calculate_answer_importance(full_cot_list, answer)
# Get chunk texts for hover data
chunks_for_hover = [chunk["chunk"] for chunk in problem_data["chunks_labeled"][:-1]]
# Plot with plotly
fig = go.Figure()
fig.add_trace(
go.Bar(
x=list(range(len(forced_importances))),
y=forced_importances,
opacity=0.7,
hovertemplate="<b>Chunk %{x}</b><br>Importance: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
customdata=[chunk[:100] + "..." if len(chunk) > 100 else chunk for chunk in chunks_for_hover],
)
)
fig.add_hline(y=0, line_color="black", line_width=0.5)
fig.update_layout(
title="Forced Answer Importance by Chunk",
xaxis_title="Chunk Index",
yaxis_title="Forced Answer Importance",
width=900,
height=400,
)
fig.show()
Click to see the expected output
Metric (2/3) : Resampling Importance
To address the limitation described in the dropdown above, we resample immediately after each sentence and measure how this changes the final answer distribution:
The key difference: $A^r_i$ comes from a full resampled trajectory starting after $S_i$, not a forced early answer. This captures less narrow ways in which the sentence is important.
Exercise - compare resampling importance
The same calculate_answer_importance function works - we just apply it to different data!
Before computing this, think for a while about what you expect to see. How do you think the resampling importance results will compare to the forced answer importances?
# YOUR CODE HERE - compute `resampling_importances` using `calculate_answer_importance`
# on the resampled rollouts
resampling_importances = []
# Compare with precomputed values from dataset (they used a slightly different method to us, but
# we should get an answer within 1% of theirs)
precomputed = [chunk["resampling_importance_accuracy"] for chunk in problem_data["chunks_labeled"][:-1]]
avg_diff = np.abs(np.subtract(resampling_importances, precomputed)).mean()
assert avg_diff < 0.01, "Error above 1% threshold"
# Plot comparison between these two metrics
chunks_for_hover = [chunk["chunk"] for chunk in problem_data["chunks_labeled"][:-1]]
hover_texts = [chunk[:100] + "..." if len(chunk) > 100 else chunk for chunk in chunks_for_hover]
fig = make_subplots(
rows=2,
cols=1,
shared_xaxes=True,
vertical_spacing=0.1,
subplot_titles=("Forced Answer Importance", "Resampling Importance"),
)
for row, color, y in [(1, "cornflowerblue", forced_importances), (2, "orange", resampling_importances)]:
fig.add_trace(
go.Bar(
x=list(range(len(y))),
y=y,
opacity=0.7,
name="Forced" if color == "cornflowerblue" else "Resampling",
marker_color=color,
hovertemplate="<b>Chunk %{x}</b><br>Importance: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
customdata=hover_texts,
),
row=row,
col=1,
)
fig.add_hline(y=0, line_color="black", line_width=0.5, row=1, col=1)
fig.add_hline(y=0, line_color="black", line_width=0.5, row=2, col=1)
fig.update_layout(width=900, height=500, showlegend=False)
fig.update_xaxes(title_text="Chunk Index", row=2, col=1)
fig.update_yaxes(title_text="Importance", row=1, col=1)
fig.update_yaxes(title_text="Importance", row=2, col=1)
fig.show()
Click to see the expected output
Solution & discussion
# Calculate resampling importance (same function, different data!)
full_cot_list_resampled = [
[rollout["full_cot"] for rollout in chunk_rollouts] for chunk_rollouts in problem_data["chunk_solutions"]
]
resampling_importances = calculate_answer_importance(full_cot_list_resampled, answer)
# Compare with precomputed values from dataset (they used a slightly different method to us, but
# we should get an answer within 1% of theirs)
precomputed = [chunk["resampling_importance_accuracy"] for chunk in problem_data["chunks_labeled"][:-1]]
avg_diff = np.abs(np.subtract(resampling_importances, precomputed)).mean()
assert avg_diff < 0.01, "Error above 1% threshold"
# Plot comparison between these two metrics
chunks_for_hover = [chunk["chunk"] for chunk in problem_data["chunks_labeled"][:-1]]
hover_texts = [chunk[:100] + "..." if len(chunk) > 100 else chunk for chunk in chunks_for_hover]
fig = make_subplots(
rows=2,
cols=1,
shared_xaxes=True,
vertical_spacing=0.1,
subplot_titles=("Forced Answer Importance", "Resampling Importance"),
)
for row, color, y in [(1, "cornflowerblue", forced_importances), (2, "orange", resampling_importances)]:
fig.add_trace(
go.Bar(
x=list(range(len(y))),
y=y,
opacity=0.7,
name="Forced" if color == "cornflowerblue" else "Resampling",
marker_color=color,
hovertemplate="<b>Chunk %{x}</b><br>Importance: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
customdata=hover_texts,
),
row=row,
col=1,
)
fig.add_hline(y=0, line_color="black", line_width=0.5, row=1, col=1)
fig.add_hline(y=0, line_color="black", line_width=0.5, row=2, col=1)
fig.update_layout(width=900, height=500, showlegend=False)
fig.update_xaxes(title_text="Chunk Index", row=2, col=1)
fig.update_yaxes(title_text="Importance", row=1, col=1)
fig.update_yaxes(title_text="Importance", row=2, col=1)
fig.show()
You should see the resampling importance values are often higher at earlier steps, petering out by the second half of the rollout chunks.
Moreover, the highest-importance sentences according to this metric are the ones that involve some amount of planning. The most important (nearly 4x higher than any other positive sentence) is one that appears in the first 15% of the rollout chunks, and which outlines the plan the model will use for solving this problem. The most negative sentence is the one which appears immediately before that one, which expresses a slightly different plan.
Semantic Similarity in Resampling
Before we look at the last metric (counterfactual importance), let's revisit the notion of embedding cosine similarity. Since we have data on a bunch of resampled rollouts at different chunks, we can compute the average cosine similarity between a chunk and all of its resampled chunks (i.e. $S_i$ and $S'_i$ in the notation above). Run the cells below to compute these cosine similarities and plot them.
Which kinds of sentences seem like their resamples have the highest or lowest cosine similarity? Can you explain why?
# Compute cosine similarity between original and resampled chunks
chunks_removed = [chunk["chunk"] for chunk in problem_data["chunks_labeled"]]
embeddings_original = embedding_model.encode(chunks_removed)
chunks_resampled = [
[rollout["chunk_resampled"] for rollout in chunk_rollouts] for chunk_rollouts in problem_data["chunk_solutions"]
]
embeddings_resampled = np.stack([embedding_model.encode(r) for r in chunks_resampled])
# Compute similarities
cos_sims = einops.einsum(embeddings_original, embeddings_resampled, "chunk d, chunk resample d -> chunk resample")
cos_sims_mean = cos_sims.mean(axis=1)
# Plot by category with plotly
chunk_labels = [CATEGORIES[chunk["function_tags"][0]] for chunk in problem_data["chunks_labeled"]]
chunks_for_hover = [chunk["chunk"] for chunk in problem_data["chunks_labeled"]]
df = pd.DataFrame({"Label": chunk_labels, "Cosine Similarity": cos_sims_mean, "Chunk Text": chunks_for_hover})
fig = go.Figure()
for label in df["Label"].unique():
subset = df[df["Label"] == label]
hover_texts = [text[:100] + "..." if len(text) > 100 else text for text in subset["Chunk Text"]]
fig.add_trace(
go.Bar(
x=subset.index.tolist(),
y=subset["Cosine Similarity"].tolist(),
name=label,
marker_color=utils.CATEGORY_COLORS.get(label, "#9E9E9E"),
hovertemplate="<b>Chunk %{x}</b><br>Category: "
+ label
+ "<br>Cosine Similarity: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
customdata=hover_texts,
)
)
fig.update_layout(
title="How Similar are Resampled Chunks to Originals?",
xaxis_title="Chunk Index",
yaxis_title="Mean Cosine Similarity to Resamples",
width=900,
height=400,
legend=dict(x=1.02, y=1, xanchor="left"),
bargap=0,
)
fig.show()
Click to see the expected output
What patterns do you notice? (click to see discussion)
You should find that some of the highest average cosine similarities are for:
- Final Answer Emission chunks, which makes sense since there's only a limited number of ways to express an answer (and given the context, it's likely obvious whether the model is about to emit an answer, as well as what the answer will be)
- Active Computation and Result Consolidation chunks, especially those which come in batches (e.g. chunks 23-33, 51-59, and 91-97 in the plot below). This makes sense, since these chunks often involve a series of similar steps (e.g. applying an iterative formula, or doing a series of multiplications with different inputs), and once the model starts one of these processes it'll be very constrained in how it acts until it's finished.
Some of the lowest average cosine similarities are for:
- Plan Generation chunks which represent changes in trajectory, for example:
- Chunk 13: "Alternatively, maybe I can calculate..."
- Chunk 49: "Let me convert it step by step again", which is a decision to fix a theorized error earlier in reasoning
- Uncertainty Management chunks which also represent a re-evaluation and could change the subsequent trajectory, for example:
- Chunk 45: "I must have made a mistake somewhere"
- Chunk 84: "So, which is correct?"
- Chunk 139: "Wait, but that's not correct because..."
problem_data["chunk_solutions"][28] shows that this is actually an artifact of incorrect chunking: the resamples here all follow the pattern "Now, adding all these up:" followed by an equation, and this has low similarity to the original chunk which (correctly) splits at : and so doesn't include the equation. If you want to fix this, you can try using our split_solution_into_chunks function from earlier to process the resampled chunks before plotting them. Moral of the story - this kind of string parsing is finnicky and easy to get wrong.
Metric (3/3): Counterfactual Importance
Finally, we'll implement counterfactual importance:
where $A^c$ only includes rollouts where the resampled sentence was different enough.
This is the same as the resampling importance (and we'll use the same data), but with one difference: we filter out all resampled rollouts where the first resampled chunk $T_i$ is sufficiently different from the chunk $S_i$ which it replaces. In this case, sufficiently different means having an embedding cosine similarity of less than 0.8, using our embedding model from earlier.
The intuition for this metric: if resampling importance told us the effect when we choose a different sentence than $S_i$, then counterfactual importance tells us the effect when we choose a different reasoning path than represented by $S_i$. Low cosine similarity in this case is a proxy for the reasoning paths being very different (rather than just light rephrasings of what is essentially the same reasoning step).
Exercise - compute counterfactual importance
def calculate_counterfactual_importance(
chunks_removed: list[str],
chunks_resampled: list[list[str]],
full_cot_list: list[list[str]],
answer: str,
threshold: float = 0.8,
min_samples: int = 5,
embedding_model: SentenceTransformer | None = None,
) -> list[float]:
"""
Calculate counterfactual importance by filtering for low-similarity resamples.
Args:
chunks_removed: Original chunks that were removed
chunks_resampled: List of resampled chunks for each position
full_cot_list: Full CoT rollouts for each position
answer: Ground truth answer
threshold: Maximum cosine similarity to count as "different"
min_samples: Minimum samples needed to compute probability
embedding_model: Sentence embedding model
Returns:
List of counterfactual importance scores
"""
raise NotImplementedError("Implement counterfactual importance calculation")
Solution
def calculate_counterfactual_importance(
chunks_removed: list[str],
chunks_resampled: list[list[str]],
full_cot_list: list[list[str]],
answer: str,
threshold: float = 0.8,
min_samples: int = 5,
embedding_model: SentenceTransformer | None = None,
) -> list[float]:
"""
Calculate counterfactual importance by filtering for low-similarity resamples.
Args:
chunks_removed: Original chunks that were removed
chunks_resampled: List of resampled chunks for each position
full_cot_list: Full CoT rollouts for each position
answer: Ground truth answer
threshold: Maximum cosine similarity to count as "different"
min_samples: Minimum samples needed to compute probability
embedding_model: Sentence embedding model
Returns:
List of counterfactual importance scores
"""
def extract_answer_from_cot(cot: str) -> str:
ans = cot.split("\\boxed{")[-1].split("}")[0]
return "".join(char for char in ans if char.isdigit() or char == ".")
def get_filtered_indices(chunk_removed: str, resampled: list[str]) -> list[int]:
emb_original = embedding_model.encode(chunk_removed)
emb_resampled = embedding_model.encode(resampled)
cos_sims = emb_original @ emb_resampled.T
return np.where(cos_sims < threshold)[0]
# Get filtered indices for each chunk
filtered_indices = [
get_filtered_indices(chunk, resampled) for chunk, resampled in zip(chunks_removed, chunks_resampled)
]
# Compute P(correct) using only filtered samples
probabilities = []
for cot_list, indices in zip(full_cot_list, filtered_indices):
if len(indices) >= min_samples:
correct = sum(extract_answer_from_cot(cot_list[i]) == answer for i in indices)
probabilities.append(correct / len(indices))
else:
probabilities.append(None)
# Forward-fill None values
probabilities = pd.Series(probabilities).ffill().bfill().tolist()
return np.diff(probabilities).tolist()
When you've filled this in, run the cells below to compute and plot the counterfactual importance scores next to your resampling importance scores.
You should find the two metrics (resampling and counterfactual) are mostly quite similar for this example. They differ most in sentences which were also shown from the plot above to have high semantic variance, because these are our thought anchors: sentences which guide the entire reasoning process, and so changing them to something different in embedding space has a large effect on the subsequent trajectory.
For example, chunk 3 "So, maybe I can just figure out how many hexadecimal digits there are..." has a higher counterfactual importance than resampling importance (about 50% higher). This is because the chunk represents a key part of the overall reasoning process: when the model doesn't say this, it often expresses a different plan such as "So, maybe I can start by converting each digit one by one..." or "So maybe I can just multiply the number of hex digits by 4..." - you can confirm this by looking at the dataset of rollouts yourself. Even if some of these plans will end up at the same place, the exact plan & phrasing that the model produces in this step will significantly affect its trajectory.
Discussion of some other chunks with very different resampling & counterfactual values
Chunks 53-58: active computations.These chunks are a series of active computations. The counterfactual metric should be zero on most or all of these chunks, because all resampled rollouts will have had very similar semantic entropy (the model was essentially forced at this point to carry out a multi-stage computation in a very specific way). This shows our counterfactual metric is working as intended, because we want to identify these kinds of reasoning steps as not particularly important.
Chunks 43-44, and 28: "chunking bugs".Chunks 43-44 say "Now, according to this, it's 19 bits. There's a discrepancy here.** Inspecting the dataset reveals we've caught the model at a reasoning crossroads: about half the time at this point it says "it's 19 bits" and the other half it just says "there's a discrepancy here". So this is an example of when our counterfactual importance metric can still be misleading depending on how we chunk our rollouts. There's a similar story for chunk 28: "Now, adding these all up:". Sometimes the sequence is chunked with "Now, adding these all up: (gives expression)" and sometimes the text and expression are in different chunks. Moral of the story - this kind of string parsing is finnicky and can easily cause issues!
# Calculate counterfactual importance
chunks_removed = [chunk["chunk"] for chunk in problem_data["chunks_labeled"]]
chunks_resampled = [
[rollout["chunk_resampled"] for rollout in chunk_rollouts] for chunk_rollouts in problem_data["chunk_solutions"]
]
full_cot_list = [
[rollout["full_cot"] for rollout in chunk_rollouts] for chunk_rollouts in problem_data["chunk_solutions"]
]
counterfactual_importances = calculate_counterfactual_importance(
chunks_removed, chunks_resampled, full_cot_list, answer, embedding_model=embedding_model
)
# Compare with precomputed
precomputed_cf = [-chunk["counterfactual_importance_accuracy"] for chunk in problem_data["chunks_labeled"][:-1]]
avg_diff = np.abs(np.subtract(counterfactual_importances, precomputed_cf)).mean()
assert avg_diff < 0.025, "Error above 2.5% threshold"
print("All tests for `calculate_counterfactual_importance` passed!")
# Plot comparison of all three metrics with subplots (like previous bar chart)
chunks_for_hover = [chunk["chunk"] for chunk in problem_data["chunks_labeled"][:-1]]
hover_texts = [chunk[:100] + "..." if len(chunk) > 100 else chunk for chunk in chunks_for_hover]
fig = make_subplots(
rows=3,
cols=1,
shared_xaxes=True,
vertical_spacing=0.08,
subplot_titles=("Forced Answer Importance", "Resampling Importance", "Counterfactual Importance"),
)
for row, (name, importances, color) in enumerate(
[
("Forced", forced_importances, "cornflowerblue"),
("Resampling", resampling_importances, "orange"),
("Counterfactual", counterfactual_importances, "seagreen"),
],
start=1,
):
fig.add_trace(
go.Bar(
x=list(range(len(importances))),
y=importances,
name=name,
opacity=0.8,
marker_color=color,
hovertemplate="<b>Chunk %{x}</b><br>"
+ name
+ " Importance: %{y:.4f}<br>Text: %{customdata}<extra></extra>",
customdata=hover_texts,
),
row=row,
col=1,
)
fig.add_hline(y=0, line_color="black", line_width=0.5, row=row, col=1)
fig.update_layout(
title="Comparison of Importance Metrics",
width=1000,
height=700,
showlegend=False,
)
fig.show()
Click to see the expected output
Exercise - replicate Figure 3b (sentence category effect)
As an open-ended challenge, try replicating Figure 3b from the paper. We've already got all the data you need to do this, so it's just a matter of understanding and plotting the data correctly.
Note that we've so far defined our metrics in terms of accuracy rather than KL divergence, so you should expect to see the metrics looking slightly different (also we're only averaging over a single prompt's chunks rather than over many prompts, so the data will be noisier). But even with these restrictions, you should still be able to get a qualitatively similar plot to Figure 3b. Signs of life you should look for in your plot are:
- (sanity check) Result Consolidation has the highest normalized position in trace
- (core result) Plan Generation and Uncertainty Management have the highest counterfactual importance
We leave it as an exercise to the reader to scale up the analysis to many different prompts, and add in some good error bars!
# YOUR CODE HERE - replicate figure 3b!
Click to see the expected output
Solution
# Create dataframe with position and importance
chunk_labels = [CATEGORIES[chunk["function_tags"][0]] for chunk in problem_data["chunks_labeled"]]
n_chunks = len(chunk_labels) - 1
df = pd.DataFrame(
{
"Label": chunk_labels[:-1],
"Importance": counterfactual_importances,
"Position": np.arange(n_chunks) / n_chunks,
}
)
# Get top 5 most common categories
top_labels = df["Label"].value_counts().head(5).index.tolist()
df_filtered = df[df["Label"].isin(top_labels)]
# Group and calculate means
grouped = df_filtered.groupby("Label")[["Importance", "Position"]].mean().reset_index()
# Plot
fig, ax = plt.subplots(figsize=(8, 6))
for _, row in grouped.iterrows():
ax.scatter(
row["Position"], row["Importance"], s=150, label=row["Label"], color=utils.CATEGORY_COLORS.get(row["Label"])
)
ax.set_xlabel("Normalized Position in Trace (0-1)")
ax.set_ylabel("Mean Counterfactual Importance")
ax.set_title("Sentence Category Effect (Figure 3b)")
ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.tight_layout()
plt.show()
Sentence-to-Sentence Causal Graph
So far we've measured importance only for the final answer — asking "how much does removing sentence $i$ affect whether the model gets the right answer?" But this misses the rich causal structure within the reasoning trace itself. A planning sentence might not directly affect the final answer, but it might be critical for producing a later computation sentence that does affect the answer.
The sentence-to-sentence causal graph captures this structure by asking: how important is sentence $i$ for whether sentence $j$ appears later? This connects our black-box importance metrics to the white-box attention patterns we'll study in the next section — if sentence $i$ causally influences sentence $j$ in the black-box sense, we'd expect the model's attention heads to also attend strongly from $j$ back to $i$.
The paper computes this by: 1. For each target sentence $j$, look at rollouts where sentence $i$ was kept vs removed 2. Check how often sentence $j$ (or something semantically similar) appears in each set, using embedding cosine similarity 3. The difference is the causal importance of $i$ on $j$
This produces a matrix of pairwise causal effects, which can be visualized as a directed graph — revealing the "backbone" of the reasoning trace.
Before diving into the exercise, let's first understand the visualization tool we'll be using. The function utils.chunk_graph_html takes an importance matrix, chunk labels, and chunk texts, and renders an interactive causal graph. Here's a demo with random data so you can see what the output looks like:
# Demo: visualize a causal graph with sample data
n_chunks = 5
sample_labels = ["Problem Setup", "Plan Generation", "Active Computation", "Self Checking", "Final Answer Emission"]
sample_texts = [
"Let me start by understanding the problem...",
"I notice that this is equivalent to...",
"Computing the value: 3 * 7 = 21",
"Let me verify: 21 / 7 = 3, correct",
"Therefore the answer is 21",
]
# Create a sample importance matrix (i, j) = how much does chunk i causally influence chunk j
# We expect later chunks to be influenced by earlier ones. In our demo, each step influences
# the next one, plus plan generation influences everything!
sample_importance = np.zeros((n_chunks, n_chunks))
for i in range(4):
sample_importance[i, i + 1] += 0.5
for i in range(2, 5):
sample_importance[1, i] += 0.5
html_str = utils.chunk_graph_html(
edge_weights=sample_importance,
chunk_labels=sample_labels,
chunk_texts=sample_texts,
)
display(HTML(html_str))
Click to see the expected output
Exercise - replicate causal graph
This exercise is broken into three clear steps, each with its own hints. We recommend tackling them in order — completing each step before moving on to the next.
Some notes before you start:
- The
problem_datawe've loaded in corresponds to the "Case study transcript" problem in appendix C.1 of the thought anchors paper. Note that they only compute the sentence-to-sentence causal graph for chunks 0-73 inclusive, since they observe (as you should have found in the counterfactual importance graphs generated above) that the resampling importance drops off to zero or near zero past this point. To save time, we recommend you follow this convention i.e. only compute the graph up to chunk 73. - For reference, the solution code takes about 6 minutes to run on an A100 GPU, with about half the time spent on computing the rollout chunks' embeddings, and about half of the time on the pairwise importance (i.e. batching cosine similarity calculations), although this was probably under-optimized and a more efficient solution is likely possible.
Step 1: Compute pairwise importance matrix
For each pair of chunks (i, j) where i < j, compute the counterfactual importance of chunk i on whether chunk j appears in rollouts. Specifically, compare the "match rate" (fraction of rollouts containing a sentence similar to chunk j) between rollouts where chunk i was kept vs removed. Store these in an importance_matrix of shape (n_chunks, n_chunks).
We recommend computing all the embeddings up front, so that they can be batched (the embedding_model.encode should automatically batch things if you give it a list of sentences and pass in the batch_size argument).
This is the hardest step. We recommend breaking it into sub-functions:
Hint 1a - overall approach for Step 1
You'll need to:
- Precompute all target embeddings, from
problem_data["chunks_labeled"][:n_chunks] - Precompute the embeddings from all chunked rollouts, using
problem_data["chunk_solutions"]and your own chunk splitting function to split apart each rollout - Iterate through every pair
(i, j)of chunks fromproblem_data["chunks_labeled"][:n_chunks], and: - Compute the include rate, which is the fraction of rollouts from
chunk_{i+1}containing a sentence sufficiently similar to chunk_j - Compute the exclude rate, which is the same but starting from
chunk_i(i.e. we don't force the inclusion of chunk i) - Compute the importance, which is the include rate minus the exclude rate
Hint 1b - function stubs to fill in
Here are the function signatures you should implement. Try filling in each one:
def precompute_rollout_embeddings(
chunk_solutions: list[list[dict]],
embedding_model: SentenceTransformer,
batch_size: int = 128,
) -> tuple[list[list[np.ndarray]], list[list[list[str]]]]:
'''
Precompute embeddings for all sentences in all rollouts using batched encoding. Does this by
first collecting every single rollout from each chunk, then batch encode them all.
Args:
chunk_solutions: List of chunk solutions, where chunk_solutions[k] contains
rollouts from resampling after chunk k
embedding_model: Sentence embedding model
batch_size: Batch size for embedding model encoding
Returns:
rollout_embeddings: rollout_embeddings[i][j] = array of shape (n_sentences, embed_dim)
rollout_sentences: rollout_sentences[i][j] = list of sentence strings
'''
...
return rollout_embeddings, rollout_sentences
def precompute_target_embeddings(
chunks: list[dict],
embedding_model: SentenceTransformer,
n_chunks: int,
) -> Float[np.ndarray, " n_chunks embed_dim"]:
'''
Precompute embeddings for all target chunk sentences.
Args:
chunks: List of chunk dictionaries with "chunk" key
embedding_model: Sentence embedding model
n_chunks: Number of chunks to process
Returns:
Array of shape (n_chunks, embed_dim) with target embeddings
'''
...
return embedding_model.encode(chunk_texts)
def compute_match_rate_from_embeddings(
target_embedding: Float[np.ndarray, " embed_dim"],
rollout_embeddings_list: list[Float[np.ndarray, " n_sentences embed_dim"]],
similarity_threshold: float = 0.7,
) -> float:
'''
Compute fraction of rollouts containing a sentence similar to target.
Args:
target_embedding: Embedding of target sentence, shape (embed_dim,)
rollout_embeddings_list: List of arrays, each shape (n_sentences, embed_dim)
similarity_threshold: Minimum cosine similarity for a match
Returns:
Fraction of rollouts with a matching sentence (0 to 1)
'''
...
return matches / len(rollout_boundaries)
def compute_pairwise_importance(
source_idx: int,
target_idx: int,
target_embeddings: Float[np.ndarray, " n_chunks embed_dim"],
rollout_embeddings: list[list[Float[np.ndarray, " n_sentences embed_dim"]]],
similarity_threshold: float = 0.7,
) -> float:
'''
Compute causal importance of sentence source_idx on sentence target_idx.
Uses precomputed embeddings for efficiency.
Compares:
- Rollouts from chunk_{source_idx + 1} (source was KEPT)
- Rollouts from chunk_{source_idx} (source was REMOVED)
'''
...
return include_rate - exclude_rate
Step 2: Generate chunk labels and texts
Extract the category label and text content for each chunk from problem_data["chunks_labeled"][:n_chunks]. The category labels are used for coloring nodes in the graph, and the text content is shown in tooltips.
Hint 2 - extracting labels and texts
Each element of problem_data["chunks_labeled"] is a dict with keys including "function_tags" (a list of category strings) and "chunk" (the text). Use the first tag from "function_tags" as the label for each chunk, mapped through the CATEGORIES dict, e.g.:
chunk_labels = [CATEGORIES[chunk["function_tags"][0]] for chunk in problem_data["chunks_labeled"][:n_chunks]]
chunk_texts = [chunk["chunk"] for chunk in problem_data["chunks_labeled"][:n_chunks]]
Step 3: Visualize with utils.chunk_graph_html
Pass your importance matrix, labels, and texts to create the interactive visualization. You can use the demo above as a reference for how the visualization function works -- pass in edge_weights (2D array of pairwise importance values) and chunk_labels (list of category strings e.g. "uncertainty_management") and it will render the interactive graph for you.
Hint 3 - calling the visualization
html_str = utils.chunk_graph_html(
edge_weights=importance_matrix,
chunk_labels=chunk_labels,
chunk_texts=chunk_texts,
n_top_edges_per_direction=3,
)
display(HTML(html_str))
Checking your answer and interpreting the graph
You can check your answer by going to thought-anchors.com and selecting "Hex to binary" in the Problem dropdown. Their demo shows the full 144 nodes rather than the subset of 74 we recommend, but you should still be able to tell if yours is correct (for example, there's one planning node which appears early on in the graph and has very high downstream importance - can you replicate this finding in your graph?).
When you've done this, you can investigate the graph and see if it shows you any interesting patterns about how the sentences causally influence each other. For example, can you see any of the following patterns:
- A planning sentence A boosts a computation sentence B, where B is something you'd only compute if you were following the plan in A
- An uncertainty management sentence C boosts the sentence D which resolves that uncertainty (highlighting the importance of that uncertainty management step)
- Short term planning blocks, e.g. uncertainty management sentence E boosting a close-range sentence F which proposes a specific plan for how to proceed
- Blocks of active computation chunks which don't really affect each other causally (highlighting how our counterfactual resampling method is better than just normal resampling)
At the same time, don't read too much into this single sequence - concrete takeaways only come from averaging out results over a large number of sequences and seeing which kinds of patterns emerge.
def precompute_rollout_embeddings(
chunk_solutions: list[list[dict]],
embedding_model: SentenceTransformer,
batch_size: int = 128,
) -> tuple[list[list[np.ndarray]], list[list[list[str]]]]:
"""
Precompute embeddings for all sentences in all rollouts using batched encoding. Does this by
first collecting every single rollout from each chunk, then batch encode them all.
Args:
chunk_solutions: List of chunk solutions, where chunk_solutions[k] contains
rollouts from resampling after chunk k
embedding_model: Sentence embedding model
batch_size: Batch size for embedding model encoding
Returns:
rollout_embeddings: rollout_embeddings[i][j] = array of shape (n_sentences, embed_dim)
rollout_sentences: rollout_sentences[i][j] = list of sentence strings
"""
# First pass: collect all sentences and track their positions
all_sentences: list[str] = []
# position_map[i][j] = (start_idx, end_idx) into all_sentences, or None if empty
position_map: list[list[tuple[int, int] | None]] = []
rollout_sentences: list[list[list[str]]] = []
for i in tqdm(range(len(chunk_solutions)), desc="Precomputing embeddings"):
chunk_positions = []
chunk_sentences = []
for rollout_dict in chunk_solutions[i]:
rollout_text = rollout_dict["rollout"]
sentences = split_solution_into_chunks(rollout_text)
assert sentences, "Expected at least one sentence per rollout"
all_sentences.extend(sentences)
chunk_positions.append((len(all_sentences) - len(sentences), len(all_sentences)))
chunk_sentences.append(sentences)
position_map.append(chunk_positions)
rollout_sentences.append(chunk_sentences)
# Batch embed all sentences at once
all_embeddings = embedding_model.encode(all_sentences, batch_size=batch_size, show_progress_bar=True)
# Reconstruct per-rollout embeddings using position map
rollout_embeddings: list[list[np.ndarray]] = []
for i in range(len(chunk_solutions)):
chunk_embeddings = []
for position in position_map[i]:
start_idx, end_idx = position
chunk_embeddings.append(all_embeddings[start_idx:end_idx])
rollout_embeddings.append(chunk_embeddings)
return rollout_embeddings, rollout_sentences
def precompute_target_embeddings(
chunks: list[dict],
embedding_model: SentenceTransformer,
n_chunks: int,
) -> Float[np.ndarray, "n_chunks embed_dim"]:
"""
Precompute embeddings for all target chunk sentences.
Args:
chunks: List of chunk dictionaries with "chunk" key
embedding_model: Sentence embedding model
n_chunks: Number of chunks to process
Returns:
Array of shape (n_chunks, embed_dim) with target embeddings
"""
chunk_texts = [chunks[i]["chunk"] for i in range(n_chunks)]
return embedding_model.encode(chunk_texts)
def compute_match_rate_from_embeddings(
target_embedding: Float[np.ndarray, " embed_dim"],
rollout_embeddings_list: list[Float[np.ndarray, "n_sentences embed_dim"]],
similarity_threshold: float = 0.7,
) -> float:
"""
Compute fraction of rollouts containing a sentence similar to target.
Args:
target_embedding: Embedding of target sentence, shape (embed_dim,)
rollout_embeddings_list: List of arrays, each shape (n_sentences, embed_dim)
similarity_threshold: Minimum cosine similarity for a match
Returns:
Fraction of rollouts with a matching sentence (0 to 1)
"""
# Filter out empty embeddings and track rollout boundaries
valid_embeddings = []
rollout_boundaries = [] # (start, end) indices into concatenated array
current_idx = 0
for embeddings in rollout_embeddings_list:
valid_embeddings.append(embeddings)
rollout_boundaries.append((current_idx, current_idx + len(embeddings)))
current_idx += len(embeddings)
# Concatenate all embeddings: (total_sentences, embed_dim)
all_embeddings = np.vstack(valid_embeddings)
# Normalize target and all embeddings for cosine similarity
target_norm = target_embedding / (np.linalg.norm(target_embedding) + 1e-8)
all_norms = np.linalg.norm(all_embeddings, axis=1, keepdims=True) + 1e-8
all_embeddings_norm = all_embeddings / all_norms
# Compute all cosine similarities at once: (total_sentences,)
all_similarities = all_embeddings_norm @ target_norm
# Count matches per rollout using boundaries
matches = sum(1 for start, end in rollout_boundaries if np.max(all_similarities[start:end]) >= similarity_threshold)
return matches / len(rollout_boundaries)
def compute_pairwise_importance(
source_idx: int,
target_idx: int,
target_embeddings: Float[np.ndarray, "n_chunks embed_dim"],
rollout_embeddings: list[list[Float[np.ndarray, "n_sentences embed_dim"]]],
similarity_threshold: float = 0.7,
) -> float:
"""
Compute causal importance of sentence source_idx on sentence target_idx.
Uses precomputed embeddings for efficiency.
Compares:
- Rollouts from chunk_{source_idx + 1} (source was KEPT)
- Rollouts from chunk_{source_idx} (source was REMOVED)
"""
if source_idx >= target_idx:
return 0.0
target_embedding = target_embeddings[target_idx]
# Rollouts where source was kept
include_embeddings = rollout_embeddings[source_idx + 1]
# Rollouts where source was removed
exclude_embeddings = rollout_embeddings[source_idx]
include_rate = compute_match_rate_from_embeddings(target_embedding, include_embeddings, similarity_threshold)
exclude_rate = compute_match_rate_from_embeddings(target_embedding, exclude_embeddings, similarity_threshold)
return include_rate - exclude_rate
# Precompute all embeddings upfront
n_chunks = 74
target_embeddings = precompute_target_embeddings(problem_data["chunks_labeled"], embedding_model, n_chunks=n_chunks)
rollout_embeddings, rollout_sentences = precompute_rollout_embeddings(
problem_data["chunk_solutions"], embedding_model
)
# Compute importance matrix
importance_matrix = np.zeros((n_chunks, n_chunks))
for i in tqdm(range(n_chunks - 1), desc="Computing pairwise importance"):
for j in range(i + 1, n_chunks):
importance_matrix[i, j] = compute_pairwise_importance(i, j, target_embeddings, rollout_embeddings)
# Plot the importance matrix with plotly
# Mask upper triangle of transposed matrix (we want to show lower triangle which has the values)
mask = np.triu(np.ones_like(importance_matrix, dtype=bool))
importance_matrix = np.where(mask, importance_matrix, 0.0)
chunk_labels = [CATEGORIES[chunk["function_tags"][0]] for chunk in problem_data["chunks_labeled"][:n_chunks]]
chunk_texts = [chunk["chunk"] for chunk in problem_data["chunks_labeled"][:n_chunks]]
html_str = utils.chunk_graph_html(
edge_weights=importance_matrix,
chunk_labels=chunk_labels,
chunk_texts=chunk_texts,
n_top_edges_per_direction=3,
)
display(HTML(html_str))
Click to see the expected output
Bonus: Implementing Your Own Resampling (local models)
Finally, we'll end with a bonus exercise for implementing your own resampling method using models loaded in locally. Note that this is not the method that the authors would have used (instead they used API credits), but we're providing this here as an optional exercise if you're interested in getting into the nuts and bolts of how this resampling method has to work.
Exercise - implement resampling (optional)
class StopOnThink(StoppingCriteria):
"""Helper class for stopping generation when the <think>...</think> tags are closed."""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.think_token_id = tokenizer.encode("</think>", add_special_tokens=False)[0]
def __call__(self, input_ids, scores, **kwargs):
return input_ids[0, -1] == self.think_token_id
def get_resampled_rollouts(
prompt: str,
model: transformers.models.llama.modeling_llama.LlamaForCausalLM,
tokenizer: transformers.models.llama.tokenization_llama.LlamaTokenizer,
num_resamples_per_chunk: int = 100,
batch_size: int = 4,
max_new_tokens: int = 2048,
up_to_n_chunks: int | None = None,
) -> tuple[str, list[str], list[list[dict]]]:
"""
After each chunk in `chunks`, computes a number of resampled rollouts.
Args:
prompt: The initial problem prompt (which ends with a <think> tag).
num_resamples_per_chunk: Number of resamples to compute for each chunk.
Returns:
Tuple of (full_answer, chunks, resampled_rollouts) where the latter is a list of lists of
dicts (one for each chunk & resample on that chunk).
"""
raise NotImplementedError()
return full_answer, chunks, chunk_rollouts
# Load the model for generation (only if you want to try this)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_8B, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_8B)
resampled_rollouts = get_resampled_rollouts(
prompt=problem_data["base_solution"]["prompt"],
model=model,
tokenizer=tokenizer,
num_resamples_per_chunk=2,
batch_size=2,
max_new_tokens=2048,
up_to_n_chunks=4,
)
for i, resamples in enumerate(resampled_rollouts):
print("Replaced chunk: " + repr(resamples[0]["chunk_replaced"]))
for j, r in enumerate(resamples):
print(f" Resample {j}: " + repr(r["chunk_resampled"]))
print()
Expected output from the cell above
Replaced chunk: 'Alright, so I have this problem here:'
Resample 0: 'Okay, so I need to figure out how many bits are in the binary representation of the hexadecimal number 66666₁₆.'
Resample 1: "Alright, so I've got this problem here:"
Replaced chunk: 'When the base-16 number \\(66666_{16}\\) is written in base 2, how many base-2 digits (bits) does it have?'
Resample 0: 'when the base-16 number \\(66666_{16}\\) is written in base 2, how many base-2 digits (bits) does it have?'
Resample 1: 'when the base-16 number 66666₁₆ is written in base 2, how many base-2 digits (bits) does it have?'
Replaced chunk: 'Hmm, okay.'
Resample 0: 'Hmm, okay.'
Resample 1: 'Hmm, okay.'
Replaced chunk: 'Let me try to figure this out step by step.'
Resample 0: 'Let me think about how to approach this.'
Resample 1: 'I need to figure out how many bits are required to represent this hexadecimal number in binary.'
Solution
class StopOnThink(StoppingCriteria):
"""Helper class for stopping generation when the <think>...</think> tags are closed."""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.think_token_id = tokenizer.encode("</think>", add_special_tokens=False)[0]
def __call__(self, input_ids, scores, kwargs):
return input_ids[0, -1] == self.think_token_id
def get_resampled_rollouts(
prompt: str,
model: transformers.models.llama.modeling_llama.LlamaForCausalLM,
tokenizer: transformers.models.llama.tokenization_llama.LlamaTokenizer,
num_resamples_per_chunk: int = 100,
batch_size: int = 4,
max_new_tokens: int = 2048,
up_to_n_chunks: int | None = None,
) -> tuple[str, list[str], list[list[dict]]]:
"""
After each chunk in chunks, computes a number of resampled rollouts.
Args:
prompt: The initial problem prompt (which ends with a <think> tag).
num_resamples_per_chunk: Number of resamples to compute for each chunk.
Returns:
Tuple of (full_answer, chunks, resampled_rollouts) where the latter is a list of lists of
dicts (one for each chunk & resample on that chunk).
"""
@torch.inference_mode()
def generate(inputs):
return model.generate(
inputs,
max_new_tokens=max_new_tokens,
temperature=0.6,
top_p=0.95,
do_sample=True,
stopping_criteria=[StopOnThink()],
pad_token_id=tokenizer.eos_token_id,
)
# First, generate a completion which we'll split up into chunks.
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
output = generate(inputs)
full_answer = tokenizer.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=False)
chunks = split_solution_into_chunks(full_answer)
# Second, generate resamples at each chunk.
chunk_rollouts = []
n_chunk_instances = defaultdict(int)
for chunk in tqdm(chunks[:up_to_n_chunks]):
chunk_rollouts.append([])
# To get the answer before the chunk, we split on the correct instance of the chunk (since
# this chunk might have appeared multiple times in the answer).
n_chunk_instances[chunk] += 1
full_answer_split = (
prompt + chunk.join(full_answer.split(chunk, maxsplit=n_chunk_instances[chunk])[:-1]).strip()
)
inputs = tokenizer([full_answer_split] * batch_size, return_tensors="pt").to(DEVICE)
for _ in range(num_resamples_per_chunk // batch_size):
output_batch = generate(inputs)
for output_generation in output_batch:
generated_text = tokenizer.decode(
output_generation[inputs["input_ids"].shape[1] :], skip_special_tokens=False
)
chunk_rollouts[-1].append(
{
"full_cot": full_answer_split + generated_text,
"chunk_resampled": split_solution_into_chunks(generated_text)[0],
"chunk_replaced": chunk,
"rollout": generated_text,
}
)
return full_answer, chunks, chunk_rollouts
# Load the model for generation (only if you want to try this)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_8B, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_8B)
resampled_rollouts = get_resampled_rollouts(
prompt=problem_data["base_solution"]["prompt"],
model=model,
tokenizer=tokenizer,
num_resamples_per_chunk=2,
batch_size=2,
max_new_tokens=2048,
up_to_n_chunks=4,
)
for i, resamples in enumerate(resampled_rollouts):
print("Replaced chunk: " + repr(resamples[0]["chunk_replaced"]))
for j, r in enumerate(resamples):
print(f" Resample {j}: " + repr(r["chunk_resampled"]))
print()
When you've got this working, we leave the rest as an exercise for you to explore! Generating sufficiently many rollouts for statistical analysis is beyond the scope of these exercises, but it might be a fun project to try out if you want more practice working at larger scale and performing full paper replications.