Exercise Status: All exercises complete and verified

2️⃣ Implementing oracle components

Learning Objectives
  • Implement activation extraction using forward hooks
  • Understand the special token mechanism (? tokens as activation placeholders)
  • Build activation steering hooks to inject activations into the oracle
  • Create training datapoints with the correct format
  • Assemble all components to replicate the utils.run_oracle() function

Now that you've seen what oracles can do from the outside, let's crack them open and build one ourselves. Up to this point you've been calling utils.run_oracle() as a black box: you hand it some activations and a question, and it hands back an answer. But there's a lot happening behind the scenes, and understanding the internals will make you a much better user of the tool (and help you debug it when things go wrong, which they will).

Before we start, think about what pieces you'd need. You need some way to extract activations from the target model at a specific layer -- that means hooks. You need a mechanism for telling the oracle where in its input the activations should go -- that's the ? token placeholders. You need to actually inject those activations into the oracle's forward pass at the right moment -- that's the steering hook. And you need to package all of this up into a format the oracle can consume during training. We'll build each of these components in turn, and by the end you'll have a from-scratch reimplementation of the full oracle pipeline.

Activation extraction with hooks

The first step is extracting activations from the target model. We'll use PyTorch's forward hooks to intercept the residual stream at specific layers. Forward hooks let us capture intermediate activations during the forward pass by hooking into the residual stream submodule at the desired layer. We can stop early after capturing the target activations (no need to run the full model), and we need to handle batching and padding correctly.

First, we need a helper to get the right submodule for a given layer:

# Layer configuration
LAYER_COUNTS = {
    "Qwen/Qwen3-1.7B": 28,
    "Qwen/Qwen3-8B": 36,
    "Qwen/Qwen3-32B": 64,
    "google/gemma-2-9b-it": 42,
    "google/gemma-3-1b-it": 26,
    "meta-llama/Llama-3.2-1B-Instruct": 16,
    "meta-llama/Llama-3.1-8B-Instruct": 32,
    "meta-llama/Llama-3.3-70B-Instruct": 80,
}


def layer_fraction_to_layer(model_name: str, layer_fraction: float) -> int:
    """Convert a layer fraction (0.0-1.0) to a layer number."""
    max_layers = LAYER_COUNTS[model_name]
    return int(max_layers * layer_fraction)


def get_hf_submodule(model: AutoModelForCausalLM, layer: int) -> torch.nn.Module:
    """
    Gets the residual stream submodule for HuggingFace transformers.

    Args:
        model: The model
        layer: Which layer to hook

    Returns:
        The submodule to hook (the layer's output is the residual stream)
    """
    model_name = model.config._name_or_path
    assert re.search("gemma|mistral|Llama|Qwen", model_name), (
        f"Model name {model_name!r} is not supported. Supported architectures: Gemma, Mistral, Llama, Qwen."
    )
    return model.model.layers[layer]


# Check it works as expected
_ = get_hf_submodule(model, layer=LAYER_COUNTS[MODEL_NAME] - 1)
with pytest.raises(IndexError):
    _ = get_hf_submodule(model, layer=LAYER_COUNTS[MODEL_NAME])

Now we'll implement the activation collection function. We need a custom exception for early stopping:

class EarlyStopException(Exception):
    """Custom exception for stopping model forward pass early."""

    pass

Exercise - Implement collect_activations_multiple_layers

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 20-25 minutes on this exercise. This is one of the most important exercises as it teaches you how to extract activations using hooks.

Implement a function that collects activations from multiple layers using forward hooks. The function should:

  1. Register forward hooks on specified submodules
  2. During the hook, store the activation tensor in a dictionary
  3. Optionally slice activations using start_offset and end_offset (negative indices from end)
  4. Raise EarlyStopException after capturing from the last layer (no need to continue forward pass)
  5. Clean up hooks in a finally block

A few details to keep in mind: the hook receives (module, inputs, outputs) and the outputs are what we want. Some models return (tensor, *rest) as outputs, so handle both cases. Use max_layer to determine when to stop early, and handle start_offset/end_offset by slicing: activations[:, start_offset:end_offset, :].

