2️⃣ Black-box Analysis

Learning Objectives
  • Understand the baseline of forced answer importance used for reasoning chunks (and its limitations)
  • Learn about resampling methods for measuring sentence importance, and implement your own resampling metric calculations on a given resampled rollout
  • Learn how we can filter for low cosine similarity resamplings to find sentences which are critical for shaping the model's reasoning trajectory
  • Reproduce several of the paper's key figures analysing the importance of different categories of reasoning steps
  • Implement your own resampling methods using LLM generation

The paper used three different methods to measure sentence importance:

1. Forced Answer Importance. For each sentence $S_i$ in the CoT, we interrupt the model and append text, inducing a final output: Therefore, the final answer is \\boxed{...}. We then measure the model's accuracy when forced to answer immediately. If $A^f_i$ is the final answer when we force the model to answer immediately after $S_i$, and $P(A)$ is the probability of answer $A$ being correct, then the formula for forced answer importance is:

$$ \text{importance}_f := P(A^f_i) - P(A^f_{i-1}) $$

Can you see a flaw in this approach, when it comes to identifying sentences which are critical for the model's reasoning process? Click here to reveal the answer.

Some sentence $S$ might be necessary for the final answer, but comes late in the reasoning process, meaning all sentences before $S$ will result in low accuracy by this metric. We will only pick up sentences whose inclusion in combination with all previous sentences gets us from the wrong answer to the right one.

For people who have completed the IOI material / done much work with model internals, this is the equivalent of finding the most important model components by seeing which ones write the final answer to the residual stream. It's a good start, but doesn't tell you much about the steps of the computation beyond the last one.

2. Resampling Importance. To address the flaw above, for each sentence $S_i$ in the CoT, we can resample a whole trajectory $(S_1, S_2, \ldots, S_{i-1}, S_i, S'_{i+1}, ..., S'_N, A^r_i)$ (where $A^r_i$ is the final answer we get from resampling chunks after $S_i$). By comparing it to the corresponding trajectory $(S_1, S_2, \ldots, S'_i, ..., S'_{N}, A^r_i)$ which we get from resampling chunks including $S_i$, we can get a sense for how important $S_i$ was for producing the final answer. Our metric is:

$$ \text{importance}_r := P(A^r_i) - P(A^r_{i-1}) $$

3. Counterfactual Importance. An issue with resampling importance is that often $S'_i$ will be very similar to $S_i$, if the reasoning context strongly constraints what can be expressed at that position. To fix this, we can filter for cases where these two sentences are fairly different, using a semantic similarity metric derived from our embedding model. This gives us a counterfactual importance metric which is identical to the previous one, but just with filtered rollouts:

$$ \text{importance} := P(A^c_i) - P(A^c_{i-1}) $$

In these sections, we'll work through each of these metrics in turn. The focus will be computing these metrics on the dataset we've already been provided, with the full replication (including your own implementation of resampling) coming in the next section.

Looking closer at our dataset

Before getting into the exercises, let's look closer at our dataset to see what we're working with. As mentioned, we'll be setting aside our actual model and chunking functions temporarily, so we can focus on analysing the resampled rollouts already provided to us in this dataset.

The code below loads in all the data necessary for analyzing a single problem. This contains:

  • problem which contains the problem statement, along with some basic metadata (including the answer)
  • chunks_labeled[i], data for the i-th chunk (e.g. what category it is, plus some metrics)
  • chunk_solutions_forced[i][j], data for the j-th rollout we get from forcing an answer immediately after the i-th chunk (e.g. for i=0 this means forcing an answer before any chunks are included)
  • chunk_solutions[i][j], data for the j-th rollout we get from resampling immediately after the i-th chunk

We'll be using this data to implement the three importance metrics described above.

Exercise - inspect the dataset

Difficulty: 🔴⚪⚪⚪⚪
Importance: 🔵🔵⚪⚪⚪
You should spend up to 10-15 minutes on this exercise.

Run the code below to load in the data for a single problem (note we're using generations from deepseek-r1-distill-qwen-14b rather than our Llama-8b model here, so that we match the case study in the paper's appendix).

Make sure you understand the structure of the data, since this will make the following exercises easier.

Here are a few investigative questions you might try to answer when inspecting this problem's data:

  • How has the chunking worked here? Can you see any obvious issues with how it's been applied, e.g. places where a chunk has been split that shouldn't have been?
  • Do the categories look reasonable? You can try comparing them to each of your autoraters, and see what fraction match.
  • Inspect chunk_solutions_forced, and see how early the model generally manages to get the answer correct.
  • Inspect chunk_solutions, and see how much variety the model's completions have when resampled from various stages in the reasoning process.
def load_single_file(file_path: str):
    local_path = hf_hub_download(repo_id=DATASET_NAME, filename=file_path, repo_type="dataset")
    with open(local_path, "r") as f:
        return json.load(f)


