1️⃣ CoT Infrastructure & Sentence Taxonomy

Learning Objectives
  • Import & understand the structure of our reasoning dataset
  • Understand how to use regex-based heuristics and autoraters to classify sentences
  • Classify sentences in reasoning traces by taxonomy

Model Setup & Dataset Inspection

Let's start by setting a few constants:

# TODO - actually use these! or delete them

# Configuration
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
DATASET_NAME = "uzaymacar/math-rollouts"  # Pre-computed rollouts from paper
SIMILARITY_THRESHOLD = 0.8  # Median threshold from paper
EMBEDDING_MODEL = "all-MiniLM-L6-v2"  # Sentence embedding model used in paper
N_ROLLOUTS = 100  # Number of rollouts per sentence

# # Paths (adjust these to match your setup)
# ROLLOUTS_DIR = Path("math_rollouts")  # Directory with generated rollouts
# ANALYSIS_DIR = Path("analysis")  # Directory to save analysis results
# ANALYSIS_DIR.mkdir(exist_ok=True, parents=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Let's load in our embedding model. Embedding models are unsupervised models designed to take in text and output a vector - we'll be using them to classify the similarity of different sentences, so we can find motifs in our reasoning traces.

embedding_model = SentenceTransformer(EMBEDDING_MODEL)
print(embedding_model)

To give you an idea of how it works, let's look at some example sentences:

prompts = [
    "Wait, I think I made an error in my reasoning and need to backtrack",
    "Hold on, I believe I made a mistake in my logic and should reconsider",
    "After careful analysis, I've determined the correct answer is 42",
    "Time is an illusion. Lunchtime doubly so.",
]
labels = [x[:35] + "..." for x in prompts]

embedding = embedding_model.encode(prompts)

cosine_sims = embedding @ embedding.T

px.imshow(
    cosine_sims,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    labels=dict(x="Prompt", y="Prompt", color="Cosine Similarity"),
    x=labels,
    y=labels,
)

We can also load in the dataset that the paper's authors have helpfully open-sourced. The dataset is very large, but the authors provide the structure in the corresponding HuggingFace page, so we can use the huggingface_hub package to load in just the data we want.

# dataset = load_dataset(DATASET_NAME, split="default", streaming=True)

PROBLEM_ID = 4682

path = f"deepseek-r1-distill-llama-8b/temperature_0.6_top_p_0.95/correct_base_solution/problem_{PROBLEM_ID}/base_solution.json"

local_path = hf_hub_download(repo_id=DATASET_NAME, filename=path, repo_type="dataset")

with open(local_path, "r") as f:
    problem_data = json.load(f)

problem_data

Lastly, we'll load in the LLM we'll be using for actual reasoning trace generation. We'll be using DeepSeek-R1-Distill-Llama-8B, to match the paper's implementation as closely as possible.

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.bfloat16,  # Use bfloat16 for efficiency
    device_map="auto",  # Automatically distribute across available GPUs
    trust_remote_code=True,
)
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print(f"Model loaded on: {model.device}")
print(f"Model has {model.config.num_hidden_layers} layers")
print(f"Model has {model.config.num_attention_heads} attention heads per layer")

We can test rollouts with this model:

problem_data["chunk_solutions"][10][0].keys()
easy_problem_text = "A rectangle has a length of 8 cm and a width of 5 cm. What is its perimeter?"

easy_prompt = f"""Solve this math problem step by step.

Problem: {easy_problem_text}

Solution:
<think>"""

inputs = tokenizer(easy_prompt, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.6,
        top_p=0.95,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )

generated_text = tokenizer.decode(
    outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=False
)
easy_problem_full_text = easy_prompt + generated_text

print(easy_problem_full_text)

Exercise - add a stopping criteria

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-15 minutes on this exercise.

It's a bit annoying that the model sometimes extends beyond the </think> token, since we might sometimes want to only take the output inside <think> ... </think> tags. To fix this, let's introduce a stopping criteria.

HuggingFace models support the stopping_criteria argument, which is a list of StoppingCriteria objects. We can implement our own StoppingCriteria by subclassing the base class and overriding the __call__ method. It takes in input_ids (the tensor of all generated tokens so far), and it should return True when we want this most recent generation to be the final one.

# TODO(mcdougallc) maybe we don't want this, because we want the actual answer after the CoT? idk