def collect_activations_multiple_layers(
    model: AutoModelForCausalLM,
    submodules: dict[int, torch.nn.Module],
    inputs_BL: dict[str, Int[Tensor, "batch seq"]],
    start_offset: int | None,
    end_offset: int | None,
) -> dict[int, Float[Tensor, "batch seq d_model"]]:
    """
    Collect activations from multiple layers using forward hooks.

    Args:
        model: The target model
        submodules: Dict mapping layer number to submodule to hook
        inputs_BL: Tokenized inputs (input_ids, attention_mask)
        start_offset: Start of the token slice (negative index from end). Only used when `end_offset`
            is also non-None; if `end_offset` is None, this must also be None (no slicing is applied).
        end_offset: End of the token slice (negative index from end, exclusive). Set both `start_offset`
            and `end_offset` to non-None values to enable token-position slicing; if both are None,
            the full sequence activations are returned.

    Returns:
        Dict mapping layer → activations tensor [batch, length, d_model]
    """
    raise NotImplementedError()


# Test the function
test_prompt = "The capital of France is"
test_inputs = tokenizer(test_prompt, return_tensors="pt", add_special_tokens=False).to(device)

# Extract from layer 18 (50% of 36 layers)
layer = layer_fraction_to_layer(MODEL_NAME, 0.5)
submodules = {layer: get_hf_submodule(model, layer)}

activations = collect_activations_multiple_layers(
    model=model,
    submodules=submodules,
    inputs_BL=test_inputs,
    start_offset=None,
    end_offset=None,
)

print(f"Extracted activations from layer {layer}")
print(f"Shape: {activations[layer].shape}")  # Should be [1, seq_len, d_model]

tests.test_collect_activations_multiple_layers(collect_activations_multiple_layers, model, tokenizer, device)
Solution
def collect_activations_multiple_layers(
    model: AutoModelForCausalLM,
    submodules: dict[int, torch.nn.Module],
    inputs_BL: dict[str, Int[Tensor, "batch seq"]],
    start_offset: int | None,
    end_offset: int | None,
) -> dict[int, Float[Tensor, "batch seq d_model"]]:
    """
    Collect activations from multiple layers using forward hooks.

    Args:
        model: The target model
        submodules: Dict mapping layer number to submodule to hook
        inputs_BL: Tokenized inputs (input_ids, attention_mask)
        start_offset: Start of the token slice (negative index from end). Only used when `end_offset`
            is also non-None; if `end_offset` is None, this must also be None (no slicing is applied).
        end_offset: End of the token slice (negative index from end, exclusive). Set both `start_offset`
            and `end_offset` to non-None values to enable token-position slicing; if both are None,
            the full sequence activations are returned.

    Returns:
        Dict mapping layer → activations tensor [batch, length, d_model]
    """
    if end_offset is not None:
        assert start_offset is not None
        assert start_offset < end_offset
        assert end_offset < 0
        assert start_offset < 0
    else:
        assert start_offset is None

    activations_BLD_by_layer = {}
    module_to_layer = {submodule: layer for layer, submodule in submodules.items()}
    max_layer = max(submodules.keys())

    def gather_target_act_hook(module, inputs, outputs):
        layer = module_to_layer[module]
        # Handle different output formats
        if isinstance(outputs, tuple):
            activations_BLD_by_layer[layer] = outputs[0]
        else:
            activations_BLD_by_layer[layer] = outputs

        # Slice if requested
        if end_offset is not None:
            activations_BLD_by_layer[layer] = activations_BLD_by_layer[layer][:, start_offset:end_offset, :]

        # Early stop after max layer
        if layer == max_layer:
            raise EarlyStopException("Early stopping after capturing activations")

    # Register hooks
    handles = []
    for layer, submodule in submodules.items():
        handles.append(submodule.register_forward_hook(gather_target_act_hook))

    try:
        with torch.no_grad():
            _ = model(**inputs_BL)
    except EarlyStopException:
        pass  # Expected
    except Exception as e:
        print(f"Unexpected error during forward pass: {str(e)}")
        raise
    finally:
        # Clean up hooks
        for handle in handles:
            handle.remove()

    return activations_BLD_by_layer