def load_problem_data(problem_id: int, verbose: bool = True):
    disable_progress_bars()

    problem_dir = "correct_base_solution"
    problem_dir_forced = "correct_base_solution_forced_answer"

    problem_path = f"deepseek-r1-distill-qwen-14b/temperature_0.6_top_p_0.95/{problem_dir}/problem_{problem_id}"
    problem_path_forced = f"deepseek-r1-distill-qwen-14b/temperature_0.6_top_p_0.95/{problem_dir_forced}/problem_{problem_id}"

    base_solution = load_single_file(f"{problem_path}/base_solution.json")
    problem = load_single_file(f"{problem_path}/problem.json")
    chunks_labeled = load_single_file(f"{problem_path}/chunks_labeled.json")
    chunk_solutions = []
    chunk_solutions_forced = []

    for chunk_idx in tqdm(range(len(chunks_labeled)), disable=not verbose):
        chunk_solutions.append(load_single_file(f"{problem_path}/chunk_{chunk_idx}/solutions.json"))
        chunk_solutions_forced.append(
            load_single_file(f"{problem_path_forced}/chunk_{chunk_idx}/solutions.json")
        )

    enable_progress_bars()

    return {
        "problem": problem,
        "base_solution": base_solution,
        "chunks_labeled": chunks_labeled,
        "chunk_solutions_forced": chunk_solutions_forced,
        "chunk_solutions": chunk_solutions,
    }


problem_data = load_problem_data(PROBLEM_ID)
print(problem_data["problem"])
print()

for i, c in enumerate(problem_data["chunks_labeled"]):
    print(f"{i}. {c['chunk']!r}")

Exercise - calculating answer importance

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵⚪⚪⚪
You should spend up to 10-15 minutes on this exercise.

You should fill in the function below, to compute the forced answer importance on a set of chunks and their labelled categories. Note that the function takes a list of lists of full CoT rollouts, meaning you'll need to parse out the model's answer from the CoT yourself.

Note, the model's return type can vary: for example in a question about interest rates with the answer 6.17, the model sometimes responds with e.g. 6.17% or r = 6.17. In this case study we're dealing with a problem that has an integer solution so this isn't really an issue, but it's still good practice to handle these kinds of cases when you're parsing the answer.

def calculate_answer_importance(full_cot_list: list[list[str]], answer: str) -> list[float]:
    """
    Calculate importance for chunks, based on accuracy differences.

    Args:
        full_cot_list: List of resampled rollouts; the [i][j]-th element is the j-th rollout which
          was generated by answer-forcing immediately after the i-th chunk (e.g. the [0][0]-th
          element is the 0th rollout which doesn't include any chunks of the model's reasoning).
        answer: The ground truth answer to the problem.
        chunk_idx: Index of the chunk to calculate importance for.

    Returns:
        float: Forced importance score
    """

    def extract_answer_from_cot(cot: str) -> str:
        answer = cot.split("\\boxed{")[-1].split("}")[0]
        return "".join(char for char in answer if char.isdigit() or char == ".")

    # Get list of P(A_{S_i}) values for each chunk
    probabilities = [
        sum(extract_answer_from_cot(cot) == answer for cot in cot_list) / len(cot_list)
        for cot_list in full_cot_list
    ]

    # Convert these to importance scores: P(A_{S_i}) - P(A_{S_{i-1}})
    return np.diff(probabilities).tolist()

When you've done this, run the cell below to get your results and plot them. Note, this should not match the paper's "Figure 2" results, since we're using forced answer importance, not resampled or counterfactual importance.

What do you notice about these results? What sentences are necessary for the model to start getting greater-than-zero accuracy? Are there any sentences which significantly drop or raise the model's accuracy, and can you explain why?

Answer - what you should see

The model starts getting greater-than-zero accuracy around chunk 40, when it computes the number of digits in the final answer. The accuracy drops significantly in a couple of spots, including near the very end when it says "That's 20 bits". The accuracy immediately recovers after this, because the next chunk clarifies that the answer is actually 19 bits because the leading digit is zero. Some of the other sharp negative spikes in accuracy are also due to this dropping-leading-zero confusion being raised then resolved in the following chunk.

Note, a discrepancy between your results and those in the dataset is fine. The current version of the dataset that is uploaded seems to have a bug in the "forced answer" metric data, for example it will classify the following rollout:

'Therefore, the final answers is \\boxed{20}. However, upon re-examining ... so the correct answer is \\boxed{19}.'

as having a final answer of 20 rather than 19, hence incorrectly classifying the answer as wrong.

full_cot_list = [
    [rollout["full_cot"] for rollout in chunk_rollouts]
    for chunk_rollouts in problem_data["chunk_solutions_forced"]
]
answer = problem_data["problem"]["gt_answer"]

forced_answer_importances = calculate_answer_importance(full_cot_list, answer)
forced_answer_importances_cumsum = np.cumsum(forced_answer_importances)
forced_answer_importances_cumsum += 1 - forced_answer_importances_cumsum[-1]

df = pd.DataFrame(
    {
        "Forced answer importance": forced_answer_importances,
        "Accuracy": forced_answer_importances_cumsum,
        "chunk": [d["chunk"] for d in problem_data["chunks_labeled"][:-1]],
        "tags": [d["function_tags"][0] for d in problem_data["chunks_labeled"][:-1]],
    }
)

fig = px.line(
    df,
    labels={"index": "Chunk index", "value": "Importance", "variable": "Metric"},
    y=["Forced answer importance", "Accuracy"],
    hover_data=["chunk", "tags"],
)
fig.update_layout(title="Forced answer importance")
fig.show()

Exercise - compare resampled answer importance

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-15 minutes on this exercise. This doesn't involve changing any code, just running the cells below and interpreting the results.