class StopOnThink(StoppingCriteria):
    def __call__(self, input_ids, scores, **kwargs):
        # YOUR CODE HERE: return True iff the model has generated "</think>"
        think_token_id = tokenizer.encode("</think>", add_special_tokens=False)[0]
        return input_ids[0, -1] == think_token_id


with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.6,
        top_p=0.95,
        do_sample=True,
        stopping_criteria=[StopOnThink()],
        pad_token_id=tokenizer.eos_token_id,
    )

generated_text = tokenizer.decode(
    outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=False
)
easy_problem_full_text = easy_prompt + generated_text

print(easy_problem_full_text)

Exercise - implement sentence splitting

Difficulty: 🔴⚪⚪⚪⚪
Importance: 🔵🔵⚪⚪⚪
You should spend ~10 minutes on this exercise.

First, we'll need to split our CoT traces into sentences based on punctuation and paragraph indices. We'll also need to handle special tokens like <think>.

You should fill in the split_solution_into_chunks function below. We've given you a few basic tests to pass; when your solution passes all of them you can be confident you've dealt with enough edge cases. Here is the full set of rules as defined by the edge cases:

  • The <think> ... </think> tags should be removed
  • You should split on sentences (i.e. ending in any of ., !, ?, or newlines), and characters like :
  • If the period . is in a decimal or numbered list (e.g. 34.5 or \n1.) then you shouldn't split on it
  • You should split on periods . unless they are decimal numbers e.g. x.y or a numbered list e.g. \n1.
  • No chunk should have length less than 10: if so, then merge it with the next chunk
  • Each chunk should be stripped of whitespace

This is a bit of a grunt task, so feel free to use LLMs to help you!

def split_solution_into_chunks(text: str) -> list[str]:
    """Split solution into sentence-level chunks."""

    # YOUR CODE HERE - fill in the rest of the function

    # Remove thinking tags
    if "<think>" in text:
        text = text.split("<think>")[1]
    if "</think>" in text:
        text = text.split("</think>")[0]
    text = text.strip()

    # Replace the "." characters which I don't want to split on
    text = re.sub(r"(\d)\.(\d)", r"\1<DECIMAL>\2", text)  # e.g. "4.5" -> "4<DECIMAL>5"
    text = re.sub(r"\n(\d)\.(\s)", r"\n\1<DECIMAL>\2", text)  # e.g. "1.\n" -> "\n1<DECIMAL>"

    # Split on sentence endings, and combine the endings with the previous chunk
    sentences = re.split(r"([!?:\n]|(?<!\n\d)\.)", text)
    chunks = []
    for i in range(0, len(sentences) - 1, 2):
        chunks.append((sentences[i] + sentences[i + 1]).replace("\n", " "))

    # Replace <DECIMAL> back with "."
    chunks = [re.sub(r"<DECIMAL>", ".", c) for c in chunks]

    # Merge chunks that are too short
    while len(chunks) > 1 and min([len(x) for x in chunks[:-1]]) < 10:
        for i, chunk in enumerate(chunks[:-1]):
            if len(chunk) < 10:
                chunks = chunks[:i] + [chunk + chunks[i + 1]] + chunks[i + 2 :]
                break

    return [c.strip() for c in chunks if c.strip()]


test_cases = [
    # (input_text, expected_chunks)
    (
        "<think>First, I understand the problem. Next, I'll solve for x. Finally, I verify!</think>",
        ["First, I understand the problem.", "Next, I'll solve for x.", "Finally, I verify!"],
    ),
    (
        "<think>Let me break this down: 1. Convert to decimal. 2. Calculate log. 3. Apply formula.</think>",
        [
            "Let me break this down:",
            "1. Convert to decimal.",
            "2. Calculate log.",
            "3. Apply formula.",
        ],
    ),
    (
        "<think>The formula is A = πr². Wait. No. Actually, it's different.</think>",
        ["The formula is A = πr².", "Wait. No.", "Actually, it's different."],
    ),
    (
        "<think>Convert 66666₁₆ to decimal. This equals 419,430. How many bits? We need log₂(419,430) ≈ 18.7. So 19 bits!</think>",
        [
            "Convert 66666₁₆ to decimal.",
            "This equals 419,430.",
            "How many bits?",
            "We need log₂(419,430) ≈ 18.7.",
            "So 19 bits!",
        ],
    ),
    ("<think>The answer is 42. Done.</think>", ["The answer is 42.", "Done."]),
]