# Test the function
test_prompt = "The capital of France is"
test_inputs = tokenizer(test_prompt, return_tensors="pt", add_special_tokens=False).to(device)

# Extract from layer 18 (50% of 36 layers)
layer = layer_fraction_to_layer(MODEL_NAME, 0.5)
submodules = {layer: get_hf_submodule(model, layer)}

activations = collect_activations_multiple_layers(
    model=model,
    submodules=submodules,
    inputs_BL=test_inputs,
    start_offset=None,
    end_offset=None,
)

print(f"Extracted activations from layer {layer}")
print(f"Shape: {activations[layer].shape}")  # Should be [1, seq_len, d_model]

tests.test_collect_activations_multiple_layers(collect_activations_multiple_layers, model, tokenizer, device)

Special token mechanism

Oracles use special ? tokens as placeholders where target model activations will be injected. The oracle is trained to expect these tokens and knows to use the injected activations instead of its own computed activations at those positions.

The format is:

Layer: X
? ? ?
<your question>

Where: - Layer: X tells the oracle which layer the activations came from - ? ? ? are placeholders (one for each activation vector) - Your question comes after

SPECIAL_TOKEN = " ?"


def get_introspection_prefix(layer: int, num_positions: int) -> str:
    """Create the prefix for oracle prompts with ? tokens."""
    prefix = f"Layer: {layer}\n"
    prefix += SPECIAL_TOKEN * num_positions
    prefix += " \n"
    return prefix


# Test it
prefix = get_introspection_prefix(layer=18, num_positions=5)
print(f"Introspection prefix:\n{prefix!r}")

Exercise - Implement find_pattern_in_tokens

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 10-15 minutes on this exercise.

Implement a function that finds the positions of special ? tokens in a tokenized sequence. This is a key piece of the oracle pipeline: when we construct an oracle prompt with ? placeholders, we need to know exactly where those placeholders ended up in the token ID sequence so we can inject the target model's activations at those positions. Getting this wrong means the oracle would receive activations at the wrong positions, silently producing garbage results. If you ever need to train your own oracle or debug injection issues, understanding this mapping between prompt text and token positions is essential.

The function should: 1. Encode the special token string to get its token ID 2. Find all positions where this token appears in the full sequence (don't stop early, since the user's oracle prompt might contain a literal ? which would create extra matches) 3. Verify we found exactly num_positions tokens (raise ValueError if not) 4. Verify they're consecutive (this is a sanity check) 5. Return the list of positions

def find_pattern_in_tokens(
    token_ids: list[int],
    special_token_str: str,
    num_positions: int,
    tokenizer: AutoTokenizer,
) -> list[int]:
    """
    Find positions of special token in tokenized sequence.

    Args:
        token_ids: List of token IDs
        special_token_str: The special token string (e.g., " ?")
        num_positions: Expected number of occurrences
        tokenizer: Tokenizer to encode special token

    Returns:
        List of positions where special token appears
    """
    raise NotImplementedError()


# Test the function
test_text = "Layer: 18\n ? ? ? \nWhat is this?"
test_tokens = tokenizer.encode(test_text, add_special_tokens=False)
positions = find_pattern_in_tokens(test_tokens, SPECIAL_TOKEN, 3, tokenizer)
print(f"Found ? tokens at positions: {positions}")