Now that we've implemented the answer_importance function, we get get the resampled answer importance for free by just applying it to a different set of rollouts. You can run the cells below, which are lightly adapted versions of the cells above (just based on the resampled trajectories rather than the forced-answer trajectories).

How do these answers compare to the forced answer importance? Are there specific sentences which have higher metric scores here than in the forced answer case? Can you reproduce the results from the paper, in section 2.4 Case Study, and do you understand the paper's description of why we get those results?

full_cot_list = [
    [rollout["full_cot"] for rollout in chunk_rollouts]
    for chunk_rollouts in problem_data["chunk_solutions"]
]
answer = problem_data["problem"]["gt_answer"]

resampling_answer_importances = calculate_answer_importance(full_cot_list, answer)

resampling_answer_importances_precomputed = [
    chunk_data["resampling_importance_accuracy"]
    for chunk_data in problem_data["chunks_labeled"][:-1]
]

avg_diff = np.abs(
    np.subtract(resampling_answer_importances, resampling_answer_importances_precomputed)
).mean()
assert avg_diff < 0.01, f"Your implementation may be incorrect: {avg_diff=:.4f}"
print(f"Average difference: {avg_diff:.4f}")
resampling_answer_importances_cumsum = np.cumsum(resampling_answer_importances)
resampling_answer_importances_cumsum += 1 - resampling_answer_importances_cumsum[-1]

df = pd.DataFrame(
    {
        "Resampling answer importance": resampling_answer_importances,
        "Accuracy": resampling_answer_importances_cumsum,
        "chunk": [d["chunk"] for d in problem_data["chunks_labeled"][:-1]],
        "tags": [d["function_tags"][0] for d in problem_data["chunks_labeled"][:-1]],
    }
)

fig = px.line(
    df,
    labels={"index": "Chunk index", "value": "Importance", "variable": "Metric"},
    y=["Resampling answer importance", "Accuracy"],
    hover_data=["chunk", "tags"],
)
fig.update_layout(title="Resampling answer importance")
fig.show()

Semantic similarity in resampling

Before we look at the last metric (counterfactual importance), let's revisit the notion of embedding cosine similarity. Since we have data on a bunch of resampled rollouts at different chunks, we can compute the average cosine similarity between a chunk and all of its resampled chunks (i.e. $S_i$ and $S'_i$ in the notation above). Run the cells below to compute these cosine similarities and plot them.

Which kinds of sentences seem like their resamples have the highest or lowest cosine similarity? Can you explain why? Click on the dropdown below to see the answer.

Answer

You should find that some of the highest average cosine simlarities are for:

- Final Answer Emission chunks, which makes sense since there's only a limited number of ways to express an answer (and given the context, it's likely obvious whether the model is about to emit an answer, as well as what the answer will be) - Active Computation and Result Consolidation chunks, especially those which come in batches (e.g. chunks 23-33, 51-59, and 91-97 in the plot below). This makes sense, since these chunks often involve a series of similar steps (e.g. applying an iterative formula, or doing a series of multiplications with different inputs), and once the model starts one of these processes it'll be very constrained in how it acts until it's finished.

Some of the lowest average cosine similarities are for:

- Plan Generation chunks which represent changes in trajectory, for example: - Chunk 13: "Alternatively, maybe I can calculate..." - Chunk 49: "Let me convert it step by step again", which is a decision to fix a theorized error earlier in reasoning - Uncertainty Management chunks which also represent a re-evaluation and could change the subsequent trajectory, for example: - Chunk 45: "I must have made a mistake somewhere" - Chunk 84: "So, which is correct?" - Chunk 139: "Wait, but that's not correct because..."

Note, there is one chunk (28) which is classified as "result consolidation" and is something of an outlier, with extremely low average cosine similarity to its resamples. However, inspection of problem_data["chunk_solutions"][28] shows that this is actually an artifact of incorrect chunking: the resamples here all follow the pattern "Now, adding all these up:" followed by an equation, and this has low similarity to the original chunk which (correctly) splits at : and so doesn't include the equation. If you want to fix this, you can try using our split_solution_into_chunks function from earlier to process the resampled chunks before plotting them. Moral of the story - this kind of string parsing is finnicky and easy to get wrong.
# Get the embeddings of S_i, the chunks we'll be resampling
chunks_removed = [chunk_data["chunk"] for chunk_data in problem_data["chunks_labeled"]]
embeddings_S_i = embedding_model.encode(chunks_removed)  # (N_chunks, d_embed)

# Get the embeddings of T_i, the resampled chunks
chunks_resampled = [
    [rollout["chunk_resampled"] for rollout in chunk_rollouts]
    for chunk_rollouts in problem_data["chunk_solutions"]
]
embeddings_T_i = np.stack(
    [embedding_model.encode(r) for r in chunks_resampled]
)  # (N_chunks, N_resamples, d_embed)

# Get the cosine similarities
cos_sims = einops.einsum(
    embeddings_S_i, embeddings_T_i, "chunk d_embed, chunk resample d_embed -> chunk resample"
)
print(f"Computed cosine similarities for {cos_sims.shape[0]} chunks, {cos_sims.shape[1]} resamples")
# Group the cosine similarity data into a dataframe
cos_sims_mean = cos_sims.mean(axis=1)
chunk_labels = [CATEGORIES[chunk["function_tags"][0]] for chunk in problem_data["chunks_labeled"]]
df = pd.DataFrame(
    {
        "Label": chunk_labels,
        "Cosine similarity": cos_sims_mean,
        "chunk": [chunk["chunk"] for chunk in problem_data["chunks_labeled"]],
    }
)