for input_text, expected_chunks in test_cases:
    chunks = split_solution_into_chunks(input_text)
    assert chunks == expected_chunks, f"Expected {expected_chunks}, got {chunks}"

Sentence categorization

Now that we've created a method for splitting up our reasoning traces into chunks, we can work out how to categorize them.

The paper uses a taxonomy of 8 different categories:

  1. Problem Setup: Parsing or rephrasing the problem
  2. Plan Generation: Stating or deciding on a plan of action, meta-reasoning
  3. Fact Retrieval: Recalling facts, formulas, problem details
  4. Active Computation: Algebra, calculations, manipulations
  5. Uncertainty Management: Expressing confusion, re-evaluating, backtracking
  6. Result Consolidation: Aggregating intermediate results, summarizing
  7. Self Checking: Verifying previous steps, checking calculations
  8. Final Answer Emission: Explicitly stating the final answer

There are 2 approaches usually taken for this kind of classification: heuristic-based (using regexes or keyword matching) and LLM-based (i.e. using an autorater). We'll try both, so we can compare the results.

Exercise - heuristic-based categorization

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-15 minutes on this exercise.

First, we'll implement a heuristic-based approach. You should do the following:

  • Fill out the CATEGORY_WORDS dictionary below, which maps each category to a list of words associated with that category. To get you started, we've filled out the first three categories.
  • Fill out the categorize_sentences_heuristic function below, which uses this dictionary to categorize sentences. We've given you a few example sentences to test your function - at minimum make sure your function works for these.

Once you've passed the test sentences below, you should try taking rollouts from your model above (or examples from the dataset) and see how your function performs on them. Some questions you might ask yourself:

  • Do you think this taxonomy is reasonable?
  • Are there any sentences that are misclassified, or belong to more than one category?
  • How many words do you need to add before your classification works decently?

Note that no heuristic-based classification will be perfect. The point of this exercise is to get you thinking about the different categories, and what the strengths / limitations of this kind of method are. In research, you should generally try not to reach for a tool more complicated than what you need!

CATEGORIES = {
    "problem_setup": "Problem Setup",
    "plan_generation": "Plan Generation",
    "fact_retrieval": "Fact Retrieval",
    "active_computation": "Active Computation",
    "uncertainty_management": "Uncertainty Management",
    "result_consolidation": "Result Consolidation",
    "self_checking": "Self Checking",
    "final_answer_emission": "Final Answer Emission",
    "unknown": "Unknown",
}

CATEGORY_WORDS = {
    # Note: we put the most definitive categories first, so they override the later ones
    "final_answer_emission": ["\\boxed", "final answer"],
    "problem_setup": ["need to", "problem is", "given"],
    "fact_retrieval": ["remember", "formula", "know that", "recall"],
    "active_computation": ["calculate", "compute", "solve", "=", "equals", "result", "giving"],
    "uncertainty_management": ["wait", "let me", "double check", "hmm", "actually", "reconsider"],
    "result_consolidation": ["summarize", "so", "therefore", "in summary"],
    "self_checking": ["verify", "check", "confirm", "correct"],
    "plan_generation": ["plan", "approach", "strategy", "will", "i'll", "try"],
}


def categorize_sentences_heuristic(chunks: list[str]) -> list[str]:
    """
    Categorize sentences using heuristics/keyword matching (simplified version).

    For full LLM-based labeling, see prompts.py DAG_PROMPT in the thought-anchors repo.

    Args:
        sentences: List of sentence strings

    Returns:
        List of tags.
    """

    categories = []

    for idx, chunk in enumerate(chunks):
        chunk_lower = chunk.lower()

        if idx == 0:
            tag = "problem_setup"
        else:
            for category, words in CATEGORY_WORDS.items():
                if any(word in chunk_lower for word in words):
                    tag = category
                    break
            else:
                tag = "unknown"

        categories.append(CATEGORIES.get(tag))

    return categories


example_problem_text = "What is the area of a circle with radius 5?"

example_sentences, example_categories = list(
    zip(
        *[
            ("I need to find the area of a circle with radius 5.", "Problem Setup"),
            ("The formula for circle area is A = πr².", "Fact Retrieval"),
            ("Substituting r = 5: A = π × 5² = 25π.", "Active Computation"),
            ("Wait, let me look again at that calculation.", "Uncertainty Management"),
            ("So the area is 25π square units.", "Result Consolidation"),
            ("Therefore, the answer is \\boxed{25π}.", "Final Answer Emission"),
        ]
    )
)