tests.test_find_pattern_in_tokens(find_pattern_in_tokens, tokenizer)
Solution
def find_pattern_in_tokens(
    token_ids: list[int],
    special_token_str: str,
    num_positions: int,
    tokenizer: AutoTokenizer,
) -> list[int]:
    """
    Find positions of special token in tokenized sequence.

    Args:
        token_ids: List of token IDs
        special_token_str: The special token string (e.g., " ?")
        num_positions: Expected number of occurrences
        tokenizer: Tokenizer to encode special token

    Returns:
        List of positions where special token appears
    """
    special_token_id = tokenizer.encode(special_token_str, add_special_tokens=False)
    assert len(special_token_id) == 1, f"Expected single token, got {len(special_token_id)}"
    special_token_id = special_token_id[0]

    # Find ALL positions where this token appears (don't stop early)
    positions = [i for i, tid in enumerate(token_ids) if tid == special_token_id]

    if len(positions) != num_positions:
        raise ValueError(
            f"Expected {num_positions} occurrences of special token, but found {len(positions)}. "
            f"This can happen if your oracle prompt contains a literal '?' character."
        )
    assert positions[-1] - positions[0] == num_positions - 1, f"Positions are not consecutive: {positions}"

    return positions


# Test the function
test_text = "Layer: 18\n ? ? ? \nWhat is this?"
test_tokens = tokenizer.encode(test_text, add_special_tokens=False)
positions = find_pattern_in_tokens(test_tokens, SPECIAL_TOKEN, 3, tokenizer)
print(f"Found ? tokens at positions: {positions}")

tests.test_find_pattern_in_tokens(find_pattern_in_tokens, tokenizer)

Activation steering

Now we need to inject the target model's activations into the oracle at the ? token positions. This is done via a forward hook that intercepts the oracle's activations and replaces them at specific positions. The vectors need to be normalized to preserve original activation norms, and the injection happens at an early layer (typically layer 1) so the oracle can process them through its remaining layers.

Exercise - Implement get_hf_activation_steering_hook

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 25-30 minutes on this exercise. This is one of the key components - the hook that actually injects activations.

Implement a function that returns a forward hook for activation steering (assuming batch_size=1). The hook should:

  1. Extract the residual stream tensor from outputs (handle tuple case)
  2. Verify batch_size is 1 (raise error if not)
  3. Get the original activations at the specified positions
  4. Normalize the steering vectors to have the same norm as the originals
  5. Apply steering coefficient and add to original activations
  6. Return modified outputs in the same format (tuple or tensor)

The core formula for each position $i$ is:

$$h'_i = h_i + \|h_i\| \cdot c \cdot \frac{v_i}{\|v_i\|}$$

where $h_i$ is the original activation, $v_i$ is the steering vector, and $c$ is the steering coefficient. In other words: normalize the steering vector to unit norm, scale it to match the original activation's magnitude (times the coefficient), then add to the original. This norm-matching is important because the paper found that direct replacement caused 100,000x norm explosion (Appendix A.5).

Some implementation details: use torch.nn.functional.normalize(vector, dim=-1) for unit normalization. Detach steering vectors before adding to avoid gradients. Verify positions are within sequence length, and skip if L <= 1.

@contextlib.contextmanager
def add_hook(module: torch.nn.Module, hook: Callable):
    """Temporarily adds a forward hook to a model module."""
    handle = module.register_forward_hook(hook)
    try:
        yield
    finally:
        handle.remove()


def get_hf_activation_steering_hook(
    vectors: Float[Tensor, "num_pos d_model"],
    positions: list[int],
    steering_coefficient: float,
    device: torch.device,
    dtype: torch.dtype,
) -> Callable:
    """
    Create hook that injects activations at specified positions (assumes batch_size=1).

    Args:
        vectors: Steering vectors [K, d_model] where K is number of positions
        positions: List of positions to inject at
        steering_coefficient: Multiplier for steering strength
        device: Device for tensors
        dtype: Data type for steering

    Returns:
        Hook function that modifies activations during forward pass
    """
    raise NotImplementedError()


# Test the function
# Create dummy data (batch_size=1)
test_positions = [5, 6, 7]  # Inject at positions 5, 6, 7
test_vectors = torch.randn(len(test_positions), model.config.hidden_size, device=device)

hook_fn = get_hf_activation_steering_hook(
    vectors=test_vectors,
    positions=test_positions,
    steering_coefficient=1.0,
    device=device,
    dtype=dtype,
)

# Create dummy activations
dummy_resid = torch.randn(1, 20, model.config.hidden_size, device=device)
orig_values = dummy_resid[0, test_positions, :].clone()