# Bar chart of average cosine similarity for each chunk
px.bar(
    df,
    labels={"index": "Chunk index", "value": "Cosine similarity"},
    color="Label",
    color_discrete_map=CATEGORY_COLORS,
    hover_data=["chunk"],
    title="Cosine similarity between removed and resampled chunks",
).show()

# Boxplot of cosine cosine similarities grouped by chunk label
label_order = df.groupby("Label")["Cosine similarity"].mean().sort_values().index.tolist()
px.box(
    df,
    x="Label",
    y="Cosine similarity",
    labels={"x": "Label", "y": "Cosine similarity", "color": "Label"},
    color="Label",
    color_discrete_map=CATEGORY_COLORS,
    category_orders={"Label": label_order},
    boxmode="overlay",
    width=1000,
    title="Cosine similarity, grouped by chunk label",
).update_xaxes(tickangle=45).update_layout(boxgap=0.3).show()

Exercise - compute counterfactual importance

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-20 minutes on this exercise.

Finally, we'll implement counterfactual importance. This is the same as the resampling importance (and we'll use the same data), but with one difference: we filter out all resampled rollouts where the first resampled chunk $T_i$ is sufficiently different from the chunk $S_i$ which it replaces. In this case, sufficiently different means having an embedding cosine similarity of less than 0.8, using our embedding model from earlier.

The intuition for this metric: if resampling importance told us the effect when we choose a different sentence than $S_i$, then counterfactual importance tells us the effect when we choose a different reasoning path than represented by $S_i$. Low cosine similarity in this case is a proxy for the reasoning paths being very different (rather than just light rephrasings of what is essentially the same reasoning step).

def calculate_counterfactual_answer_importance(
    chunks_removed: list[str],
    chunks_resampled: list[list[str]],
    full_cot_list: list[list[str]],
    answer: str,
    threshold: float = 0.8,
    min_indices: int = 5,
    embedding_model: sentence_transformers.SentenceTransformer = embedding_model,
) -> list[float]:
    """
    Calculate importance for chunks, based on accuracy differences, after filtering for low
    new-generation cosine similarity.

    Args:
        chunks_removed: List of chunks $S_i$ which were removed from the rollouts.
        chunks_resampled: List of chunks $T_i$ which were resampled for each of the multiple rollouts.
        full_cot_list: List of resampled rollouts; the [i][j]-th element is the j-th rollout which
          was generated by answer-forcing immediately after the i-th chunk (e.g. the [0][0]-th
          element is the 0th rollout which doesn't include any chunks of the model's reasoning).
        answer: The ground truth answer to the problem.
        threshold: Minimum embedding cosine similarity to consider a rollout as "sufficiently different"
        min_indices: Minimum number of indices we can have post-filtering to count this score.
        embedding_model: Embedding model to use for calculating cosine similarity

    Returns:
        float: Forced importance score
    """

    def get_filtered_indices(chunk_removed: str, chunks_resampled: list[str]) -> list[int]:
        embedding_S_i = embedding_model.encode(chunk_removed)  # (d_embed,)
        embeddings_T_i = embedding_model.encode(chunks_resampled)  # (N, d_embed,)
        cos_sims = embedding_S_i @ embeddings_T_i.T  # (N,)
        return np.where(cos_sims < threshold)[0]

    def extract_answer_from_cot(cot: str) -> str:
        answer = cot.split("\\boxed{")[-1].split("}")[0]
        return "".join(char for char in answer if char.isdigit() or char == ".")

    filtered_indices = [
        get_filtered_indices(chunk_removed, _chunks_resampled)
        for chunk_removed, _chunks_resampled in zip(chunks_removed, chunks_resampled)
    ]

    # Get list of P(A^c_{S_i}) values for each chunk (or None if can't be computed)
    probabilities = [
        sum(extract_answer_from_cot(cot_list[idx]) == answer for idx in indices) / len(indices)
        if len(indices) >= min_indices
        else None
        for cot_list, indices in zip(full_cot_list, filtered_indices)
    ]

    # Forward-fill this list, to remove the "None" values
    probabilities = pd.Series(probabilities).ffill().bfill().tolist()

    # Return diffs
    return np.diff(probabilities).tolist()

When you've filled this in, run the cells below to compute and plot the counterfactual importance scores next to your resampling importance scores.

You should find the two metrics (resampling and counterfactual) are mostly quite similar for this example. They differ most in sentences which were also shown from the plot above to have high semantic variance, because these are our thought anchors: sentences which guide the entire reasoning process, and so changing them to something different in embedding space has a large effect on the subsequent trajectory.

For example, chunk 3 "So, maybe I can just figure out how many hexadecimal digits there are..." has a higher counterfactual importance than resampling importance (93% vs 88% accuracy when you resample counterfactually vs without filtering at this chunk). This is because the chunk represents a key part of the overall reasoning process: when the model doesn't say this, it often expresses a different plan such as "So, maybe I can start by converting each digit one by one..." or "So maybe I can just multiply the number of hex digits by 4...". Even if some of these plans will end up at the same place, the exact plan & phrasing that the model produces in this step will significantly affect its trajectory.

