2️⃣ Training & comparing probes
Learning Objectives
- Implement difference-of-means (MM) and logistic regression (LR) probes
- Compare probe types: accuracy, direction similarity, and what each captures
- Understand CCS (Contrastive Consistent Search) as an unsupervised alternative and its limitations
Now that we've seen truth is linearly represented, let's train proper probes and understand the differences between probe types.
From the Geometry of Truth paper:
"We find that the difference-in-means directions are more causally implicated in the model's computation of truth values than the logistic regression directions, despite the fact that they achieve similar accuracies when used as probes."
We'll implement two probe types. The MMProbe (Mass-Mean / Difference-of-Means) is the simplest: the "truth direction" is just the difference between the mean of true and false activations. The LRProbe (Logistic Regression) is a linear classifier trained via sklearn to find the best separating direction.
Both produce a single direction vector. The paper reports similar accuracies for both, but the key question is which direction is more causally meaningful. We'll answer this in Section 3.
First, let's set up train/test splits for all our datasets. We'll use these throughout this section.
# Create train/test splits for all datasets
t.manual_seed(42)
train_acts, test_acts = {}, {}
train_labels, test_labels = {}, {}
for name in DATASET_NAMES:
acts = activations[name]
labs = labels_dict[name]
n = len(acts)
perm = t.randperm(n)
n_train = int(0.8 * n)
train_acts[name] = acts[perm[:n_train]]
test_acts[name] = acts[perm[n_train:]]
train_labels[name] = labs[perm[:n_train]]
test_labels[name] = labs[perm[n_train:]]
print(f"{name}: train={n_train}, test={n - n_train}")
Exercise - implement MMProbe
Implement the Mass-Mean (difference-of-means) probe as a PyTorch nn.Module. The key components:
* direction: the vector mean(true_acts) - mean(false_acts), stored as a non-trainable parameter
* covariance: the pooled within-class covariance matrix (for optional IID-corrected evaluation)
* forward(x, iid=False): returns sigmoid(x @ direction), or sigmoid(x @ inv_cov @ direction) if iid=True
* pred(x, iid=False): returns binary predictions (round the probabilities)
* from_data(acts, labels): class method that constructs a probe from data
class MMProbe(t.nn.Module):
def __init__(
self,
direction: Float[Tensor, " d_model"],
covariance: Float[Tensor, "d_model d_model"] | None = None,
atol: float = 1e-3,
):
super().__init__()
# Store direction and precompute inverse covariance
raise NotImplementedError()
def forward(self, x: Float[Tensor, "n d_model"], iid: bool = False) -> Float[Tensor, " n"]:
raise NotImplementedError()
def pred(self, x: Float[Tensor, "n d_model"], iid: bool = False) -> Float[Tensor, " n"]:
return self(x, iid=iid).round()
@staticmethod
def from_data(
acts: Float[Tensor, "n d_model"],
labels: Float[Tensor, " n"],
device: str = "cpu",
) -> "MMProbe":
raise NotImplementedError()
mm_probe = MMProbe.from_data(train_acts["cities"], train_labels["cities"])
# Train accuracy
train_preds = mm_probe.pred(train_acts["cities"])
train_acc = (train_preds == train_labels["cities"]).float().mean().item()
# Test accuracy
test_preds = mm_probe.pred(test_acts["cities"])
test_acc = (test_preds == test_labels["cities"]).float().mean().item()
assert test_acc > 0.7, "Expected at least 70% accuracy"
print("MMProbe on cities:")
print(f" Train accuracy: {train_acc:.3f}")
print(f" Test accuracy: {test_acc:.3f}")
print(f" Direction norm: {mm_probe.direction.norm().item():.3f}")
print(f" Direction (first 5): {mm_probe.direction[:5].tolist()}")
Click to see the expected output
MMProbe on cities: Train accuracy: 0.751 Test accuracy: 0.743 Direction norm: 9.776 Direction (first 5): [0.04566267877817154, 0.2775959372520447, -0.09119868278503418, 0.09136010706424713, -0.14157512784004211]
Solution
class MMProbe(t.nn.Module):
def __init__(
self,
direction: Float[Tensor, " d_model"],
covariance: Float[Tensor, "d_model d_model"] | None = None,
atol: float = 1e-3,
):
super().__init__()
self.direction = t.nn.Parameter(direction, requires_grad=False)
if covariance is not None:
self.inv = t.nn.Parameter(t.linalg.pinv(covariance, hermitian=True, atol=atol), requires_grad=False)
else:
self.inv = None
def forward(self, x: Float[Tensor, "n d_model"], iid: bool = False) -> Float[Tensor, " n"]:
if iid and self.inv is not None:
return t.sigmoid(x @ self.inv @ self.direction)
else:
return t.sigmoid(x @ self.direction)
def pred(self, x: Float[Tensor, "n d_model"], iid: bool = False) -> Float[Tensor, " n"]:
return self(x, iid=iid).round()
@staticmethod
def from_data(
acts: Float[Tensor, "n d_model"],
labels: Float[Tensor, " n"],
device: str = "cpu",
) -> "MMProbe":
acts, labels = acts.to(device), labels.to(device)
pos_acts = acts[labels == 1]
neg_acts = acts[labels == 0]
pos_mean = pos_acts.mean(0)
neg_mean = neg_acts.mean(0)
direction = pos_mean - neg_mean
centered = t.cat([pos_acts - pos_mean, neg_acts - neg_mean], dim=0)
covariance = centered.t() @ centered / acts.shape[0]
return MMProbe(direction, covariance=covariance).to(device)
Exercise - implement LRProbe
Implement a logistic regression probe using sklearn's LogisticRegression with StandardScaler normalization, matching the methodology of the deception-detection repo.
You need to fill in three methods:
__init__: Create aSequential(Linear(d_in, 1, bias=False), Sigmoid())network, and registerscaler_meanandscaler_scaleas buffers (these store the scaler parameters for use at inference time).forward: Apply_normalizeto the input, then pass throughself.net, squeezing the last dimension.from_data: Train aLogisticRegression(C=C, fit_intercept=False)onStandardScaler-normalized activations, then construct anLRProbewith the fitted scaler parameters and copy the learned coefficients intoself.net[0].weight.data[0].
Key details:
* L2 regularization: C=0.1 (matching the paper's lambda=10), fit_intercept=False
* The _normalize method applies the stored scaler parameters during forward passes and is already provided for you
class LRProbe(t.nn.Module):
def __init__(self, d_in: int, scaler_mean: Tensor | None = None, scaler_scale: Tensor | None = None):
super().__init__()
raise NotImplementedError()
def _normalize(self, x: Float[Tensor, "n d_model"]) -> Float[Tensor, "n d_model"]:
"""Apply StandardScaler normalization if scaler parameters are available."""
if self.scaler_mean is not None and self.scaler_scale is not None:
return (x - self.scaler_mean) / self.scaler_scale
return x
def forward(self, x: Float[Tensor, "n d_model"]) -> Float[Tensor, " n"]:
raise NotImplementedError()
def pred(self, x: Float[Tensor, "n d_model"]) -> Float[Tensor, " n"]:
return self(x).round()
@property
def direction(self) -> Float[Tensor, " d_model"]:
return self.net[0].weight.data[0]
@staticmethod
def from_data(
acts: Float[Tensor, "n d_model"],
labels: Float[Tensor, " n"],
C: float = 0.1,
device: str = "cpu",
) -> "LRProbe":
"""
Train an LR probe using sklearn's LogisticRegression with StandardScaler normalization.
Args:
acts: Activation matrix [n_samples, d_model].
labels: Binary labels (1=true, 0=false).
C: Inverse regularization strength (lower = stronger regularization).
Default 0.1 (reg_coeff=10) matches the deception-detection paper's cfg.yaml.
The repo class default is reg_coeff=1000 (C=0.001), which is stronger.
device: Device to place the resulting probe on.
"""
raise NotImplementedError()
lr_probe = LRProbe.from_data(train_acts["cities"], train_labels["cities"], device="cpu")
# Train accuracy
train_preds = lr_probe.pred(train_acts["cities"])
train_acc = (train_preds == train_labels["cities"]).float().mean().item()
# Test accuracy
test_preds = lr_probe.pred(test_acts["cities"])
test_acc = (test_preds == test_labels["cities"]).float().mean().item()
print("LRProbe on cities:")
print(f" Train accuracy: {train_acc:.3f}")
print(f" Test accuracy: {test_acc:.3f}")
print(f" Direction norm: {lr_probe.direction.norm().item():.3f}")
assert test_acc >= 0.90, f"Test accuracy too low: {test_acc:.3f} (expected >= 0.90)"
# Compare directions
mm_dir = mm_probe.direction / mm_probe.direction.norm()
lr_dir = lr_probe.direction / lr_probe.direction.norm()
cos_sim = (mm_dir @ lr_dir).item()
print(f"\nCosine similarity between MM and LR directions: {cos_sim:.4f}")
# Compare both probes across all 3 datasets
results_rows = []
for name in DATASET_NAMES:
mm_p = MMProbe.from_data(train_acts[name], train_labels[name])
lr_p = LRProbe.from_data(train_acts[name], train_labels[name])
mm_test_acc = (mm_p.pred(test_acts[name]) == test_labels[name]).float().mean().item()
lr_test_acc = (lr_p.pred(test_acts[name]) == test_labels[name]).float().mean().item()
results_rows.append({"Dataset": name, "MM Test Acc": f"{mm_test_acc:.3f}", "LR Test Acc": f"{lr_test_acc:.3f}"})
results_df = pd.DataFrame(results_rows)
print("\nProbe accuracy comparison across datasets:")
display(results_df)
# Bar chart
fig = go.Figure()
fig.add_trace(go.Bar(name="MMProbe", x=DATASET_NAMES, y=[float(r["MM Test Acc"]) for r in results_rows]))
fig.add_trace(go.Bar(name="LRProbe", x=DATASET_NAMES, y=[float(r["LR Test Acc"]) for r in results_rows]))
fig.update_layout(
title="Probe Test Accuracy by Dataset",
yaxis_title="Test Accuracy",
yaxis_range=[0.5, 1.05],
barmode="group",
height=400,
width=600,
)
fig.show()
Click to see the expected output
LRProbe on cities: Train accuracy: 1.000 Test accuracy: 0.993 Direction norm: 0.540 Cosine similarity between MM and LR directions: 0.6393
Solution
class LRProbe(t.nn.Module):
def __init__(self, d_in: int, scaler_mean: Tensor | None = None, scaler_scale: Tensor | None = None):
super().__init__()
self.net = t.nn.Sequential(t.nn.Linear(d_in, 1, bias=False), t.nn.Sigmoid())
self.register_buffer("scaler_mean", scaler_mean)
self.register_buffer("scaler_scale", scaler_scale)
def _normalize(self, x: Float[Tensor, "n d_model"]) -> Float[Tensor, "n d_model"]:
"""Apply StandardScaler normalization if scaler parameters are available."""
if self.scaler_mean is not None and self.scaler_scale is not None:
return (x - self.scaler_mean) / self.scaler_scale
return x
def forward(self, x: Float[Tensor, "n d_model"]) -> Float[Tensor, " n"]:
return self.net(self._normalize(x)).squeeze(-1)
def pred(self, x: Float[Tensor, "n d_model"]) -> Float[Tensor, " n"]:
return self(x).round()
@property
def direction(self) -> Float[Tensor, " d_model"]:
return self.net[0].weight.data[0]
@staticmethod
def from_data(
acts: Float[Tensor, "n d_model"],
labels: Float[Tensor, " n"],
C: float = 0.1,
device: str = "cpu",
) -> "LRProbe":
"""
Train an LR probe using sklearn's LogisticRegression with StandardScaler normalization.
Args:
acts: Activation matrix [n_samples, d_model].
labels: Binary labels (1=true, 0=false).
C: Inverse regularization strength (lower = stronger regularization).
Default 0.1 (reg_coeff=10) matches the deception-detection paper's cfg.yaml.
The repo class default is reg_coeff=1000 (C=0.001), which is stronger.
device: Device to place the resulting probe on.
"""
X = acts.cpu().float().numpy()
y = labels.cpu().float().numpy()
# Standardize features (zero mean, unit variance) before fitting, as in the paper
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# fit_intercept=False: the paper fits on normalized data so the intercept is redundant
lr_model = LogisticRegression(C=C, random_state=42, fit_intercept=False, max_iter=1000)
lr_model.fit(X_scaled, y)
# Build probe with scaler parameters baked in
scaler_mean = t.tensor(scaler.mean_, dtype=t.float32)
scaler_scale = t.tensor(scaler.scale_, dtype=t.float32)
probe = LRProbe(acts.shape[-1], scaler_mean=scaler_mean, scaler_scale=scaler_scale).to(device)
probe.net[0].weight.data[0] = t.tensor(lr_model.coef_[0], dtype=t.float32).to(device)
return probe
Exercise - cross-dataset generalization matrix
Train MM and LR probes on each of the 3 curated datasets and evaluate each probe on all 3 datasets, producing 3×3 accuracy matrices. If the model has a unified truth direction, off-diagonal entries should be high. The Geometry of Truth paper reports strong cross-generalization for LR probes at 13B scale, and even stronger at 70B:
"The LR probes transfer almost perfectly: probes trained on any one of our three datasets achieve near-ceiling accuracy on the other two."
Do your results agree? You should also compute pairwise cosine similarities between probe directions trained on different datasets to check whether they point in a similar direction.
def compute_generalization_matrix(
train_acts: dict[str, Float[Tensor, "n d"]],
train_labels: dict[str, Float[Tensor, " n"]],
test_acts: dict[str, Float[Tensor, "n d"]],
test_labels: dict[str, Float[Tensor, " n"]],
dataset_names: list[str],
probe_cls: type,
) -> Float[Tensor, "n_datasets n_datasets"]:
"""
Compute a generalization matrix: entry (i, j) is the test accuracy of a probe trained on dataset i
and evaluated on dataset j.
Args:
train_acts, train_labels: Training data per dataset.
test_acts, test_labels: Test data per dataset.
dataset_names: Names of datasets (determines matrix ordering).
probe_cls: Probe class to use (MMProbe or LRProbe), must have from_data and pred methods.
Returns:
Tensor of shape [n_datasets, n_datasets] with accuracy values.
"""
raise NotImplementedError()
mm_matrix = compute_generalization_matrix(train_acts, train_labels, test_acts, test_labels, DATASET_NAMES, MMProbe)
lr_matrix = compute_generalization_matrix(train_acts, train_labels, test_acts, test_labels, DATASET_NAMES, LRProbe)
assert mm_matrix.shape == (3, 3), f"Wrong shape: {mm_matrix.shape}"
assert (mm_matrix.diag() > 0.6).all(), "In-distribution accuracy should be at least 60%"
# Heatmap visualization
fig = make_subplots(rows=1, cols=2, subplot_titles=["MMProbe", "LRProbe"], horizontal_spacing=0.15)
for idx, (matrix, name) in enumerate([(mm_matrix, "MM"), (lr_matrix, "LR")]):
text_vals = [[f"{matrix[i, j]:.3f}" for j in range(len(DATASET_NAMES))] for i in range(len(DATASET_NAMES))]
fig.add_trace(
go.Heatmap(
z=matrix.numpy(),
x=DATASET_NAMES,
y=DATASET_NAMES,
text=text_vals,
texttemplate="%{text}",
colorscale="RdYlGn",
zmin=0.5,
zmax=1.0,
showscale=(idx == 1),
),
row=1,
col=idx + 1,
)
fig.update_yaxes(title_text="Train dataset" if idx == 0 else "", row=1, col=idx + 1)
fig.update_xaxes(title_text="Test dataset", row=1, col=idx + 1)
fig.update_layout(title="Cross-dataset Generalization (Test Accuracy)", height=400, width=800)
fig.show()
# Cosine similarity between probe directions
mm_directions = {name: MMProbe.from_data(train_acts[name], train_labels[name]).direction for name in DATASET_NAMES}
lr_directions = {name: LRProbe.from_data(train_acts[name], train_labels[name]).direction for name in DATASET_NAMES}
print("\nPairwise cosine similarity between probe directions:")
for probe_name, directions in [("MM", mm_directions), ("LR", lr_directions)]:
print(f"\n {probe_name}Probe:")
for i, n1 in enumerate(DATASET_NAMES):
for j, n2 in enumerate(DATASET_NAMES):
if j > i:
d1 = directions[n1] / directions[n1].norm()
d2 = directions[n2] / directions[n2].norm()
print(f" {n1} vs {n2}: {(d1 @ d2).item():.4f}")
Click to see the expected output
Discussion - interpreting the generalization matrix
If the model has a single "truth direction" that is shared across domains, then probes trained on any dataset should achieve high accuracy on all datasets (high off-diagonal entries). If instead the model represents truth in domain-specific ways, you'll see high diagonal (in-distribution) but low off-diagonal (out-of-distribution) accuracy.
The Geometry of Truth paper reports that LR probes generalize almost perfectly at 13B+ scale, which suggests a largely unified truth representation. However, at smaller model scales or for very different domains, you may see more variation.
As a further test, try evaluating your cities probe on the likely dataset from the geometry-of-truth repo, which contains nonfactual text with likely or unlikely continuations (code snippets, random text). If the truth probe also separates this dataset, the probe may detect text probability rather than factual truth. The paper predicts (and confirms) that a well-trained truth probe should not separate likely, because truth and likelihood are distinct features.
Solution
def compute_generalization_matrix(
train_acts: dict[str, Float[Tensor, "n d"]],
train_labels: dict[str, Float[Tensor, " n"]],
test_acts: dict[str, Float[Tensor, "n d"]],
test_labels: dict[str, Float[Tensor, " n"]],
dataset_names: list[str],
probe_cls: type,
) -> Float[Tensor, "n_datasets n_datasets"]:
"""
Compute a generalization matrix: entry (i, j) is the test accuracy of a probe trained on dataset i
and evaluated on dataset j.
Args:
train_acts, train_labels: Training data per dataset.
test_acts, test_labels: Test data per dataset.
dataset_names: Names of datasets (determines matrix ordering).
probe_cls: Probe class to use (MMProbe or LRProbe), must have from_data and pred methods.
Returns:
Tensor of shape [n_datasets, n_datasets] with accuracy values.
"""
n = len(dataset_names)
matrix = t.zeros(n, n)
for i, train_name in enumerate(dataset_names):
probe = probe_cls.from_data(train_acts[train_name], train_labels[train_name])
for j, test_name in enumerate(dataset_names):
preds = probe.pred(test_acts[test_name])
acc = (preds == test_labels[test_name]).float().mean().item()
matrix[i, j] = acc
return matrix
A note about CCS (optional)
Contrastive Consistent Search (CCS) and the unsupervised discovery debate
Before moving to causal interventions, it's worth discussing CCS (Contrastive Consistent Search), which represents one of the most exciting and debated ideas in probing research.
The CCS method
CCS is unsupervised: it requires paired positive/negative statements but no labels. The loss enforces that p(x) ≈ 1 - p(neg_x) (consistency) and that the probe is confident. This suggests truth might be discoverable without any supervision.
Key critiques
Contrast pairs do all the work. Emmons (2023) shows that PCA on contrast pair differences achieves 97% of CCS's accuracy. The CCS loss function is largely redundant; the contrast pair construction is the real innovation.
CCS may not find knowledge. Farquhar et al. (Google DeepMind, 2023) prove that for every possible binary classification, there exists a zero-loss CCS probe, so the loss has no structural preference for knowledge over arbitrary features. When they append random distractor words to contrast pairs, CCS learns to classify the distractor instead of truth.
The XOR problem. Sam Marks (2024) shows that LLMs linearly represent XORs of arbitrary features, creating a new failure mode for probes: a probe could learn truth XOR geography which works on geographic data but fails under distribution shift. Empirically probes often do generalize, likely because basic features have higher variance than their XORs, but this is a matter of degree, not a guarantee. (Note, most of this paper talks about the XOR problem in a broader way than just its relevance to CCS - still strongly recommended as further reading!)
Takeaway
The CCS literature tells a cautionary tale: probing accuracy on a single dataset tells you very little. This is why Section 3 demands causal evidence via interventions, and why cross-dataset generalization is the real test of a probe.
Bonus exercise: implement CCS
If you want to dig deeper, try implementing CCS yourself using the neg_cities dataset (which provides natural contrast pairs). The core idea: take the difference vector for each contrast pair (true statement activation minus false statement activation), then run PCA on those difference vectors. As Emmons shows, just the first principal component of the difference vectors recovers ~97% of CCS's accuracy, so you don't even need the CCS loss. You can then compare this unsupervised direction to your supervised MM and LR probes on the generalization datasets from Section 2, and see how the causal effects compare in Section 3.
For further reading: Levinstein & Herrmann (2024), "Still No Lie Detector for Language Models"; the geometry-of-truth repo includes a full CCSProbe implementation in probes.py.