# Apply hook
modified_resid = hook_fn(None, None, dummy_resid)

# Check modifications occurred
new_values = modified_resid[0, test_positions[0], :]
assert not torch.allclose(orig_values, new_values), "Hook should modify activations"
print("Steering hook test passed!")

tests.test_get_hf_activation_steering_hook(get_hf_activation_steering_hook, device, model.config.hidden_size)
tests.test_get_hf_activation_steering_hook_matches_reference(
    get_hf_activation_steering_hook, device, model.config.hidden_size
)
Solution
@contextlib.contextmanager
def add_hook(module: torch.nn.Module, hook: Callable):
    """Temporarily adds a forward hook to a model module."""
    handle = module.register_forward_hook(hook)
    try:
        yield
    finally:
        handle.remove()


def get_hf_activation_steering_hook(
    vectors: Float[Tensor, "num_pos d_model"],
    positions: list[int],
    steering_coefficient: float,
    device: torch.device,
    dtype: torch.dtype,
) -> Callable:
    """
    Create hook that injects activations at specified positions (assumes batch_size=1).

    Args:
        vectors: Steering vectors [K, d_model] where K is number of positions
        positions: List of positions to inject at
        steering_coefficient: Multiplier for steering strength
        device: Device for tensors
        dtype: Data type for steering

    Returns:
        Hook function that modifies activations during forward pass
    """
    # Normalize vectors to unit norm
    normed_vectors = torch.nn.functional.normalize(vectors, dim=-1).detach()
    positions_tensor = torch.tensor(positions, dtype=torch.long, device=device)

    def hook_fn(module, _input, output):
        # Extract residual stream tensor
        if isinstance(output, tuple):
            resid_BLD, *rest = output
            output_is_tuple = True
        else:
            resid_BLD = output
            output_is_tuple = False

        B, L, d_model = resid_BLD.shape

        if B != 1:
            raise ValueError(f"Expected batch_size=1, got B={B}")

        if L <= 1:
            return (resid_BLD, *rest) if output_is_tuple else resid_BLD

        # Check positions are valid
        assert positions_tensor.min() >= 0
        assert positions_tensor.max() < L, f"Position {positions_tensor.max()} >= sequence length {L}"

        # Get original activations at steering positions
        orig_KD = resid_BLD[0, positions_tensor, :]  # [K, d_model]
        norms_K1 = orig_KD.norm(dim=-1, keepdim=True)  # [K, 1]

        # Scale normalized steering vectors by original magnitudes
        steered_KD = (normed_vectors * norms_K1 * steering_coefficient).to(dtype)

        # Inject (add to original)
        resid_BLD[0, positions_tensor, :] = steered_KD.detach() + orig_KD

        return (resid_BLD, *rest) if output_is_tuple else resid_BLD

    return hook_fn


# Test the function
# Create dummy data (batch_size=1)
test_positions = [5, 6, 7]  # Inject at positions 5, 6, 7
test_vectors = torch.randn(len(test_positions), model.config.hidden_size, device=device)

hook_fn = get_hf_activation_steering_hook(
    vectors=test_vectors,
    positions=test_positions,
    steering_coefficient=1.0,
    device=device,
    dtype=dtype,
)

# Create dummy activations
dummy_resid = torch.randn(1, 20, model.config.hidden_size, device=device)
orig_values = dummy_resid[0, test_positions, :].clone()

# Apply hook
modified_resid = hook_fn(None, None, dummy_resid)

# Check modifications occurred
new_values = modified_resid[0, test_positions[0], :]
assert not torch.allclose(orig_values, new_values), "Hook should modify activations"
print("Steering hook test passed!")

tests.test_get_hf_activation_steering_hook(get_hf_activation_steering_hook, device, model.config.hidden_size)
tests.test_get_hf_activation_steering_hook_matches_reference(
    get_hf_activation_steering_hook, device, model.config.hidden_size
)

Training datapoint format

Before we assemble the full pipeline, let's understand the OracleInput structure that oracles use. This format is used both during oracle training and during inference.