However, many of the chunks have very similar counterfactual and resampling importance scores. Only by doing a much larger statistical analysis would we be able to parse out any meaningful differences between the two (ideally we'd include more sequences and more resamples per sequence).

chunks_removed = [chunk_data["chunk"] for chunk_data in problem_data["chunks_labeled"]]
chunks_resampled = [
    [rollout["chunk_resampled"] for rollout in chunk_rollouts]
    for chunk_rollouts in problem_data["chunk_solutions"]
]
full_cot_list = [
    [rollout["full_cot"] for rollout in chunk_rollouts]
    for chunk_rollouts in problem_data["chunk_solutions"]
]
answer = problem_data["problem"]["gt_answer"]

counterfactual_answer_importances = calculate_counterfactual_answer_importance(
    chunks_removed, chunks_resampled, full_cot_list, answer
)

counterfactual_answer_importances_precomputed = [
    -chunk_data["counterfactual_importance_accuracy"]
    for chunk_data in problem_data["chunks_labeled"][:-1]
]

avg_diff = np.abs(
    np.subtract(counterfactual_answer_importances, counterfactual_answer_importances_precomputed)
).mean()
assert avg_diff < 0.1, f"Your implementation may be incorrect: {avg_diff=:.4f}"
print(f"Average difference: {avg_diff:.4f}")
counterfactual_answer_importances_cumsum = np.cumsum(counterfactual_answer_importances)
counterfactual_answer_importances_cumsum += 1 - counterfactual_answer_importances_cumsum[-1]

df = pd.DataFrame(
    {
        "Resampling answer importance": resampling_answer_importances,
        "Resampling answer importance (accuracy)": resampling_answer_importances_cumsum,
        "Counterfactual answer importance": counterfactual_answer_importances,
        "Counterfactual answer importance (accuracy)": counterfactual_answer_importances_cumsum,
        "chunk": [d["chunk"] for d in problem_data["chunks_labeled"][:-1]],
        "tags": [d["function_tags"][0] for d in problem_data["chunks_labeled"][:-1]],
    }
)

fig = px.line(
    df,
    labels={"index": "Chunk index", "value": "Importance", "variable": "Metric"},
    y=[
        "Resampling answer importance",
        "Resampling answer importance (accuracy)",
        "Counterfactual answer importance",
        "Counterfactual answer importance (accuracy)",
    ],
    hover_data=["chunk", "tags"],
)
fig.update_layout(title="Resampling vs counterfactual answer importance")
fig.show()

Exercise - replicate Figure 3b (without KL divergence)

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵⚪⚪⚪
You should spend up to 10-15 minutes on this exercise.

As an open-ended challenge, try replicating Figure 3b from the paper. We've already got all the data you need to do this, so it's just a matter of understanding and plotting the data correctly.

Note that we've so far defined our metrics in terms of accuracy rather than KL divergence, so you should expect to see the metrics looking slightly different (also we're only averaging over a single prompt's chunks rather than over many prompts). We recommend leaving the error bars off, for that reason. But even with these restrictions, you should still be able to get a qualitatively similar plot to Figure 3b. We leave it as an exercise to the reader to scale up the analysis to many different prompts, and add the error bars back!

<!--

The next 2 cells do this at scale, with error bars. Note, this will take about 10 minutes to run (most of that time is spent loading the data).

# Get the list of all files in the data directory

files = list_repo_files(repo_id=DATASET_NAME, repo_type="dataset")
print(len(files))
# -> 29030

problem_ids = set()
for file in files:
    if match := re.search(r"/problem_(\d+)/", file):
        problem_ids.add(int(match.group(1)))
problem_ids = sorted(problem_ids)
print(problem_ids)
# -> [330, 1591, 2050, 2137, 2189, 2236, 2238, 2870, 3360, 3448, 3550, 3916, 3935, 4019, 4164, 4605, 4682, 6481, 6596, 6998]

chunk_labels_all = []
counterfactual_importances_all = []

for problem_id in tqdm(problem_ids[:10]):
    # Load data
    data = load_problem_data(problem_id)

    # Compute counterfactual importance for all chunks in this problem
    chunk_labels = [CATEGORIES[chunk["function_tags"][0]] for chunk in data["chunks_labeled"]]
    chunks_removed = [chunk_data["chunk"] for chunk_data in data["chunks_labeled"]]
    chunks_resampled = [
        [rollout["chunk_resampled"] for rollout in chunk_rollouts]
        for chunk_rollouts in data["chunk_solutions"]
    ]
    full_cot_list = [
        [rollout["full_cot"] for rollout in chunk_rollouts]
        for chunk_rollouts in data["chunk_solutions"]
    ]
    answer = data["problem"]["gt_answer"]

    counterfactual_answer_importances = calculate_counterfactual_answer_importance(
        chunks_removed, chunks_resampled, full_cot_list, answer
    )

    # Add them to the lists
    chunk_labels_all.append(chunk_labels)
    counterfactual_importances_all.append(counterfactual_answer_importances)

# Plotting all of them together:
all_data = []
for i, (labels, importances) in enumerate(zip(chunk_labels_all, counterfactual_importances_all)):
    n = len(labels)
    for j, (label, importance) in enumerate(zip(labels, importances)):
        all_data.append({"Label": label, "Importance": importance, "position": j / n})

df_all = pd.DataFrame(all_data)

# Get the 5 most common labels
top_5_labels = df_all["Label"].value_counts().head(5).index.tolist()
df_filtered = df_all[df_all["Label"].isin(top_5_labels)]

# Group and calculate mean and standard error
grouped = (
    df_filtered.groupby("Label")
    .agg({"Importance": ["mean", "sem"], "position": ["mean", "sem"]})
    .reset_index()
)
grouped.columns = ["Label", "importance_mean", "importance_sem", "pos_mean", "pos_sem"]

# Plot
fig, ax = plt.subplots(figsize=(10, 6))
for _, row in grouped.iterrows():
    ax.errorbar(
        row["pos_mean"],
        row["importance_mean"],
        xerr=row["pos_sem"],
        yerr=row["importance_sem"],
        fmt="o",
        label=row["Label"],
        color=CATEGORY_COLORS.get(row["Label"]),
        markersize=10,
        capsize=5,
    )

ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("Normalized position in trace (0-1)")
ax.set_ylabel("Counterfactual importance")
ax.set_title("Sentence category effect")
ax.legend()
plt.show()

-->


```python
# Get the 5 most common labels
label_counts = pd.Series(chunk_labels[:-1]).value_counts()
top_5_labels = label_counts.head(5).index.tolist()

# Create dataframe and filter
df_filtered = pd.DataFrame(
    {
        "Label": chunk_labels[:-1],
        "Importance": counterfactual_answer_importances,
        "position": np.arange(len(chunk_labels[:-1])) / len(chunk_labels[:-1]),
    }
)
df_filtered = df_filtered[df_filtered["Label"].isin(top_5_labels)]
grouped = df_filtered.groupby("Label")[["Importance", "position"]].mean().reset_index()

# Plot
fig, ax = plt.subplots(figsize=(6, 4))
for label in grouped["Label"]:
    row = grouped[grouped["Label"] == label]
    ax.scatter(row["position"], row["Importance"], label=label, color=CATEGORY_COLORS.get(label))

ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("Normalized position in trace (0-1)")
ax.set_ylabel("Counterfactual importance")
ax.set_title("Sentence category effect")
ax.legend()
plt.show()

Exercise - replicate Figure 3b (with KL divergence)

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵⚪⚪⚪⚪
You should spend up to 10-30 minutes on this exercise, if you choose to complete it.

Now, you can try to replicate this figure using KL divergence instead of accuracy!

So far we've only been working with accuracy, i.e. how much does the model's tendency to say the correct answer change when we resample a particular chunk. But if we care about how much a sentence shapes the model's trajectory rather than just how much it helps the model get to the correct answer, then it makes sense to look at an absolute measure of how much the answers change. KL divergence is the perfect tool for looking at the difference between two distributions.

In the cells below, you should rewrite the calculate_counterfactual_answer_importance function to use KL divergence instead of accuracy. In other words, the i-th output of your function should be the KL div between the distribution of answers when the i-th and i-1-th chunks are removed: $D(p(A^c_i) \,||\, p(A^c_{i-1}))$. Some notes on this:

  • When computing the KL divergence, your distributions can be over the discrete set of answers produced by resamplings at either the i-th or i-1-th chunk.
  • To avoid division-by-zero errors, you should use the paper's method of Laplace smoothing: we add some value alpha to all the answer counts, so that when we convert them to distributions none of them are zero. The terms in our KL div calculation change from $p \log \frac{p}{q}$ to $p \log \frac{p'}{q'}$ where $p', q'$ are the smoothed distributions (meaning the terms will still be zero when $p$ is zero). We've given you the alpha parameter in the function signature.

You can generate the same scatter plot as above, and see if the results more closely match Figure 3b from the paper. You can also regenerate the plot from the previous exercise, and compare the two metrics (KL div based vs accuracy based). Do you see any differences?

Answer - what you should expect to see

In my implementation, the two scatter plots look almost identical (although obviously the axis is different since all KL divergence values are non-negative).

The line plot was also quite similar, the main difference being that the KL divergence based metric was less noisy, i.e. more separation between the large metric spikes and the smaller values. This makes sense, since KL divergence is a convex function: an increase in the difference between two probability distributions will lead to a greater than proportional increase in the KL divergence. For example, the difference between the KL divergence of the two distributions (0.5, 0.5) and (0.1, 0.9) is more than twice the difference between (0.5, 0.5) and (0.3, 0.7).

Technically they can be negative thanks to the Laplace smoothing, but only very close to zero, and the group means will all be positive.
def calculate_counterfactual_answer_importance_kl_div(
    chunks_removed: list[str],
    chunks_resampled: list[list[str]],
    full_cot_list: list[list[str]],
    threshold: float = 0.8,
    min_indices: int = 5,
    alpha: int = 1,
    embedding_model: sentence_transformers.SentenceTransformer = embedding_model,
) -> list[float]:
    """
    Calculate importance for chunks, based on accuracy differences, after filtering for low
    new-generation cosine similarity.

    Args:
        chunks_removed: List of chunks $S_i$ which were removed from the rollouts.
        chunks_resampled: List of chunks $T_i$ which were resampled for each of the multiple rollouts.
        full_cot_list: List of resampled rollouts; the [i][j]-th element is the j-th rollout which
          was generated by answer-forcing immediately after the i-th chunk (e.g. the [0][0]-th
          element is the 0th rollout which doesn't include any chunks of the model's reasoning).
        threshold: Minimum embedding cosine similarity to consider a rollout as "sufficiently different"
        min_indices: Minimum number of indices we can have post-filtering to count this score.
        alpha: Laplace smoothing parameter
        embedding_model: Embedding model to use for calculating cosine similarity

    Returns:
        float: Forced importance score
    """
    counter = 0

    def kl_div(answers_1: list[str], answers_2: list[str]) -> float:
        nonlocal counter
        counter += 1
        if not answers_1 or not answers_2:
            return 0.0
        # Get vocab set
        vocab = set(answers_1) | set(answers_2)
        # Get counts & probabilities
        count_1 = Counter(answers_1)
        count_2 = Counter(answers_2)
        p_1 = {k: v / sum(count_1.values()) for k, v in count_1.items()}
        # Laplace smoothing
        count_1_smoothed = {k: count_1.get(k, 0) + alpha for k in vocab}
        count_2_smoothed = {k: count_2.get(k, 0) + alpha for k in vocab}
        p_1_smoothed = {k: v / sum(count_1_smoothed.values()) for k, v in count_1_smoothed.items()}
        p_2_smoothed = {k: v / sum(count_2_smoothed.values()) for k, v in count_2_smoothed.items()}
        # Get KL divergence
        return sum(p_1[k] * np.log(p_1_smoothed[k] / p_2_smoothed[k]) for k in p_1)

    def get_filtered_indices(chunk_removed: str, chunks_resampled: list[str]) -> list[int]:
        embedding_S_i = embedding_model.encode(chunk_removed)  # (d_embed,)
        embeddings_T_i = embedding_model.encode(chunks_resampled)  # (N, d_embed,)
        cos_sims = embedding_S_i @ embeddings_T_i.T  # (N,)
        return np.where(cos_sims < threshold)[0]

    def extract_answer_from_cot(cot: str) -> str:
        answer = cot.split("\\boxed{")[-1].split("}")[0]
        return "".join(char for char in answer if char.isdigit() or char == ".")

    filtered_indices = [
        get_filtered_indices(chunk_removed, _chunks_resampled)
        for chunk_removed, _chunks_resampled in zip(chunks_removed, chunks_resampled)
    ]

    # Get a list of the different answers returned in each of the resampled (filtered) rollouts
    all_answers = [
        [extract_answer_from_cot(cot_list[idx]) for idx in indices]
        if len(indices) >= min_indices
        else []
        for cot_list, indices in zip(full_cot_list, filtered_indices)
    ]

    # Get KL divergences using the function defined above
    kl_divs = [kl_div(a1, a0) for a1, a0 in zip(all_answers[1:], all_answers)]

    return kl_divs
# Get the new KL div-based importance scores
chunk_labels = [CATEGORIES[chunk["function_tags"][0]] for chunk in problem_data["chunks_labeled"]]
chunks_removed = [chunk_data["chunk"] for chunk_data in problem_data["chunks_labeled"]]
chunks_resampled = [
    [rollout["chunk_resampled"] for rollout in chunk_rollouts]
    for chunk_rollouts in problem_data["chunk_solutions"]
]
full_cot_list = [
    [rollout["full_cot"] for rollout in chunk_rollouts]
    for chunk_rollouts in problem_data["chunk_solutions"]
]
counterfactual_answer_importances_kl_div = calculate_counterfactual_answer_importance_kl_div(
    chunks_removed, chunks_resampled, full_cot_list
)

# Re-run the plotting code from above
label_counts = pd.Series(chunk_labels[:-1]).value_counts()
top_5_labels = label_counts.head(5).index.tolist()
df_filtered = pd.DataFrame(
    {
        "Label": chunk_labels[:-1],
        "Importance": counterfactual_answer_importances_kl_div,
        "position": np.arange(len(chunk_labels[:-1])) / len(chunk_labels[:-1]),
    }
)
df_filtered = df_filtered[df_filtered["Label"].isin(top_5_labels)]
grouped = df_filtered.groupby("Label")[["Importance", "position"]].mean().reset_index()
fig, ax = plt.subplots(figsize=(6, 4))
for label in grouped["Label"]:
    row = grouped[grouped["Label"] == label]
    ax.scatter(row["position"], row["Importance"], label=label, color=CATEGORY_COLORS.get(label))

ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("Normalized position in trace (0-1)")
ax.set_ylabel("Counterfactual importance")
ax.set_title("Sentence category effect")
ax.legend()
plt.show()
df = pd.DataFrame(
    {
        "Counterfactual answer importance": counterfactual_answer_importances,
        "Counterfactual answer importance (KL div)": counterfactual_answer_importances_kl_div,
        "chunk": [d["chunk"] for d in problem_data["chunks_labeled"][:-1]],
        "tags": [d["function_tags"][0] for d in problem_data["chunks_labeled"][:-1]],
    }
)

fig = px.line(
    df,
    labels={"index": "Chunk index", "value": "Importance", "variable": "Metric"},
    y=[
        "Counterfactual answer importance",
        "Counterfactual answer importance (KL div)",
    ],
    hover_data=["chunk", "tags"],
)
fig.update_layout(title="Resampling vs counterfactual answer importance")
fig.show()

Resampling

We'll now move onto the final part of this section, where you get to implement your own resampling method for producing rollouts like the ones we've been analyzing above.

Exercise - implement your own resampling method

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-20 minutes on this exercise.

We'll start by implementing the resample_rollouts function below. This function takes in a full chain of thought, and a list of chunks. It then outputs a list of lists of dicts, where output[i][j] contains the data for the j-th resampled rollout after the i-th chunk. These dicts should have the following keys:

  • full_cot: The full new chain of thought for this rollout, including the chunk.
  • chunk_resampled: The first chunk of the resampled rollout (you can find this with your split_solution_into_chunks function).
  • chunk_replaced: The chunk that was replaced.

Note that this solution will probably be hackier than the more carefully designed generation & chunking code used in the actual paper, but this is fine - the purpose here is just to get an MVP which works and we could build on if we wanted to.

def get_resampled_rollouts(
    prompt: str,
    num_resamples_per_chunk: int = 100,
    batch_size: int = 4,
    max_new_tokens: int = 2048,
    up_to_n_chunks: int | None = None,
    model: transformers.models.llama.modeling_llama.LlamaForCausalLM = model,
) -> tuple[str, list[str], list[list[dict]]]:
    """
    After each chunk in `chunks`, computes a number of resampled rollouts.

    Args:
        prompt: The initial problem prompt (which ends with a <think> tag).
        num_resamples_per_chunk: Number of resamples to compute for each chunk.

    Returns:
        Tuple of (full_answer, chunks, resampled_rollouts) where the latter is a list of lists of
        dicts (one for each chunk & resample on that chunk).
    """

    @torch.inference_mode()
    def generate(inputs):
        return model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.6,
            top_p=0.95,
            do_sample=True,
            stopping_criteria=[StopOnThink()],
            pad_token_id=tokenizer.eos_token_id,
        )

    # First, generate a completion which we'll split up into chunks.
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    output = generate(inputs)
    full_answer = tokenizer.decode(
        output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=False
    )
    chunks = split_solution_into_chunks(full_answer)

    # Second, generate resamples at each chunk.
    chunk_rollouts = []
    n_chunk_instances = defaultdict(int)
    for chunk in tqdm(chunks[:up_to_n_chunks]):
        chunk_rollouts.append([])

        # To get the answer before the chunk, we split on the correct instance of the chunk (since
        # this chunk might have appeared multiple times in the answer).
        n_chunk_instances[chunk] += 1
        full_answer_split = (
            prompt
            + chunk.join(full_answer.split(chunk, maxsplit=n_chunk_instances[chunk])[:-1]).strip()
        )
        inputs = tokenizer([full_answer_split] * batch_size, return_tensors="pt").to(device)

        for _ in range(num_resamples_per_chunk // batch_size):
            output_batch = generate(inputs)
            for output_generation in output_batch:
                generated_text = tokenizer.decode(
                    output_generation[inputs["input_ids"].shape[1] :], skip_special_tokens=False
                )
                chunk_rollouts[-1].append(
                    {
                        "full_cot": full_answer_split + generated_text,
                        "chunk_resampled": split_solution_into_chunks(generated_text)[0],
                        "chunk_replaced": chunk,
                        "rollout": generated_text,
                    }
                )

    return full_answer, chunks, chunk_rollouts
resampled_rollouts = get_resampled_rollouts(
    prompt=problem_data["base_solution"]["prompt"],
    num_resamples_per_chunk=2,
    batch_size=2,
    max_new_tokens=2048,
    up_to_n_chunks=4,
)

for i, resamples in enumerate(resampled_rollouts):
    print("Replaced chunk: " + repr(resamples[0]["chunk_replaced"]))
    for j, r in enumerate(resamples):
        print(f"    Resample {j}: " + repr(r["chunk_resampled"]))
    print()

# Replaced chunk: 'Alright, so I have this problem here:'
#     Resample 0: 'Okay, so I need to figure out how many bits are in the binary representation of the hexadecimal number 66666₁₆.'
#     Resample 1: "Alright, so I've got this problem here:"

# Replaced chunk: 'When the base-16 number \\(66666_{16}\\) is written in base 2, how many base-2 digits (bits) does it have?'
#     Resample 0: 'when the base-16 number \\(66666_{16}\\) is written in base 2, how many base-2 digits (bits) does it have?'
#     Resample 1: 'when the base-16 number 66666₁₆ is written in base 2, how many base-2 digits (bits) does it have?'

# Replaced chunk: 'Hmm, okay.'
#     Resample 0: 'Hmm, okay.'
#     Resample 1: 'Hmm, okay.'

# Replaced chunk: 'Let me try to figure this out step by step.'
#     Resample 0: 'Let me think about how to approach this.'
#     Resample 1: 'I need to figure out how many bits are required to represent this hexadecimal number in binary.'

When you've got this working, we leave the rest as an exercise for you to explore! Generating sufficiently many rollouts for statistical analysis is beyond the scope of this Colab notebook, but it might be a fun project to try out if you want more practice working at larger scale and performing full paper replications.