categories = categorize_sentences_heuristic(example_sentences)

for sentence, category, expected_category in zip(example_sentences, categories, example_categories):
    assert category == expected_category, (
        f"Expected {expected_category!r}, got {category!r} for sentence: {sentence!r}"
    )

Now, testing your function on an actual rollout, does it look reasonable? If not, you can try going back and tweaking the categories or the categorization logic.

easy_chunks = split_solution_into_chunks(easy_problem_full_text)

easy_categories = categorize_sentences_heuristic(easy_chunks)

for chunk, category in zip(easy_chunks, easy_categories):
    chunk_str = chunk if len(chunk) < 80 else chunk[:60] + " ... " + chunk[-20:]
    print(f"{category:>20} | {chunk_str!r}")

Exercise - implement an autorater

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 10-15 minutes on this exercise.

We'll now progress to a slightly more advanced approach for classification, using an autorater. This essentially means we're querying an LLM to do the categorization for us, rather than relying on hardcoded rules.

We'll start by setting up a helper function to call an API (to keep things simple we're sticking with OpenAI for now, but this could easily be modified to support other providers). It has the option of returning in structured output, which can be helpful for classification tasks.

You'll need to create an .env file in the current working directory and set the OPENAI_API_KEY environment variable, then you can run the cell below to see how the helper function works.

load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None


def generate_api_response(
    model: str = "gpt-4.1-mini",
    messages: list[dict[str, str]] = [],
    max_tokens: int = 128,
    stop_sequences: list[str] | None = None,
    temperature: float = 0.0,
    max_retries: int = 3,
) -> str:
    """Helper function with retry logic and error handling."""

    for attempt in range(max_retries):
        try:
            # Generate response
            resp = openai_client.chat.completions.create(
                model=model,
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                stop=stop_sequences if stop_sequences else None,
            )
            # Extract text from response
            return resp.choices[0].message.content or ""

        except Exception as e:
            if "rate_limit_error" in str(e) or "429" in str(e):
                if attempt < max_retries - 1:
                    # Exponential backoff: 2^attempt seconds (1, 2, 4, 8, 16...)
                    wait_time = 2**attempt
                    print(f"Rate limit hit, waiting {wait_time} seconds before retry...")
                    time.sleep(wait_time)
                    continue
                else:
                    print(f"Failed to get response after {max_retries} attempts, returning early.")
            raise e


resp = generate_api_response(messages=[{"role": "user", "content": "Who are you?"}])

print(textwrap.fill(resp))

Note the default use of temperature zero in the autorater above. This is because we're looking for reproducibility and reliability; we don't benefit from output diversity in this case. A temperature of zero is also useful for reproducibility.

We've also given you the prompt we'll use for the autorater - it's a simplified version of the DAG_SYSTEM_PROMPT in the author's file thought-anchors/prompts.py. Note that the authors allow for a chunk to receive multiple tags; we assume a single tag per chunk to keep things simple.

DAG_SYSTEM_PROMPT = """
You are an expert in interpreting how language models solve math problems using multi-step reasoning. Your task is to analyze a Chain-of-Thought (CoT) reasoning trace, broken into discrete text chunks, and label each chunk with a tag that describes what this chunk is *doing* functionally in the reasoning process.

---

### Function Tags:

1. `problem_setup`: 
    Parsing or rephrasing the problem (initial reading or comprehension).

2. `plan_generation`: 
    Stating or deciding on a plan of action, or on the next step in the reasoning process.

3. `fact_retrieval`: 
    Recalling facts, formulas, problem details (without immediate computation).

4. `active_computation`: 
    Performing algebra, calculations, manipulations toward the answer.
    This only includes actual calculations being done & values computed, not stating formulas or plans.

5. `result_consolidation`: 
    Aggregating intermediate results, summarizing, or preparing final answer.

6. `uncertainty_management`: 
    Expressing confusion, re-evaluating, proposing alternative plans (includes backtracking).

7. `final_answer_emission`: 
    Explicit statement of the final boxed answer or earlier chunks that contain the final answer.

8. `self_checking`: 
    Verifying previous steps, Pythagorean checking, re-confirmations.
    This is at the object-level, whereas "uncertainty_management" is at the planning level.

9. `unknown`: 
    Use only if the chunk does not fit any of the above tags or is purely stylistic or semantic.

---

### Output Format:

Return a numbered list, one item for each chunk, consisting of the function tag that best describes the chunk.

For example:

1. problem_setup
...
5. final_answer_emission
"""