@dataclass
class OracleInput:
    """Simplified datapoint for oracle inference (no training-specific fields)."""

    input_ids: list[int]
    layer: int
    steering_vectors: Float[Tensor, "num_pos d_model"]
    positions: list[int]


@dataclass
class OracleResults:
    oracle_lora_path: str | None
    target_lora_path: str | None
    target_prompt: str
    act_key: str
    oracle_prompt: str
    num_tokens: int
    token_responses: list[str | None]
    full_sequence_responses: list[str]
    segment_responses: list[str]
    target_input_ids: list[int]

Exercise - Implement create_oracle_input

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 15-20 minutes on this exercise.

Implement a function that creates an OracleInput for oracle inference. The function should:

  1. Add the introspection prefix (with ? tokens) to the prompt
  2. Format using the chat template with add_generation_prompt=True (ends with <|im_start|>assistant\n)
  3. Find ? token positions in the tokenized sequence
  4. Return an OracleInput with all fields filled in

Use tokenizer.apply_chat_template() with add_generation_prompt=True and enable_thinking=False for inference. The activations should be cloned and detached to CPU.

def create_oracle_input(
    prompt: str,
    layer: int,
    num_positions: int,
    tokenizer: AutoTokenizer,
    acts_BD: Float[Tensor, "num_pos d_model"],
) -> OracleInput:
    """
    Create an oracle input for inference.

    Args:
        prompt: Question to ask the oracle
        layer: Layer the activations came from
        num_positions: Number of ? tokens (equals length of acts_BD)
        tokenizer: Tokenizer
        acts_BD: Activation vectors [num_positions, d_model]

    Returns:
        OracleInput ready for generation
    """
    raise NotImplementedError()


# Test the function
test_activations = torch.randn(3, model.config.hidden_size)
datapoint = create_oracle_input(
    prompt="What is the model thinking about?",
    layer=18,
    num_positions=3,
    tokenizer=tokenizer,
    acts_BD=test_activations,
)

print(f"Created datapoint with {len(datapoint.input_ids)} tokens")
print(f"? tokens at positions: {datapoint.positions}")

tests.test_create_oracle_input(create_oracle_input, tokenizer, model.config.hidden_size)
Solution
def create_oracle_input(
    prompt: str,
    layer: int,
    num_positions: int,
    tokenizer: AutoTokenizer,
    acts_BD: Float[Tensor, "num_pos d_model"],
) -> OracleInput:
    """
    Create an oracle input for inference.

    Args:
        prompt: Question to ask the oracle
        layer: Layer the activations came from
        num_positions: Number of ? tokens (equals length of acts_BD)
        tokenizer: Tokenizer
        acts_BD: Activation vectors [num_positions, d_model]

    Returns:
        OracleInput ready for generation
    """
    # Add introspection prefix with ? tokens
    prefix = get_introspection_prefix(layer, num_positions)
    prompt = prefix + prompt
    input_messages = [{"role": "user", "content": prompt}]

    # Create prompt with generation template (ends with <|im_start|>assistant\n)
    input_prompt_ids = tokenizer.apply_chat_template(
        input_messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors=None,
        padding=False,
        enable_thinking=False,
    )

    # Find ? token positions in the prompt
    positions = find_pattern_in_tokens(input_prompt_ids, SPECIAL_TOKEN, num_positions, tokenizer)

    # Ensure activations are on CPU and detached
    acts_BD = acts_BD.cpu().clone().detach()

    return OracleInput(
        input_ids=input_prompt_ids,
        layer=layer,
        steering_vectors=acts_BD,
        positions=positions,
    )


# Test the function
test_activations = torch.randn(3, model.config.hidden_size)
datapoint = create_oracle_input(
    prompt="What is the model thinking about?",
    layer=18,
    num_positions=3,
    tokenizer=tokenizer,
    acts_BD=test_activations,
)

print(f"Created datapoint with {len(datapoint.input_ids)} tokens")
print(f"? tokens at positions: {datapoint.positions}")

