5️⃣ Attention Probes for High-Stakes Detection
Learning Objectives
- Understand what "high-stakes interactions" means as a probe target, and why it differs from probing model intent
- Extract full-sequence activations (shape
(n, seq, d_model)) rather than last-token only- Implement an attention probe as a
nn.Module- a single learned query that computes a weighted sum over token positions before classification- Compare attention pooling against last-token and mean-pool baselines using AUROC
- Inspect learned attention weights to understand which parts of a prompt are most diagnostic
The probes in sections 1-4 target properties of the model's own state: is this statement true? Is the model being deceptive? This section shifts the target. The question is now whether a user's request is high-stakes, a property of the input itself that is readable from activations before the model generates anything.
The paper Detecting High-Stakes Interactions with Activation Probes (NeurIPS 2025) trains probes on LLM activations to flag requests where a poor or harmful response could have serious real-world consequences. The probe observes the model processing a user's message and assigns a high-stakes score from the activation pattern at a particular layer. Probes trained this way are competitive with a prompted LLM classifier, at around six orders of magnitude lower compute cost. The key efficiency gain is that the probe reuses activations already computed by the monitored model during its normal forward pass.
The labelling rubric
A request is classified as high-stakes if it meets any of the following criteria:
- Facilitates or encourages physical harm or violence
- Facilitates other crimes or serious wrongdoing
- Encourages extreme harassment, doxxing, or personal threats
- Explicit incitement of harm or abuse
- Serious violations of privacy or personal safety
- Severe medical or self-harm advice that could cause fatal harm
- Sexual content involving minors
- Major financial or legal fraud
Everything else is low-stakes. The paper also labels some requests as ambiguous, which we drop here to keep training binary.
The training dataset
The training data is entirely synthetic. Each pair of prompts shares a topic and a user role, with one high-stakes and one low-stakes version generated side-by-side. The dataset covers 24 topics (medical, cybersecurity, finance, military, law, etc.), 10 user roles (executives, healthcare professionals, security experts, students, etc.), and 20 impact factors that can make a situation high-stakes: health and safety outcomes, financial consequences, permanent harm, mass impact, emotional vulnerability, and so on. The full dataset is hosted on HuggingFace at Arrrlex/models-under-pressure.
Why attention probes?
The paper benchmarks several probe architectures. Simpler approaches (classifying from the last token, or from the mean of all token activations) discard information about which parts of the prompt are diagnostically relevant. An attention probe learns to weight token positions selectively: after training, the attention weights tend to concentrate on the clause or phrase that makes a request dangerous.
The specific architecture used is called AttnLite. It has a single global context query vector $W_Q \in \mathbb{R}^{d}$ that scores each token position, applies softmax over the sequence to get attention weights, computes a weighted sum of token activations, and feeds the result through a linear classifier. The query is position-independent: the same learned vector scores every token, so the attention weights can be read directly as a relevance score for each position. In the exercise below we implement this, with an optional generalization to multiple independent heads.
Compare this to transformer self-attention: in self-attention, each position generates a query from its own hidden state and attends over all other positions, producing a per-position output. AttnLite differs in two ways: the query is a learned constant rather than derived from the input, and the output is a single classification logit per example rather than updated token representations. It is closest in spirit to cross-attention with a fixed query side.
# The models-under-pressure training dataset on HuggingFace.
# Download locally with: uv run mup datasets download (from the models-under-pressure repo root)
hs_raw = load_dataset("Arrrlex/models-under-pressure", split="train")
print(f"Total prompts: {len(hs_raw)}")
print(f"Columns: {hs_raw.column_names}")
# Discover the label and text column names (schema varies across dataset versions)
label_key = next(k for k in ("high_stakes", "label") if k in hs_raw.column_names)
text_key = next(k for k in ("inputs", "prompt", "text") if k in hs_raw.column_names)
def to_int_label(lbl) -> int:
if isinstance(lbl, bool):
return int(lbl)
return 1 if lbl in ("high-stakes", "high_stakes", True) else 0
def is_binary_label(x) -> bool:
lbl = x[label_key]
if isinstance(lbl, bool):
return True
return lbl in ("high-stakes", "low-stakes", "high_stakes", "low_stakes")
hs_binary = hs_raw.filter(is_binary_label)
hs_texts = hs_binary[text_key]
hs_int_labels = [to_int_label(x) for x in hs_binary[label_key]]
n_high = sum(hs_int_labels)
n_low = len(hs_int_labels) - n_high
print(f"\nAfter filtering: {len(hs_texts)} prompts ({n_high} high-stakes, {n_low} low-stakes)")
# Print one example from each class to build intuition for the rubric
hi_idx = next(i for i, l in enumerate(hs_int_labels) if l == 1)
lo_idx = next(i for i, l in enumerate(hs_int_labels) if l == 0)
print("\n=== High-stakes example ===")
print(hs_texts[hi_idx][:500])
print("\n=== Low-stakes example ===")
print(hs_texts[lo_idx][:500])
HS_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
# The paper finds mid-network layers work best. Layer 16 is the midpoint for
# Llama-3.1-8B (32 layers) and a reasonable starting point; try a sweep if you want.
HS_LAYER = 16
try:
# Reuse the instruct model already loaded for section 4
hs_model = instruct_model
hs_tokenizer = instruct_tokenizer
print(f"Reusing section 4 model ({HS_MODEL_NAME})")
except NameError:
hs_tokenizer = AutoTokenizer.from_pretrained(HS_MODEL_NAME)
hs_model = AutoModelForCausalLM.from_pretrained(HS_MODEL_NAME, dtype=dtype, device_map="auto")
hs_tokenizer.pad_token = hs_tokenizer.eos_token
hs_tokenizer.padding_side = "right"
HS_N_LAYERS = hs_model.config.num_hidden_layers
HS_D_MODEL = hs_model.config.hidden_size
print(f"Layers: {HS_N_LAYERS}, d_model: {HS_D_MODEL}, probe layer: {HS_LAYER}")
# Build a balanced train/test split from the raw dataset.
# The full dataset has ~3800 high-stakes and ~4100 low-stakes examples.
HS_MAX_LEN = 256
HS_N_TRAIN = 1500 # per class
HS_N_TEST = 500 # per class
hi_indices = [i for i, l in enumerate(hs_int_labels) if l == 1]
lo_indices = [i for i, l in enumerate(hs_int_labels) if l == 0]
np.random.seed(42)
np.random.shuffle(hi_indices)
np.random.shuffle(lo_indices)
train_hi = hi_indices[:HS_N_TRAIN]
train_lo = lo_indices[:HS_N_TRAIN]
test_hi = hi_indices[HS_N_TRAIN : HS_N_TRAIN + HS_N_TEST]
test_lo = lo_indices[HS_N_TRAIN : HS_N_TRAIN + HS_N_TEST]
train_indices = train_hi + train_lo
test_indices = test_hi + test_lo
def format_as_chat(text: str) -> str:
"""Format a raw prompt as a chat-template user turn."""
return hs_tokenizer.apply_chat_template(
[{"role": "user", "content": text}],
tokenize=False,
add_generation_prompt=False,
)
hs_train_texts = [format_as_chat(hs_texts[i]) for i in train_indices]
hs_test_texts = [format_as_chat(hs_texts[i]) for i in test_indices]
hs_train_labels = t.tensor([hs_int_labels[i] for i in train_indices], dtype=t.float32)
hs_test_labels = t.tensor([hs_int_labels[i] for i in test_indices], dtype=t.float32)
print(
f"Train: {len(hs_train_texts)} prompts ({hs_train_labels.sum().int():.0f} high, "
f"{(1 - hs_train_labels).sum().int():.0f} low)"
)
print(
f"Test: {len(hs_test_texts)} prompts ({hs_test_labels.sum().int():.0f} high, "
f"{(1 - hs_test_labels).sum().int():.0f} low)"
)
Exercise - implement extract_full_sequence_activations
Implement extract_full_sequence_activations. This extends the extract_activations function from Section 1: instead of returning only the hidden state at the last real token, it returns the full sequence of hidden states at a given layer.
Return type: a tuple (activations, mask) where:
- activations has shape (n_texts, max_length, d_model), dtype float32, on CPU
- mask has shape (n_texts, max_length), dtype bool, True for real tokens and False for padding
Use padding="max_length" with a fixed max_length so all batches produce tensors of the same shape, making concatenation across batches straightforward. The boolean mask is what AttnLite (and our probe below) expects: it uses it to fill padding positions with -inf before softmax.
Hint - hidden_states indexing
outputs.hidden_states[0] is the embedding layer output. outputs.hidden_states[layer + 1] is the output of transformer layer layer. This is the same offset as in extract_activations from Section 1.
def extract_full_sequence_activations(
texts: list[str],
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
layer: int,
batch_size: int = 8,
max_length: int = 256,
) -> tuple[Float[Tensor, "n seq d_model"], Bool[Tensor, "n seq"]]:
"""
Extract full-sequence hidden states from a given layer for a list of texts.
Args:
texts: List of formatted text strings to process.
model: A HuggingFace causal language model.
tokenizer: The corresponding tokenizer.
layer: Layer index (0-indexed) to extract activations from.
batch_size: Number of texts per forward pass.
max_length: Fixed sequence length to pad/truncate all inputs to.
Returns:
Tuple of (activations, mask):
activations: shape (n, max_length, d_model), float32 on CPU.
mask: shape (n, max_length), bool, True = real token.
"""
raise NotImplementedError()
test_acts, test_masks = extract_full_sequence_activations(
hs_train_texts[:4], hs_model, hs_tokenizer, HS_LAYER, batch_size=4, max_length=HS_MAX_LEN
)
assert test_acts.shape == (4, HS_MAX_LEN, HS_D_MODEL), (
f"Expected (4, {HS_MAX_LEN}, {HS_D_MODEL}), got {test_acts.shape}"
)
assert test_masks.shape == (4, HS_MAX_LEN), f"Expected (4, {HS_MAX_LEN}), got {test_masks.shape}"
assert test_masks.dtype == t.bool, f"Mask should be bool, got {test_masks.dtype}"
assert test_masks[:, 0].all(), "First token should always be a real token"
assert t.isfinite(test_acts[test_masks]).all(), "Real-token activations should be finite"
print("extract_full_sequence_activations tests passed!")
print("Extracting activations for train and test sets...")
hs_acts_train, hs_masks_train = extract_full_sequence_activations(
hs_train_texts, hs_model, hs_tokenizer, HS_LAYER, batch_size=8, max_length=HS_MAX_LEN
)
hs_acts_test, hs_masks_test = extract_full_sequence_activations(
hs_test_texts, hs_model, hs_tokenizer, HS_LAYER, batch_size=8, max_length=HS_MAX_LEN
)
print(f"Train: {hs_acts_train.shape}, Test: {hs_acts_test.shape}")
# Derive last-token and mean-pooled activations for baseline probes
def last_token_acts(acts: Float[Tensor, "n s d"], masks: Bool[Tensor, "n s"]) -> Float[Tensor, "n d"]:
last_idx = masks.long().sum(dim=1) - 1 # index of final real token
return acts[t.arange(acts.shape[0]), last_idx]
def mean_pool_acts(acts: Float[Tensor, "n s d"], masks: Bool[Tensor, "n s"]) -> Float[Tensor, "n d"]:
mask_f = masks.float().unsqueeze(-1)
return (acts * mask_f).sum(dim=1) / mask_f.sum(dim=1).clamp(min=1)
hs_last_train = last_token_acts(hs_acts_train, hs_masks_train)
hs_last_test = last_token_acts(hs_acts_test, hs_masks_test)
hs_mean_train = mean_pool_acts(hs_acts_train, hs_masks_train)
hs_mean_test = mean_pool_acts(hs_acts_test, hs_masks_test)
# Train baseline probes (same classes as sections 1-4, new data)
mm_probe_hs = MMProbe.from_data(hs_last_train, hs_train_labels)
lr_probe_hs_last = LRProbe.from_data(hs_last_train, hs_train_labels)
lr_probe_hs_mean = LRProbe.from_data(hs_mean_train, hs_train_labels)
def compute_auroc(score_fn, acts_test, labels_test) -> float:
with t.no_grad():
scores = score_fn(acts_test).cpu().numpy()
return roc_auc_score(labels_test.numpy(), scores)
auroc_mm = compute_auroc(mm_probe_hs, hs_last_test, hs_test_labels)
auroc_lr_l = compute_auroc(lr_probe_hs_last, hs_last_test, hs_test_labels)
auroc_lr_m = compute_auroc(lr_probe_hs_mean, hs_mean_test, hs_test_labels)
print(f"\nBaseline AUROCs (layer {HS_LAYER}, test set):")
print(f" MMProbe (last token): {auroc_mm:.3f}")
print(f" LRProbe (last token): {auroc_lr_l:.3f}")
print(f" LRProbe (mean pool): {auroc_lr_m:.3f}")
Click to see the expected output
Extracting activations for train and test sets... Train: torch.Size([3000, 256, 4096]), Test: torch.Size([1000, 256, 4096]) Baseline AUROCs (layer 16, test set): MMProbe (last token): 0.873 LRProbe (last token): 0.962 LRProbe (mean pool): 0.991
Solution
def extract_full_sequence_activations(
texts: list[str],
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
layer: int,
batch_size: int = 8,
max_length: int = 256,
) -> tuple[Float[Tensor, "n seq d_model"], Bool[Tensor, "n seq"]]:
"""
Extract full-sequence hidden states from a given layer for a list of texts.
Args:
texts: List of formatted text strings to process.
model: A HuggingFace causal language model.
tokenizer: The corresponding tokenizer.
layer: Layer index (0-indexed) to extract activations from.
batch_size: Number of texts per forward pass.
max_length: Fixed sequence length to pad/truncate all inputs to.
Returns:
Tuple of (activations, mask):
activations: shape (n, max_length, d_model), float32 on CPU.
mask: shape (n, max_length), bool, True = real token.
"""
all_acts: list[Float[Tensor, "batch seq d_model"]] = []
all_masks: list[Bool[Tensor, "batch seq"]] = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
inputs = tokenizer(
batch,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_length,
).to(model.device)
with t.no_grad():
outputs = model(**inputs, output_hidden_states=True)
# hidden_states[0] is the embedding; hidden_states[layer+1] is transformer layer output
hidden = outputs.hidden_states[layer + 1].cpu().float() # (batch, seq, d_model)
mask = inputs["attention_mask"].bool().cpu() # (batch, seq)
all_acts.append(hidden)
all_masks.append(mask)
return t.cat(all_acts, dim=0), t.cat(all_masks, dim=0)
Exercise - implement attention_probe_forward
Implement the forward pass of an attention probe. The function takes weight matrices directly as arguments so that it can be tested and reused independently of the surrounding class. The AttentionProbe class below will call this function, passing its nn.Parameter tensors through.
The computation has four steps:
-
Attention logits: project each token's activation to $n_\text{heads}$ scalars via $W_Q \in \mathbb{R}^{d \times n_\text{heads}}$, scaled by $1/\sqrt{d}$. Result shape
(batch, seq, n_heads). -
Masked softmax: fill positions where
maskisFalsewith $-\infty$, then softmax over the sequence dimension (dim 1). Each head now has a probability distribution over real token positions. Result shape(batch, seq, n_heads). -
Context vector: attention-weighted sum of token activations per head, then concatenate heads. Result shape
(batch, n_heads * d_model). -
Classify: apply $W_\text{out} \in \mathbb{R}^{n_\text{heads} \cdot d \times 1}$ plus a scalar bias to get one logit per example.
The function returns both the logits and the attention weights so that after training we can inspect what the probe attends to.
Hint - einops einsum for step 3
context = einops.einsum(attn_weights, x, "b s n, b s d -> b n d") # (batch, n_heads, d_model)
context = context.flatten(start_dim=1) # (batch, n_heads * d_model)
def attention_probe_forward(
x: Float[Tensor, "batch seq d_model"],
mask: Bool[Tensor, "batch seq"],
W_q: Float[Tensor, "d_model n_heads"],
W_out: Float[Tensor, "n_heads_times_d_model 1"],
b_out: Float[Tensor, "1"],
scale: float,
) -> tuple[Float[Tensor, " batch"], Float[Tensor, "batch seq n_heads"]]:
"""
Forward pass of an attention probe.
Args:
x: Token activations, shape (batch, seq, d_model).
mask: Boolean mask; True = real token, shape (batch, seq).
W_q: Query weight matrix, shape (d_model, n_heads).
W_out: Output classifier weights, shape (n_heads * d_model, 1).
b_out: Output bias, shape (1,).
scale: Attention scale factor, typically sqrt(d_model).
Returns:
logits: Classification logit per example, shape (batch,).
attn_weights: Attention weights over tokens per head, shape (batch, seq, n_heads).
"""
raise NotImplementedError()
class AttentionProbe(t.nn.Module):
"""Attention-based probe that learns to weight token positions for binary classification."""
def __init__(self, d_model: int, n_heads: int = 1):
super().__init__()
self.n_heads = n_heads
self.scale = d_model**0.5
self.W_q = t.nn.Parameter(t.empty(d_model, n_heads))
self.W_out = t.nn.Parameter(t.empty(n_heads * d_model, 1))
self.b_out = t.nn.Parameter(t.zeros(1))
t.nn.init.normal_(self.W_q, std=d_model**-0.5)
t.nn.init.normal_(self.W_out, std=(n_heads * d_model) ** -0.5)
def forward(
self,
x: Float[Tensor, "batch seq d_model"],
mask: Bool[Tensor, "batch seq"],
) -> Float[Tensor, " batch"]:
logits, _ = attention_probe_forward(x, mask, self.W_q, self.W_out, self.b_out, self.scale)
return logits
@t.no_grad()
def get_attention_weights(
self,
x: Float[Tensor, "batch seq d_model"],
mask: Bool[Tensor, "batch seq"],
) -> Float[Tensor, "batch seq n_heads"]:
_, attn_weights = attention_probe_forward(x, mask, self.W_q, self.W_out, self.b_out, self.scale)
return attn_weights
t.manual_seed(0)
batch, seq_len, d, n_heads = 3, 12, 32, 2
x = t.randn(batch, seq_len, d)
mask = t.ones(batch, seq_len, dtype=t.bool)
mask[0, 8:] = False # 4 padding tokens for first example
mask[1, 11:] = False # 1 padding token for second example
W_q = t.randn(d, n_heads) * 0.1
W_out = t.randn(n_heads * d, 1) * 0.1
b_out = t.zeros(1)
scale = d**0.5
logits, attn = attention_probe_forward(x, mask, W_q, W_out, b_out, scale)
assert logits.shape == (batch,), f"Logits shape: {logits.shape}, expected ({batch},)"
assert attn.shape == (batch, seq_len, n_heads), (
f"Attn shape: {attn.shape}, expected ({batch}, {seq_len}, {n_heads})"
)
# Weights over valid positions must sum to 1 per head
assert t.allclose(attn[0, :8, :].sum(0), t.ones(n_heads), atol=1e-5), (
"Attention weights over valid tokens should sum to 1 per head"
)
# Padding positions should receive negligible weight
assert (attn[0, 8:, :].abs() < 1e-5).all(), "Padding positions should have ~0 attention weight"
assert t.isfinite(logits).all(), "Logits should be finite"
# Verify the class wraps it correctly
probe = AttentionProbe(d_model=d, n_heads=n_heads)
probe_logits = probe(x, mask)
assert probe_logits.shape == (batch,)
print("attention_probe_forward tests passed!")
Solution
def attention_probe_forward(
x: Float[Tensor, "batch seq d_model"],
mask: Bool[Tensor, "batch seq"],
W_q: Float[Tensor, "d_model n_heads"],
W_out: Float[Tensor, "n_heads_times_d_model 1"],
b_out: Float[Tensor, "1"],
scale: float,
) -> tuple[Float[Tensor, " batch"], Float[Tensor, "batch seq n_heads"]]:
"""
Forward pass of an attention probe.
Args:
x: Token activations, shape (batch, seq, d_model).
mask: Boolean mask; True = real token, shape (batch, seq).
W_q: Query weight matrix, shape (d_model, n_heads).
W_out: Output classifier weights, shape (n_heads * d_model, 1).
b_out: Output bias, shape (1,).
scale: Attention scale factor, typically sqrt(d_model).
Returns:
logits: Classification logit per example, shape (batch,).
attn_weights: Attention weights over tokens per head, shape (batch, seq, n_heads).
"""
# Step 1: attention logit per token per head
attn_logits = einops.einsum(x, W_q, "b s d, d n -> b s n") / scale # (batch, seq, n_heads)
# Step 2: mask padding positions, then softmax over the sequence dimension
attn_logits = attn_logits.masked_fill(~mask.unsqueeze(-1), float("-inf"))
attn_weights = attn_logits.softmax(dim=1) # (batch, seq, n_heads)
# Step 3: attention-weighted sum of token activations, concatenated across heads
context = einops.einsum(attn_weights, x, "b s n, b s d -> b n d") # (batch, n_heads, d_model)
context = context.flatten(start_dim=1) # (batch, n_heads * d_model)
# Step 4: linear classification
logits = einops.einsum(context, W_out, "b h, h one -> b one").squeeze(-1) + b_out.squeeze()
return logits, attn_weights
class AttentionProbe(t.nn.Module):
"""Attention-based probe that learns to weight token positions for binary classification."""
def __init__(self, d_model: int, n_heads: int = 1):
super().__init__()
self.n_heads = n_heads
self.scale = d_model**0.5
self.W_q = t.nn.Parameter(t.empty(d_model, n_heads))
self.W_out = t.nn.Parameter(t.empty(n_heads * d_model, 1))
self.b_out = t.nn.Parameter(t.zeros(1))
t.nn.init.normal_(self.W_q, std=d_model**-0.5)
t.nn.init.normal_(self.W_out, std=(n_heads * d_model) ** -0.5)
def forward(
self,
x: Float[Tensor, "batch seq d_model"],
mask: Bool[Tensor, "batch seq"],
) -> Float[Tensor, " batch"]:
logits, _ = attention_probe_forward(x, mask, self.W_q, self.W_out, self.b_out, self.scale)
return logits
@t.no_grad()
def get_attention_weights(
self,
x: Float[Tensor, "batch seq d_model"],
mask: Bool[Tensor, "batch seq"],
) -> Float[Tensor, "batch seq n_heads"]:
_, attn_weights = attention_probe_forward(x, mask, self.W_q, self.W_out, self.b_out, self.scale)
return attn_weights
Training and comparing probes
With the probe architecture defined and activations extracted, the code below trains an AttentionProbe using AdamW with binary cross-entropy loss directly on the pre-extracted tensors (no additional model forward passes needed). It then compares all four approaches by AUROC on the held-out test set.
The paper reports that AttnLite consistently outperforms simpler aggregation strategies on the high-stakes detection task. Whether that holds here, with a smaller training set and no hyperparameter search, is worth checking.
def train_attention_probe(
acts: Float[Tensor, "n seq d_model"],
masks: Bool[Tensor, "n seq"],
labels: Float[Tensor, "n"],
n_heads: int = 1,
n_epochs: int = 200,
lr: float = 5e-3,
weight_decay: float = 1e-3,
) -> "AttentionProbe":
probe = AttentionProbe(d_model=acts.shape[-1], n_heads=n_heads)
optimizer = t.optim.AdamW(probe.parameters(), lr=lr, weight_decay=weight_decay)
criterion = t.nn.BCEWithLogitsLoss()
probe.train()
for _ in range(n_epochs):
logits = probe(acts, masks)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return probe.eval()
attn_probe = train_attention_probe(hs_acts_train, hs_masks_train, hs_train_labels, n_heads=1)
with t.no_grad():
attn_scores_test = attn_probe(hs_acts_test, hs_masks_test).sigmoid().numpy()
auroc_attn = roc_auc_score(hs_test_labels.numpy(), attn_scores_test)
# Print comparison table
methods = [
"MMProbe (last token)",
"LRProbe (last token)",
"LRProbe (mean pool) ",
"AttentionProbe (full seq)",
]
aurocs = [auroc_mm, auroc_lr_l, auroc_lr_m, auroc_attn]
best = max(aurocs)
print(f"\n{'Method':<35} {'AUROC':>7}")
print("-" * 44)
for name, auc in zip(methods, aurocs):
marker = " <-- best" if auc == best else ""
print(f"{name:<35} {auc:>7.3f}{marker}")
# ROC curves for all four methods
fig = go.Figure()
curve_data = [
("MMProbe (last)", mm_probe_hs(hs_last_test).detach().numpy()),
("LRProbe (last)", lr_probe_hs_last(hs_last_test).detach().numpy()),
("LRProbe (mean)", lr_probe_hs_mean(hs_mean_test).detach().numpy()),
("AttentionProbe", attn_scores_test),
]
for curve_name, scores in curve_data:
fpr, tpr, _ = roc_curve(hs_test_labels.numpy(), scores)
fig.add_trace(go.Scatter(x=fpr, y=tpr, mode="lines", name=curve_name))
fig.add_trace(
go.Scatter(
x=[0, 1],
y=[0, 1],
mode="lines",
line=dict(dash="dash", color="gray"),
name="Chance",
showlegend=True,
)
)
fig.update_layout(
title=f"ROC curves - high-stakes detection (layer {HS_LAYER}, {len(hs_test_texts)} test examples)",
xaxis_title="False positive rate",
yaxis_title="True positive rate",
height=450,
width=650,
)
fig.show()
Click to see the expected output
Method AUROC -------------------------------------------- MMProbe (last token) 0.873 LRProbe (last token) 0.962 LRProbe (mean pool) 0.991 AttentionProbe (full seq) 0.994 <-- best
Out-of-distribution evaluation
The synthetic training data uses a fixed set of topics and user roles. Does the probe generalize to real-world text? The Arrrlex/models-under-pressure HuggingFace dataset includes several real-world evaluation sets: conversations from Anthropic HH-RLHF, medical transcriptions, and ToolACE (a tool-using assistant dataset). These differ from the synthetic training data in both style and domain.
Below we load two of these evaluation datasets, extract activations with the same model and layer, and score with all four probes trained above. No retraining is performed.
eval_configs = ["anthropic_hh_balanced", "toolace_balanced"]
ood_results: dict[str, dict[str, float]] = {}
for cfg in eval_configs:
try:
eval_ds = load_dataset("Arrrlex/models-under-pressure", cfg, split="test")
except Exception as e:
print(f" Skipping {cfg}: {e}")
continue
eval_label_key = next(k for k in ("high_stakes", "label", "labels") if k in eval_ds.column_names)
eval_text_key = next(k for k in ("inputs", "prompt", "text") if k in eval_ds.column_names)
raw_labels = eval_ds[eval_label_key]
raw_texts = eval_ds[eval_text_key]
# Keep only binary labels (drop any "ambiguous" examples)
pairs = [
(format_as_chat(txt), to_int_label(lbl))
for txt, lbl in zip(raw_texts, raw_labels)
if isinstance(lbl, bool) or lbl in ("high-stakes", "low-stakes", "high_stakes", "low_stakes")
]
eval_texts = [p[0] for p in pairs]
eval_labels = t.tensor([p[1] for p in pairs], dtype=t.float32)
n_hi = int(eval_labels.sum())
print(f"{cfg}: {len(eval_texts)} examples ({n_hi} high, {len(eval_texts) - n_hi} low)")
eval_acts, eval_masks = extract_full_sequence_activations(
eval_texts, hs_model, hs_tokenizer, HS_LAYER, batch_size=8, max_length=HS_MAX_LEN
)
eval_last = last_token_acts(eval_acts, eval_masks)
eval_mean = mean_pool_acts(eval_acts, eval_masks)
ood_results[cfg] = {
"MMProbe (last)": compute_auroc(mm_probe_hs, eval_last, eval_labels),
"LRProbe (last)": compute_auroc(lr_probe_hs_last, eval_last, eval_labels),
"LRProbe (mean)": compute_auroc(lr_probe_hs_mean, eval_mean, eval_labels),
}
with t.no_grad():
attn_scores_ood = attn_probe(eval_acts, eval_masks).sigmoid().numpy()
ood_results[cfg]["AttnProbe"] = roc_auc_score(eval_labels.numpy(), attn_scores_ood)
# Combined comparison table: synthetic test set + OOD eval datasets
all_cols = ["Synthetic"] + list(ood_results.keys())
col_w = max(len(c) for c in all_cols) + 2
method_names = ["MMProbe (last)", "LRProbe (last)", "LRProbe (mean)", "AttnProbe"]
synth_aurocs = {
"MMProbe (last)": auroc_mm,
"LRProbe (last)": auroc_lr_l,
"LRProbe (mean)": auroc_lr_m,
"AttnProbe": auroc_attn,
}
header = f"{'Method':<25}" + "".join(f"{c:>{col_w}}" for c in all_cols)
print(f"\n{header}")
print("-" * len(header))
for m in method_names:
row = f"{m:<25}{synth_aurocs[m]:>{col_w}.3f}"
for cfg in ood_results:
row += f"{ood_results[cfg][m]:>{col_w}.3f}"
print(row)
Click to see the expected output
anthropic_hh_balanced: 2984 examples (1492 high, 1492 low) toolace_balanced: 734 examples (367 high, 367 low) Method Synthetic anthropic_hh_balanced toolace_balanced ---------------------------------------------------------------------------------------------- MMProbe (last) 0.873 0.767 0.573 LRProbe (last) 0.962 0.732 0.533 LRProbe (mean) 0.991 0.876 0.741 AttnProbe 0.994 0.849 0.786
What does the probe attend to?
The attention weights from a trained probe tell us directly which token positions were most useful for the classification. For correctly identified high-stakes prompts, we would expect the weights to concentrate on the phrase or clause that makes the request dangerous, rather than being spread uniformly across the whole input.
AttnLite uses a position-independent global query: the same learned weight vector scores every token, regardless of where it appears in the sequence. This means the attention distribution can be read as a per-token relevance score with no query-position dependence.
# Pick a few high-stakes examples to visualize attention patterns.
n_vis = 2
vis_indices = [i for i, l in enumerate(hs_test_labels.tolist()) if l == 1][:n_vis]
vis_inputs = hs_tokenizer(
[hs_test_texts[i] for i in vis_indices],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=HS_MAX_LEN,
)
vis_acts = hs_acts_test[vis_indices] # (n_vis, seq, d_model)
vis_masks = hs_masks_test[vis_indices] # (n_vis, seq)
with t.no_grad():
vis_attn = attn_probe.get_attention_weights(vis_acts, vis_masks) # (n_vis, seq, n_heads)
for ex, idx in enumerate(vis_indices):
n_valid = int(vis_masks[ex].sum().item())
raw_tokens = hs_tokenizer.convert_ids_to_tokens(vis_inputs["input_ids"][ex, :n_valid].tolist())
tokens = [utils.clean_bpe_token(tok) for tok in raw_tokens]
# Attention weights for valid positions: (valid_len, n_heads)
weights = vis_attn[ex, :n_valid, :]
# Bar chart of the attention distribution over token positions (head 0).
# Since the probe query is position-independent, this 1-D vector IS the
# full attention pattern; the heatmap above just tiles it across rows.
w = weights[:, 0].float().numpy()
top_thresh = float(np.percentile(w, 95))
bar_fig = go.Figure(
go.Bar(
x=list(range(n_valid)),
y=w,
text=tokens,
hovertemplate="Token: %{text}<br>Pos: %{x}<br>Weight: %{y:.4f}<extra></extra>",
marker_color=["#d62728" if wi >= top_thresh else "#1f77b4" for wi in w],
)
)
bar_fig.update_layout(
title="Attention weight per token (head 0)",
xaxis_title="Token position",
yaxis_title="Attention weight",
height=300,
width=max(700, n_valid * 5),
showlegend=False,
)
bar_fig.show()
Click to see the expected output
Here are some things you should look for in the attention patterns, as sanity checks:
- Mostly uniform-ish attention rather than overweighting any single token. This is supported by the evidence that mean probes outperform last-token probes, since uniform attention is basically equivalent to taking a mean.
- Attending less to very early tokens, since these won't have formed any interesting representations about what the sentence is about yet - these won't be inferrable until we get further in.
- Attending to "signalling words" e.g. "damage", "emergency", since these drive classification. Although this is more common when we apply attention probes at early layers, when the model hasn't formed as many intermediate representations yet so it has to rely on surface-level tokens.
- On the opposite side, if we're probing later layers, sometimes attending to punctuation like "." - since these tokens can carry higher-level representations of a sentence's semantic meaning.
More discussion: interpreting the attention patterns
For correctly classified high-stakes prompts, attention weight tends to concentrate on the specific phrase or clause that triggers the high-stakes label: words like "without anyone noticing", "forge", "without consent", or a direct reference to the harmful act. For low-stakes prompts the distribution is typically more diffuse.
A couple of things to watch out for:
-
Watch out for system-prompt tokens. The chat template adds boilerplate before the user's message. If the probe heavily weights these identical tokens it may have found a shortcut that won't generalise to differently-formatted inputs. Compare patterns across examples to check whether the concentrated weight is in the user turn, not the template.
-
The paper notes that mean-pooling can fail on long prompts where only one clause is dangerous. The attention probe can in principle focus on that clause, but only if training exposed it to this pattern.
-
Attention operates on BPE tokens, not words. Multi-token phrases (e.g. "over▁dose") may have their attention weight split across constituent tokens.
Bonus: out-of-distribution generalisation
The training data is LLM-generated synthetic pairs. The paper evaluates whether probes trained on this data generalise to real-world datasets: Anthropic HH-RLHF conversations, medical transcriptions, mental health dialogues, and Aya red-teaming prompts. These are substantially different from the training distribution in style and domain.
To test this yourself, download the evaluation datasets:
cd /path/to/models-under-pressure
uv run mup datasets download
Evaluation datasets land in data/evals/dev/ as JSONL files. Load them with datasets.load_dataset("json", data_files=...), extract activations using extract_full_sequence_activations with the same layer and model, then score with the trained attn_probe directly. No retraining required.
The paper finds that the attention probe generalises better than last-token baselines across all evaluation datasets. Cross-dataset generalisation is also studied via a generalisation heatmap (train on dataset A, evaluate on B), which the repo reproduces with uv run mup exp +experiment=generalisation_heatmap.