2️⃣ Implementing oracle components
Learning Objectives
- Implement activation extraction using forward hooks
- Understand the special token mechanism (
?tokens as activation placeholders)- Build activation steering hooks to inject activations into the oracle
- Create training datapoints with the correct format
- Assemble all components to replicate the
utils.run_oracle()function
Now that you've seen what oracles can do from the outside, let's crack them open and build one ourselves. Up to this point you've been calling utils.run_oracle() as a black box: you hand it some activations and a question, and it hands back an answer. But there's a lot happening behind the scenes, and understanding the internals will make you a much better user of the tool (and help you debug it when things go wrong, which they will).
Before we start, think about what pieces you'd need. You need some way to extract activations from the target model at a specific layer -- that means hooks. You need a mechanism for telling the oracle where in its input the activations should go -- that's the ? token placeholders. You need to actually inject those activations into the oracle's forward pass at the right moment -- that's the steering hook. And you need to package all of this up into a format the oracle can consume during training. We'll build each of these components in turn, and by the end you'll have a from-scratch reimplementation of the full oracle pipeline.
Activation extraction with hooks
The first step is extracting activations from the target model. We'll use PyTorch's forward hooks to intercept the residual stream at specific layers. Forward hooks let us capture intermediate activations during the forward pass by hooking into the residual stream submodule at the desired layer. We can stop early after capturing the target activations (no need to run the full model), and we need to handle batching and padding correctly.
First, we need a helper to get the right submodule for a given layer:
# Layer configuration
LAYER_COUNTS = {
"Qwen/Qwen3-1.7B": 28,
"Qwen/Qwen3-8B": 36,
"Qwen/Qwen3-32B": 64,
"google/gemma-2-9b-it": 42,
"google/gemma-3-1b-it": 26,
"meta-llama/Llama-3.2-1B-Instruct": 16,
"meta-llama/Llama-3.1-8B-Instruct": 32,
"meta-llama/Llama-3.3-70B-Instruct": 80,
}
def layer_fraction_to_layer(model_name: str, layer_fraction: float) -> int:
"""Convert a layer fraction (0.0-1.0) to a layer number."""
max_layers = LAYER_COUNTS[model_name]
return int(max_layers * layer_fraction)
def get_hf_submodule(model: AutoModelForCausalLM, layer: int) -> torch.nn.Module:
"""
Gets the residual stream submodule for HuggingFace transformers.
Args:
model: The model
layer: Which layer to hook
Returns:
The submodule to hook (the layer's output is the residual stream)
"""
model_name = model.config._name_or_path
assert re.search("gemma|mistral|Llama|Qwen", model_name), (
f"Model name {model_name!r} is not supported. Supported architectures: Gemma, Mistral, Llama, Qwen."
)
return model.model.layers[layer]
# Check it works as expected
_ = get_hf_submodule(model, layer=LAYER_COUNTS[MODEL_NAME] - 1)
with pytest.raises(IndexError):
_ = get_hf_submodule(model, layer=LAYER_COUNTS[MODEL_NAME])
Now we'll implement the activation collection function. We need a custom exception for early stopping:
class EarlyStopException(Exception):
"""Custom exception for stopping model forward pass early."""
pass
Exercise - Implement collect_activations_multiple_layers
Implement a function that collects activations from multiple layers using forward hooks. The function should:
- Register forward hooks on specified submodules
- During the hook, store the activation tensor in a dictionary
- Optionally slice activations using
start_offsetandend_offset(negative indices from end) - Raise
EarlyStopExceptionafter capturing from the last layer (no need to continue forward pass) - Clean up hooks in a finally block
A few details to keep in mind: the hook receives (module, inputs, outputs) and the outputs are what we want. Some models return (tensor, *rest) as outputs, so handle both cases. Use max_layer to determine when to stop early, and handle start_offset/end_offset by slicing: activations[:, start_offset:end_offset, :].
def collect_activations_multiple_layers(
model: AutoModelForCausalLM,
submodules: dict[int, torch.nn.Module],
inputs_BL: dict[str, Int[Tensor, "batch seq"]],
start_offset: int | None,
end_offset: int | None,
) -> dict[int, Float[Tensor, "batch seq d_model"]]:
"""
Collect activations from multiple layers using forward hooks.
Args:
model: The target model
submodules: Dict mapping layer number to submodule to hook
inputs_BL: Tokenized inputs (input_ids, attention_mask)
start_offset: Start of the token slice (negative index from end). Only used when `end_offset`
is also non-None; if `end_offset` is None, this must also be None (no slicing is applied).
end_offset: End of the token slice (negative index from end, exclusive). Set both `start_offset`
and `end_offset` to non-None values to enable token-position slicing; if both are None,
the full sequence activations are returned.
Returns:
Dict mapping layer → activations tensor [batch, length, d_model]
"""
raise NotImplementedError()
# Test the function
test_prompt = "The capital of France is"
test_inputs = tokenizer(test_prompt, return_tensors="pt", add_special_tokens=False).to(device)
# Extract from layer 18 (50% of 36 layers)
layer = layer_fraction_to_layer(MODEL_NAME, 0.5)
submodules = {layer: get_hf_submodule(model, layer)}
activations = collect_activations_multiple_layers(
model=model,
submodules=submodules,
inputs_BL=test_inputs,
start_offset=None,
end_offset=None,
)
print(f"Extracted activations from layer {layer}")
print(f"Shape: {activations[layer].shape}") # Should be [1, seq_len, d_model]
tests.test_collect_activations_multiple_layers(collect_activations_multiple_layers, model, tokenizer, device)
Solution
def collect_activations_multiple_layers(
model: AutoModelForCausalLM,
submodules: dict[int, torch.nn.Module],
inputs_BL: dict[str, Int[Tensor, "batch seq"]],
start_offset: int | None,
end_offset: int | None,
) -> dict[int, Float[Tensor, "batch seq d_model"]]:
"""
Collect activations from multiple layers using forward hooks.
Args:
model: The target model
submodules: Dict mapping layer number to submodule to hook
inputs_BL: Tokenized inputs (input_ids, attention_mask)
start_offset: Start of the token slice (negative index from end). Only used when `end_offset`
is also non-None; if `end_offset` is None, this must also be None (no slicing is applied).
end_offset: End of the token slice (negative index from end, exclusive). Set both `start_offset`
and `end_offset` to non-None values to enable token-position slicing; if both are None,
the full sequence activations are returned.
Returns:
Dict mapping layer → activations tensor [batch, length, d_model]
"""
if end_offset is not None:
assert start_offset is not None
assert start_offset < end_offset
assert end_offset < 0
assert start_offset < 0
else:
assert start_offset is None
activations_BLD_by_layer = {}
module_to_layer = {submodule: layer for layer, submodule in submodules.items()}
max_layer = max(submodules.keys())
def gather_target_act_hook(module, inputs, outputs):
layer = module_to_layer[module]
# Handle different output formats
if isinstance(outputs, tuple):
activations_BLD_by_layer[layer] = outputs[0]
else:
activations_BLD_by_layer[layer] = outputs
# Slice if requested
if end_offset is not None:
activations_BLD_by_layer[layer] = activations_BLD_by_layer[layer][:, start_offset:end_offset, :]
# Early stop after max layer
if layer == max_layer:
raise EarlyStopException("Early stopping after capturing activations")
# Register hooks
handles = []
for layer, submodule in submodules.items():
handles.append(submodule.register_forward_hook(gather_target_act_hook))
try:
with torch.no_grad():
_ = model(**inputs_BL)
except EarlyStopException:
pass # Expected
except Exception as e:
print(f"Unexpected error during forward pass: {str(e)}")
raise
finally:
# Clean up hooks
for handle in handles:
handle.remove()
return activations_BLD_by_layer
# Test the function
test_prompt = "The capital of France is"
test_inputs = tokenizer(test_prompt, return_tensors="pt", add_special_tokens=False).to(device)
# Extract from layer 18 (50% of 36 layers)
layer = layer_fraction_to_layer(MODEL_NAME, 0.5)
submodules = {layer: get_hf_submodule(model, layer)}
activations = collect_activations_multiple_layers(
model=model,
submodules=submodules,
inputs_BL=test_inputs,
start_offset=None,
end_offset=None,
)
print(f"Extracted activations from layer {layer}")
print(f"Shape: {activations[layer].shape}") # Should be [1, seq_len, d_model]
tests.test_collect_activations_multiple_layers(collect_activations_multiple_layers, model, tokenizer, device)
Special token mechanism
Oracles use special ? tokens as placeholders where target model activations will be injected. The oracle is trained to expect these tokens and knows to use the injected activations instead of its own computed activations at those positions.
The format is:
Layer: X
? ? ?
<your question>
Where:
- Layer: X tells the oracle which layer the activations came from
- ? ? ? are placeholders (one for each activation vector)
- Your question comes after
SPECIAL_TOKEN = " ?"
def get_introspection_prefix(layer: int, num_positions: int) -> str:
"""Create the prefix for oracle prompts with ? tokens."""
prefix = f"Layer: {layer}\n"
prefix += SPECIAL_TOKEN * num_positions
prefix += " \n"
return prefix
# Test it
prefix = get_introspection_prefix(layer=18, num_positions=5)
print(f"Introspection prefix:\n{prefix!r}")
Exercise - Implement find_pattern_in_tokens
Implement a function that finds the positions of special ? tokens in a tokenized sequence. This is a key piece of the oracle pipeline: when we construct an oracle prompt with ? placeholders, we need to know exactly where those placeholders ended up in the token ID sequence so we can inject the target model's activations at those positions. Getting this wrong means the oracle would receive activations at the wrong positions, silently producing garbage results. If you ever need to train your own oracle or debug injection issues, understanding this mapping between prompt text and token positions is essential.
The function should:
1. Encode the special token string to get its token ID
2. Find all positions where this token appears in the full sequence (don't stop early, since the user's oracle prompt might contain a literal ? which would create extra matches)
3. Verify we found exactly num_positions tokens (raise ValueError if not)
4. Verify they're consecutive (this is a sanity check)
5. Return the list of positions
def find_pattern_in_tokens(
token_ids: list[int],
special_token_str: str,
num_positions: int,
tokenizer: AutoTokenizer,
) -> list[int]:
"""
Find positions of special token in tokenized sequence.
Args:
token_ids: List of token IDs
special_token_str: The special token string (e.g., " ?")
num_positions: Expected number of occurrences
tokenizer: Tokenizer to encode special token
Returns:
List of positions where special token appears
"""
raise NotImplementedError()
# Test the function
test_text = "Layer: 18\n ? ? ? \nWhat is this?"
test_tokens = tokenizer.encode(test_text, add_special_tokens=False)
positions = find_pattern_in_tokens(test_tokens, SPECIAL_TOKEN, 3, tokenizer)
print(f"Found ? tokens at positions: {positions}")
tests.test_find_pattern_in_tokens(find_pattern_in_tokens, tokenizer)
Solution
def find_pattern_in_tokens(
token_ids: list[int],
special_token_str: str,
num_positions: int,
tokenizer: AutoTokenizer,
) -> list[int]:
"""
Find positions of special token in tokenized sequence.
Args:
token_ids: List of token IDs
special_token_str: The special token string (e.g., " ?")
num_positions: Expected number of occurrences
tokenizer: Tokenizer to encode special token
Returns:
List of positions where special token appears
"""
special_token_id = tokenizer.encode(special_token_str, add_special_tokens=False)
assert len(special_token_id) == 1, f"Expected single token, got {len(special_token_id)}"
special_token_id = special_token_id[0]
# Find ALL positions where this token appears (don't stop early)
positions = [i for i, tid in enumerate(token_ids) if tid == special_token_id]
if len(positions) != num_positions:
raise ValueError(
f"Expected {num_positions} occurrences of special token, but found {len(positions)}. "
f"This can happen if your oracle prompt contains a literal '?' character."
)
assert positions[-1] - positions[0] == num_positions - 1, f"Positions are not consecutive: {positions}"
return positions
# Test the function
test_text = "Layer: 18\n ? ? ? \nWhat is this?"
test_tokens = tokenizer.encode(test_text, add_special_tokens=False)
positions = find_pattern_in_tokens(test_tokens, SPECIAL_TOKEN, 3, tokenizer)
print(f"Found ? tokens at positions: {positions}")
tests.test_find_pattern_in_tokens(find_pattern_in_tokens, tokenizer)
Activation steering
Now we need to inject the target model's activations into the oracle at the ? token positions. This is done via a forward hook that intercepts the oracle's activations and replaces them at specific positions. The vectors need to be normalized to preserve original activation norms, and the injection happens at an early layer (typically layer 1) so the oracle can process them through its remaining layers.
Exercise - Implement get_hf_activation_steering_hook
Implement a function that returns a forward hook for activation steering (assuming batch_size=1). The hook should:
- Extract the residual stream tensor from outputs (handle tuple case)
- Verify batch_size is 1 (raise error if not)
- Get the original activations at the specified positions
- Normalize the steering vectors to have the same norm as the originals
- Apply steering coefficient and add to original activations
- Return modified outputs in the same format (tuple or tensor)
The core formula for each position $i$ is:
where $h_i$ is the original activation, $v_i$ is the steering vector, and $c$ is the steering coefficient. In other words: normalize the steering vector to unit norm, scale it to match the original activation's magnitude (times the coefficient), then add to the original. This norm-matching is important because the paper found that direct replacement caused 100,000x norm explosion (Appendix A.5).
Some implementation details: use torch.nn.functional.normalize(vector, dim=-1) for unit normalization. Detach steering vectors before adding to avoid gradients. Verify positions are within sequence length, and skip if L <= 1.
@contextlib.contextmanager
def add_hook(module: torch.nn.Module, hook: Callable):
"""Temporarily adds a forward hook to a model module."""
handle = module.register_forward_hook(hook)
try:
yield
finally:
handle.remove()
def get_hf_activation_steering_hook(
vectors: Float[Tensor, "num_pos d_model"],
positions: list[int],
steering_coefficient: float,
device: torch.device,
dtype: torch.dtype,
) -> Callable:
"""
Create hook that injects activations at specified positions (assumes batch_size=1).
Args:
vectors: Steering vectors [K, d_model] where K is number of positions
positions: List of positions to inject at
steering_coefficient: Multiplier for steering strength
device: Device for tensors
dtype: Data type for steering
Returns:
Hook function that modifies activations during forward pass
"""
raise NotImplementedError()
# Test the function
# Create dummy data (batch_size=1)
test_positions = [5, 6, 7] # Inject at positions 5, 6, 7
test_vectors = torch.randn(len(test_positions), model.config.hidden_size, device=device)
hook_fn = get_hf_activation_steering_hook(
vectors=test_vectors,
positions=test_positions,
steering_coefficient=1.0,
device=device,
dtype=dtype,
)
# Create dummy activations
dummy_resid = torch.randn(1, 20, model.config.hidden_size, device=device)
orig_values = dummy_resid[0, test_positions, :].clone()
# Apply hook
modified_resid = hook_fn(None, None, dummy_resid)
# Check modifications occurred
new_values = modified_resid[0, test_positions[0], :]
assert not torch.allclose(orig_values, new_values), "Hook should modify activations"
print("Steering hook test passed!")
tests.test_get_hf_activation_steering_hook(get_hf_activation_steering_hook, device, model.config.hidden_size)
tests.test_get_hf_activation_steering_hook_matches_reference(
get_hf_activation_steering_hook, device, model.config.hidden_size
)
Solution
@contextlib.contextmanager
def add_hook(module: torch.nn.Module, hook: Callable):
"""Temporarily adds a forward hook to a model module."""
handle = module.register_forward_hook(hook)
try:
yield
finally:
handle.remove()
def get_hf_activation_steering_hook(
vectors: Float[Tensor, "num_pos d_model"],
positions: list[int],
steering_coefficient: float,
device: torch.device,
dtype: torch.dtype,
) -> Callable:
"""
Create hook that injects activations at specified positions (assumes batch_size=1).
Args:
vectors: Steering vectors [K, d_model] where K is number of positions
positions: List of positions to inject at
steering_coefficient: Multiplier for steering strength
device: Device for tensors
dtype: Data type for steering
Returns:
Hook function that modifies activations during forward pass
"""
# Normalize vectors to unit norm
normed_vectors = torch.nn.functional.normalize(vectors, dim=-1).detach()
positions_tensor = torch.tensor(positions, dtype=torch.long, device=device)
def hook_fn(module, _input, output):
# Extract residual stream tensor
if isinstance(output, tuple):
resid_BLD, *rest = output
output_is_tuple = True
else:
resid_BLD = output
output_is_tuple = False
B, L, d_model = resid_BLD.shape
if B != 1:
raise ValueError(f"Expected batch_size=1, got B={B}")
if L <= 1:
return (resid_BLD, *rest) if output_is_tuple else resid_BLD
# Check positions are valid
assert positions_tensor.min() >= 0
assert positions_tensor.max() < L, f"Position {positions_tensor.max()} >= sequence length {L}"
# Get original activations at steering positions
orig_KD = resid_BLD[0, positions_tensor, :] # [K, d_model]
norms_K1 = orig_KD.norm(dim=-1, keepdim=True) # [K, 1]
# Scale normalized steering vectors by original magnitudes
steered_KD = (normed_vectors * norms_K1 * steering_coefficient).to(dtype)
# Inject (add to original)
resid_BLD[0, positions_tensor, :] = steered_KD.detach() + orig_KD
return (resid_BLD, *rest) if output_is_tuple else resid_BLD
return hook_fn
# Test the function
# Create dummy data (batch_size=1)
test_positions = [5, 6, 7] # Inject at positions 5, 6, 7
test_vectors = torch.randn(len(test_positions), model.config.hidden_size, device=device)
hook_fn = get_hf_activation_steering_hook(
vectors=test_vectors,
positions=test_positions,
steering_coefficient=1.0,
device=device,
dtype=dtype,
)
# Create dummy activations
dummy_resid = torch.randn(1, 20, model.config.hidden_size, device=device)
orig_values = dummy_resid[0, test_positions, :].clone()
# Apply hook
modified_resid = hook_fn(None, None, dummy_resid)
# Check modifications occurred
new_values = modified_resid[0, test_positions[0], :]
assert not torch.allclose(orig_values, new_values), "Hook should modify activations"
print("Steering hook test passed!")
tests.test_get_hf_activation_steering_hook(get_hf_activation_steering_hook, device, model.config.hidden_size)
tests.test_get_hf_activation_steering_hook_matches_reference(
get_hf_activation_steering_hook, device, model.config.hidden_size
)
Training datapoint format
Before we assemble the full pipeline, let's understand the OracleInput structure that oracles use. This format is used both during oracle training and during inference.
@dataclass
class OracleInput:
"""Simplified datapoint for oracle inference (no training-specific fields)."""
input_ids: list[int]
layer: int
steering_vectors: Float[Tensor, "num_pos d_model"]
positions: list[int]
@dataclass
class OracleResults:
oracle_lora_path: str | None
target_lora_path: str | None
target_prompt: str
act_key: str
oracle_prompt: str
num_tokens: int
token_responses: list[str | None]
full_sequence_responses: list[str]
segment_responses: list[str]
target_input_ids: list[int]
Exercise - Implement create_oracle_input
Implement a function that creates an OracleInput for oracle inference. The function should:
- Add the introspection prefix (with
?tokens) to the prompt - Format using the chat template with
add_generation_prompt=True(ends with<|im_start|>assistant\n) - Find
?token positions in the tokenized sequence - Return an
OracleInputwith all fields filled in
Use tokenizer.apply_chat_template() with add_generation_prompt=True and enable_thinking=False for inference. The activations should be cloned and detached to CPU.
def create_oracle_input(
prompt: str,
layer: int,
num_positions: int,
tokenizer: AutoTokenizer,
acts_BD: Float[Tensor, "num_pos d_model"],
) -> OracleInput:
"""
Create an oracle input for inference.
Args:
prompt: Question to ask the oracle
layer: Layer the activations came from
num_positions: Number of ? tokens (equals length of acts_BD)
tokenizer: Tokenizer
acts_BD: Activation vectors [num_positions, d_model]
Returns:
OracleInput ready for generation
"""
raise NotImplementedError()
# Test the function
test_activations = torch.randn(3, model.config.hidden_size)
datapoint = create_oracle_input(
prompt="What is the model thinking about?",
layer=18,
num_positions=3,
tokenizer=tokenizer,
acts_BD=test_activations,
)
print(f"Created datapoint with {len(datapoint.input_ids)} tokens")
print(f"? tokens at positions: {datapoint.positions}")
tests.test_create_oracle_input(create_oracle_input, tokenizer, model.config.hidden_size)
Solution
def create_oracle_input(
prompt: str,
layer: int,
num_positions: int,
tokenizer: AutoTokenizer,
acts_BD: Float[Tensor, "num_pos d_model"],
) -> OracleInput:
"""
Create an oracle input for inference.
Args:
prompt: Question to ask the oracle
layer: Layer the activations came from
num_positions: Number of ? tokens (equals length of acts_BD)
tokenizer: Tokenizer
acts_BD: Activation vectors [num_positions, d_model]
Returns:
OracleInput ready for generation
"""
# Add introspection prefix with ? tokens
prefix = get_introspection_prefix(layer, num_positions)
prompt = prefix + prompt
input_messages = [{"role": "user", "content": prompt}]
# Create prompt with generation template (ends with <|im_start|>assistant\n)
input_prompt_ids = tokenizer.apply_chat_template(
input_messages,
tokenize=True,
add_generation_prompt=True,
return_tensors=None,
padding=False,
enable_thinking=False,
)
# Find ? token positions in the prompt
positions = find_pattern_in_tokens(input_prompt_ids, SPECIAL_TOKEN, num_positions, tokenizer)
# Ensure activations are on CPU and detached
acts_BD = acts_BD.cpu().clone().detach()
return OracleInput(
input_ids=input_prompt_ids,
layer=layer,
steering_vectors=acts_BD,
positions=positions,
)
# Test the function
test_activations = torch.randn(3, model.config.hidden_size)
datapoint = create_oracle_input(
prompt="What is the model thinking about?",
layer=18,
num_positions=3,
tokenizer=tokenizer,
acts_BD=test_activations,
)
print(f"Created datapoint with {len(datapoint.input_ids)} tokens")
print(f"? tokens at positions: {datapoint.positions}")
tests.test_create_oracle_input(create_oracle_input, tokenizer, model.config.hidden_size)
Assembling the full oracle pipeline
Now we'll combine all the components to build our own version of utils.run_oracle(). This is a significant exercise that brings everything together.
Exercise - Build utils.run_oracle() from components
Implement a simplified version of utils.run_oracle() that only handles full-sequence queries. We've provided the scaffolding code that handles tokenization, padding extraction, batch construction, and response decoding. Your job is to fill in the core pipeline:
- Collect activations: switch to the base model adapter (
model.set_adapter("default")), then usecollect_activations_multiple_layers()to extract activations from the target model - Create the oracle input: use
create_oracle_input()with the extracted activations - Build and apply the steering hook: create a hook with
get_hf_activation_steering_hook(), switch to the oracle adapter (model.set_adapter("oracle")), and generate with the hook applied using theadd_hookcontext manager
We inject activations at layer 1 of the oracle (not the layer we extracted from). This is by design: injecting early gives the oracle's remaining layers maximum depth to process the injected information, rather than inserting it partway through and giving the oracle fewer layers to work with.
def run_oracle(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
target_prompt: str,
oracle_prompt: str,
layer_fraction: float = 0.5,
device: torch.device = device,
) -> str:
"""
Run oracle query from scratch using components we built.
Args:
model: Model with oracle LoRA loaded
tokenizer: Tokenizer
target_prompt: Prompt to analyze (already formatted with chat template)
oracle_prompt: Question to ask about activations
layer_fraction: Which layer to extract from (as fraction of total, 0.0-1.0)
device: Device
Returns:
Oracle's response as string
"""
# For oracle sampling
generation_kwargs = {"do_sample": False, "temperature": 0.0, "max_new_tokens": 50}
# Tokenize target prompt and extract non-padding positions
inputs_BL = tokenizer(target_prompt, return_tensors="pt", add_special_tokens=False).to(device)
model_name = model.config._name_or_path
act_layer = layer_fraction_to_layer(model_name, layer_fraction)
submodules = {act_layer: get_hf_submodule(model, act_layer)}
seq_len = inputs_BL["input_ids"].shape[1]
attn_mask = inputs_BL["attention_mask"][0]
real_len = int(attn_mask.sum().item())
left_pad = seq_len - real_len
# YOUR CODE HERE - fill in the 3 steps described in the exercise:
# (1) Collect activations from the target model (switch to "default" adapter first)
# (2) Create an OracleInput using create_oracle_input()
# (3) Build a steering hook, switch to "oracle" adapter, generate with the hook applied
raise NotImplementedError()
# Decode response
generated_tokens = output_ids[:, input_ids.shape[1] :]
response = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return response
# Test our implementation
target_prompt_dict = [{"role": "user", "content": "The capital of France is"}]
target_prompt = tokenizer.apply_chat_template(target_prompt_dict, tokenize=False, add_generation_prompt=True)
oracle_prompt = "What answer will the model give, as a single token?"
our_response = run_oracle(
model=model,
tokenizer=tokenizer,
target_prompt=target_prompt,
oracle_prompt=oracle_prompt,
layer_fraction=0.5,
device=device,
)
print(f"Our implementation response: {our_response!r}")
# Compare to library version
library_results = utils.run_oracle(
model=model,
tokenizer=tokenizer,
device=device,
target_prompt=target_prompt,
target_lora_path=None,
oracle_prompt=oracle_prompt,
oracle_lora_path="oracle",
oracle_input_type="full_seq",
)
library_response = library_results.full_sequence_responses[0]
print(f"Library response: {library_response!r}")
assert our_response.strip().lower() == library_response.strip().lower()
Click to see the expected output
Our implementation response: 'Paris' Library response: 'Paris'
Solution
def run_oracle(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
target_prompt: str,
oracle_prompt: str,
layer_fraction: float = 0.5,
device: torch.device = device,
) -> str:
"""
Run oracle query from scratch using components we built.
Args:
model: Model with oracle LoRA loaded
tokenizer: Tokenizer
target_prompt: Prompt to analyze (already formatted with chat template)
oracle_prompt: Question to ask about activations
layer_fraction: Which layer to extract from (as fraction of total, 0.0-1.0)
device: Device
Returns:
Oracle's response as string
"""
# For oracle sampling
generation_kwargs = {"do_sample": False, "temperature": 0.0, "max_new_tokens": 50}
# Tokenize target prompt and extract non-padding positions
inputs_BL = tokenizer(target_prompt, return_tensors="pt", add_special_tokens=False).to(device)
model_name = model.config._name_or_path
act_layer = layer_fraction_to_layer(model_name, layer_fraction)
submodules = {act_layer: get_hf_submodule(model, act_layer)}
seq_len = inputs_BL["input_ids"].shape[1]
attn_mask = inputs_BL["attention_mask"][0]
real_len = int(attn_mask.sum().item())
left_pad = seq_len - real_len
# Step 1: Collect activations from the target model
model.set_adapter("default")
acts_by_layer = collect_activations_multiple_layers(
model=model,
submodules=submodules,
inputs_BL=inputs_BL,
start_offset=None,
end_offset=None,
)
# Extract activations for all non-padding positions
num_positions = real_len
acts_BD = acts_by_layer[act_layer][0, left_pad:, :] # [num_positions, d_model]
# Step 2: Create oracle input
datapoint = create_oracle_input(
prompt=oracle_prompt,
layer=act_layer,
num_positions=num_positions,
tokenizer=tokenizer,
acts_BD=acts_BD,
)
# Step 3: Build steering hook, switch to oracle, generate
input_ids = torch.tensor([datapoint.input_ids], dtype=torch.long, device=device)
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
steering_vectors = datapoint.steering_vectors.to(device)
positions = datapoint.positions
injection_layer = 1 # Inject at layer 1 (gives oracle max processing depth)
injection_submodule = get_hf_submodule(model, injection_layer)
hook_fn = get_hf_activation_steering_hook(
vectors=steering_vectors,
positions=positions,
steering_coefficient=1.0,
device=device,
dtype=dtype,
)
model.set_adapter("oracle")
with add_hook(injection_submodule, hook_fn):
output_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, **generation_kwargs)
# Decode response
generated_tokens = output_ids[:, input_ids.shape[1] :]
response = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return response
# Test our implementation
target_prompt_dict = [{"role": "user", "content": "The capital of France is"}]
target_prompt = tokenizer.apply_chat_template(target_prompt_dict, tokenize=False, add_generation_prompt=True)
oracle_prompt = "What answer will the model give, as a single token?"
our_response = run_oracle(
model=model,
tokenizer=tokenizer,
target_prompt=target_prompt,
oracle_prompt=oracle_prompt,
layer_fraction=0.5,
device=device,
)
print(f"Our implementation response: {our_response!r}")
# Compare to library version
library_results = utils.run_oracle(
model=model,
tokenizer=tokenizer,
device=device,
target_prompt=target_prompt,
target_lora_path=None,
oracle_prompt=oracle_prompt,
oracle_lora_path="oracle",
oracle_input_type="full_seq",
)
library_response = library_results.full_sequence_responses[0]
print(f"Library response: {library_response!r}")
assert our_response.strip().lower() == library_response.strip().lower()
You've now built the entire oracle pipeline from scratch, which means none of it is magic anymore. You should now have a concrete mental model for what happens when you call run_oracle(): activation extraction via hooks, the ? token placeholder mechanism, norm-matched steering, and adapter switching. In the next section, we'll put this machinery to work on alignment-relevant problems, starting with the question of whether you can extract information that a model has been deliberately trained to hide.