tests.test_create_oracle_input(create_oracle_input, tokenizer, model.config.hidden_size)

Assembling the full oracle pipeline

Now we'll combine all the components to build our own version of utils.run_oracle(). This is a significant exercise that brings everything together.

Exercise - Build utils.run_oracle() from components

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 30-40 minutes on this exercise. This is the capstone of Section 2 - you're building the full oracle pipeline.

Implement a simplified version of utils.run_oracle() that only handles full-sequence queries. We've provided the scaffolding code that handles tokenization, padding extraction, batch construction, and response decoding. Your job is to fill in the core pipeline:

  1. Collect activations: switch to the base model adapter (model.set_adapter("default")), then use collect_activations_multiple_layers() to extract activations from the target model
  2. Create the oracle input: use create_oracle_input() with the extracted activations
  3. Build and apply the steering hook: create a hook with get_hf_activation_steering_hook(), switch to the oracle adapter (model.set_adapter("oracle")), and generate with the hook applied using the add_hook context manager

We inject activations at layer 1 of the oracle (not the layer we extracted from). This is by design: injecting early gives the oracle's remaining layers maximum depth to process the injected information, rather than inserting it partway through and giving the oracle fewer layers to work with.

def run_oracle(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    target_prompt: str,
    oracle_prompt: str,
    layer_fraction: float = 0.5,
    device: torch.device = device,
) -> str:
    """
    Run oracle query from scratch using components we built.

    Args:
        model: Model with oracle LoRA loaded
        tokenizer: Tokenizer
        target_prompt: Prompt to analyze (already formatted with chat template)
        oracle_prompt: Question to ask about activations
        layer_fraction: Which layer to extract from (as fraction of total, 0.0-1.0)
        device: Device

    Returns:
        Oracle's response as string
    """
    # For oracle sampling
    generation_kwargs = {"do_sample": False, "temperature": 0.0, "max_new_tokens": 50}

    # Tokenize target prompt and extract non-padding positions
    inputs_BL = tokenizer(target_prompt, return_tensors="pt", add_special_tokens=False).to(device)
    model_name = model.config._name_or_path
    act_layer = layer_fraction_to_layer(model_name, layer_fraction)
    submodules = {act_layer: get_hf_submodule(model, act_layer)}

    seq_len = inputs_BL["input_ids"].shape[1]
    attn_mask = inputs_BL["attention_mask"][0]
    real_len = int(attn_mask.sum().item())
    left_pad = seq_len - real_len

    # YOUR CODE HERE - fill in the 3 steps described in the exercise:
    # (1) Collect activations from the target model (switch to "default" adapter first)
    # (2) Create an OracleInput using create_oracle_input()
    # (3) Build a steering hook, switch to "oracle" adapter, generate with the hook applied
    raise NotImplementedError()

    # Decode response
    generated_tokens = output_ids[:, input_ids.shape[1] :]
    response = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

    return response


# Test our implementation
target_prompt_dict = [{"role": "user", "content": "The capital of France is"}]
target_prompt = tokenizer.apply_chat_template(target_prompt_dict, tokenize=False, add_generation_prompt=True)
oracle_prompt = "What answer will the model give, as a single token?"

our_response = run_oracle(
    model=model,
    tokenizer=tokenizer,
    target_prompt=target_prompt,
    oracle_prompt=oracle_prompt,
    layer_fraction=0.5,
    device=device,
)

print(f"Our implementation response: {our_response!r}")

# Compare to library version
library_results = utils.run_oracle(
    model=model,
    tokenizer=tokenizer,
    device=device,
    target_prompt=target_prompt,
    target_lora_path=None,
    oracle_prompt=oracle_prompt,
    oracle_lora_path="oracle",
    oracle_input_type="full_seq",
)
library_response = library_results.full_sequence_responses[0]