You should now complete the function categorize_sentences_autorater below, which will use the DAG_SYSTEM_PROMPT and the generate_api_response function we've already written to categorize the chunks. You'll need to supply your own user prompt to go with the system prompt above.

def categorize_sentences_autorater(problem_text: str, chunks: list[str]) -> list[str]:
    """
    Categorize sentences using heuristics/keyword matching (simplified version).

    Args:
        sentences: List of sentence strings

    Returns:
        List of categories.
    """

    # YOUR CODE HERE

    chunk_str = ""
    for i, chunk in enumerate(chunks):
        chunk_str += f"{i + 1}. {chunk}\n"

    user_prompt = f"""
Here is the math problem:

[PROBLEM]
{problem_text}

Here is the full Chain of Thought, broken into chunks:

[CHUNKS]
{chunk_str.strip()}

Now label each chunk with function tags and dependencies."""

    raw_response = generate_api_response(
        messages=[
            {"role": "system", "content": DAG_SYSTEM_PROMPT},
            {"role": "user", "content": user_prompt},
        ],
    )
    response = re.split(r"\n\d{1,2}\.", "\n" + raw_response.strip())
    response = [r.strip() for r in response if r.strip()]

    assert len(response) == len(chunks), f"Length mismatch: {len(response)} != {len(chunks)}"

    return [CATEGORIES.get(r, "Unknown") for r in response]

You can run the cell below to see how your function performs on the example sentences. Does it agree with the heuristic-based approach? Do you think it's better or worse? You can also try it on the rollouts you generated earlier.

example_categories = categorize_sentences_autorater(example_problem_text, example_sentences)

for chunk, category in zip(example_sentences, example_categories):
    print(f"{category:>25} | {chunk!r}")
easy_categories_autorater = categorize_sentences_autorater(easy_problem_text, easy_chunks)

df = pd.DataFrame(
    {
        "Category (heuristics)": easy_categories,
        "Category (autorater)": easy_categories_autorater,
        "Chunk": easy_chunks,
    }
)
with pd.option_context("display.max_colwidth", None):
    display(df)

Before moving to the next section, we'll make a helper function to visualize reasoning traces:

CATEGORY_COLORS = {
    "Problem Setup": "#4285F4",
    "Plan Generation": "#EA4335",
    "Fact Retrieval": "#FBBC05",
    "Active Computation": "#34A853",
    "Uncertainty Management": "#9C27B0",
    "Result Consolidation": "#00BCD4",
    "Self Checking": "#FF9800",
    "Final Answer Emission": "#795548",
    "Unknown": "#9E9E9E",
}


def visualize_trace_structure(chunks: list[str], categories: list[str], problem_text: str = None):
    """Visualize a reasoning trace with color-coded sentence categories."""

    n_chunks = len(chunks)
    fig, ax = plt.subplots(figsize=(12, 1 + int(0.5 * n_chunks)))

    for idx, (chunk, category) in enumerate(zip(chunks, categories)):
        color = CATEGORY_COLORS.get(category, "#9E9E9E")

        y = n_chunks - idx  # Start from top

        # Category label with colored background
        ax.barh(y, 0.15, left=0, height=0.8, color=color, alpha=0.6)
        ax.text(0.075, y, f"{category}", ha="center", va="center", fontsize=9, weight="bold")

        # Sentence text
        text = chunk[:100] + ("..." if len(chunk) > 100 else "")
        ax.text(0.17, y, f"[{idx}] {text}", va="center", fontsize=9)

    ax.set_xlim(0, 1)
    ax.set_ylim(0.5, n_chunks + 0.5)
    ax.axis("off")

    if problem_text:
        fig.suptitle(f"Problem: {problem_text[:120]}...", fontsize=11, y=0.98, weight="bold")

    plt.title("Reasoning Trace Structure", fontsize=13, pad=30)
    plt.tight_layout()
    plt.show()


visualize_trace_structure(easy_chunks, easy_categories, easy_problem_text)