print(f"Library response: {library_response!r}")
assert our_response.strip().lower() == library_response.strip().lower()
Click to see the expected output
Our implementation response: 'Paris'
Library response: 'Paris'
Solution
def run_oracle(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    target_prompt: str,
    oracle_prompt: str,
    layer_fraction: float = 0.5,
    device: torch.device = device,
) -> str:
    """
    Run oracle query from scratch using components we built.

    Args:
        model: Model with oracle LoRA loaded
        tokenizer: Tokenizer
        target_prompt: Prompt to analyze (already formatted with chat template)
        oracle_prompt: Question to ask about activations
        layer_fraction: Which layer to extract from (as fraction of total, 0.0-1.0)
        device: Device

    Returns:
        Oracle's response as string
    """
    # For oracle sampling
    generation_kwargs = {"do_sample": False, "temperature": 0.0, "max_new_tokens": 50}

    # Tokenize target prompt and extract non-padding positions
    inputs_BL = tokenizer(target_prompt, return_tensors="pt", add_special_tokens=False).to(device)
    model_name = model.config._name_or_path
    act_layer = layer_fraction_to_layer(model_name, layer_fraction)
    submodules = {act_layer: get_hf_submodule(model, act_layer)}

    seq_len = inputs_BL["input_ids"].shape[1]
    attn_mask = inputs_BL["attention_mask"][0]
    real_len = int(attn_mask.sum().item())
    left_pad = seq_len - real_len

    # Step 1: Collect activations from the target model
    model.set_adapter("default")
    acts_by_layer = collect_activations_multiple_layers(
        model=model,
        submodules=submodules,
        inputs_BL=inputs_BL,
        start_offset=None,
        end_offset=None,
    )

    # Extract activations for all non-padding positions
    num_positions = real_len
    acts_BD = acts_by_layer[act_layer][0, left_pad:, :]  # [num_positions, d_model]

    # Step 2: Create oracle input
    datapoint = create_oracle_input(
        prompt=oracle_prompt,
        layer=act_layer,
        num_positions=num_positions,
        tokenizer=tokenizer,
        acts_BD=acts_BD,
    )

    # Step 3: Build steering hook, switch to oracle, generate
    input_ids = torch.tensor([datapoint.input_ids], dtype=torch.long, device=device)
    attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
    steering_vectors = datapoint.steering_vectors.to(device)
    positions = datapoint.positions

    injection_layer = 1  # Inject at layer 1 (gives oracle max processing depth)
    injection_submodule = get_hf_submodule(model, injection_layer)

    hook_fn = get_hf_activation_steering_hook(
        vectors=steering_vectors,
        positions=positions,
        steering_coefficient=1.0,
        device=device,
        dtype=dtype,
    )

    model.set_adapter("oracle")

    with add_hook(injection_submodule, hook_fn):
        output_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, **generation_kwargs)

    # Decode response
    generated_tokens = output_ids[:, input_ids.shape[1] :]
    response = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

    return response


# Test our implementation
target_prompt_dict = [{"role": "user", "content": "The capital of France is"}]
target_prompt = tokenizer.apply_chat_template(target_prompt_dict, tokenize=False, add_generation_prompt=True)
oracle_prompt = "What answer will the model give, as a single token?"

our_response = run_oracle(
    model=model,
    tokenizer=tokenizer,
    target_prompt=target_prompt,
    oracle_prompt=oracle_prompt,
    layer_fraction=0.5,
    device=device,
)

print(f"Our implementation response: {our_response!r}")

# Compare to library version
library_results = utils.run_oracle(
    model=model,
    tokenizer=tokenizer,
    device=device,
    target_prompt=target_prompt,
    target_lora_path=None,
    oracle_prompt=oracle_prompt,
    oracle_lora_path="oracle",
    oracle_input_type="full_seq",
)
library_response = library_results.full_sequence_responses[0]

print(f"Library response: {library_response!r}")
assert our_response.strip().lower() == library_response.strip().lower()

You've now built the entire oracle pipeline from scratch, which means none of it is magic anymore. You should now have a concrete mental model for what happens when you call run_oracle(): activation extraction via hooks, the ? token placeholder mechanism, norm-matched steering, and adapter switching. In the next section, we'll put this machinery to work on alignment-relevant problems, starting with the question of whether you can extract information that a model has been deliberately trained to hide.