Exercise Status: All exercises complete and verified

2️⃣ Training & comparing probes

Learning Objectives
  • Implement difference-of-means (MM) and logistic regression (LR) probes
  • Compare probe types: accuracy, direction similarity, and what each captures
  • Understand CCS (Contrastive Consistent Search) as an unsupervised alternative and its limitations

Now that we've seen truth is linearly represented, let's train proper probes and understand the differences between probe types.

From the Geometry of Truth paper:

"We find that the difference-in-means directions are more causally implicated in the model's computation of truth values than the logistic regression directions, despite the fact that they achieve similar accuracies when used as probes."

We'll implement two probe types. The MMProbe (Mass-Mean / Difference-of-Means) is the simplest: the "truth direction" is just the difference between the mean of true and false activations. The LRProbe (Logistic Regression) is a linear classifier trained via sklearn to find the best separating direction.

Both produce a single direction vector. The paper reports similar accuracies for both, but the key question is which direction is more causally meaningful. We'll answer this in Section 3.

First, let's set up train/test splits for all our datasets. We'll use these throughout this section.

# Create train/test splits for all datasets
t.manual_seed(42)
train_acts, test_acts = {}, {}
train_labels, test_labels = {}, {}

for name in DATASET_NAMES:
    acts = activations[name]
    labs = labels_dict[name]
    n = len(acts)
    perm = t.randperm(n)
    n_train = int(0.8 * n)

    train_acts[name] = acts[perm[:n_train]]
    test_acts[name] = acts[perm[n_train:]]
    train_labels[name] = labs[perm[:n_train]]
    test_labels[name] = labs[perm[n_train:]]

    print(f"{name}: train={n_train}, test={n - n_train}")

Exercise - implement MMProbe

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 10-15 minutes on this exercise. The simplest and often most causally meaningful probe type.

Implement the Mass-Mean (difference-of-means) probe as a PyTorch nn.Module. The key components: * direction: the vector mean(true_acts) - mean(false_acts), stored as a non-trainable parameter * covariance: the pooled within-class covariance matrix (for optional IID-corrected evaluation) * forward(x, iid=False): returns sigmoid(x @ direction), or sigmoid(x @ inv_cov @ direction) if iid=True * pred(x, iid=False): returns binary predictions (round the probabilities) * from_data(acts, labels): class method that constructs a probe from data

class MMProbe(t.nn.Module):
    def __init__(
        self,
        direction: Float[Tensor, " d_model"],
        covariance: Float[Tensor, "d_model d_model"] | None = None,
        atol: float = 1e-3,
    ):
        super().__init__()
        # Store direction and precompute inverse covariance
        raise NotImplementedError()

    def forward(self, x: Float[Tensor, "n d_model"], iid: bool = False) -> Float[Tensor, " n"]:
        raise NotImplementedError()

    def pred(self, x: Float[Tensor, "n d_model"], iid: bool = False) -> Float[Tensor, " n"]:
        return self(x, iid=iid).round()

    @staticmethod
    def from_data(
        acts: Float[Tensor, "n d_model"],
        labels: Float[Tensor, " n"],
        device: str = "cpu",
    ) -> "MMProbe":
        raise NotImplementedError()


mm_probe = MMProbe.from_data(train_acts["cities"], train_labels["cities"])

# Train accuracy
train_preds = mm_probe.pred(train_acts["cities"])
train_acc = (train_preds == train_labels["cities"]).float().mean().item()

# Test accuracy
test_preds = mm_probe.pred(test_acts["cities"])
test_acc = (test_preds == test_labels["cities"]).float().mean().item()
assert test_acc > 0.7, "Expected at least 70% accuracy"

print("MMProbe on cities:")
print(f"  Train accuracy: {train_acc:.3f}")
print(f"  Test accuracy:  {test_acc:.3f}")
print(f"  Direction norm: {mm_probe.direction.norm().item():.3f}")
print(f"  Direction (first 5): {mm_probe.direction[:5].tolist()}")
Click to see the expected output
MMProbe on cities:
  Train accuracy: 0.751
  Test accuracy:  0.743
  Direction norm: 9.776
  Direction (first 5): [0.04566267877817154, 0.2775959372520447, -0.09119868278503418, 0.09136010706424713, -0.14157512784004211]
Solution
class MMProbe(t.nn.Module):
    def __init__(
        self,
        direction: Float[Tensor, " d_model"],
        covariance: Float[Tensor, "d_model d_model"] | None = None,
        atol: float = 1e-3,
    ):
        super().__init__()
        self.direction = t.nn.Parameter(direction, requires_grad=False)
        if covariance is not None:
            self.inv = t.nn.Parameter(t.linalg.pinv(covariance, hermitian=True, atol=atol), requires_grad=False)
        else:
            self.inv = None

    def forward(self, x: Float[Tensor, "n d_model"], iid: bool = False) -> Float[Tensor, " n"]:
        if iid and self.inv is not None:
            return t.sigmoid(x @ self.inv @ self.direction)
        else:
            return t.sigmoid(x @ self.direction)

    def pred(self, x: Float[Tensor, "n d_model"], iid: bool = False) -> Float[Tensor, " n"]:
        return self(x, iid=iid).round()

    @staticmethod
    def from_data(
        acts: Float[Tensor, "n d_model"],
        labels: Float[Tensor, " n"],
        device: str = "cpu",
    ) -> "MMProbe":
        acts, labels = acts.to(device), labels.to(device)
        pos_acts = acts[labels == 1]
        neg_acts = acts[labels == 0]
        pos_mean = pos_acts.mean(0)
        neg_mean = neg_acts.mean(0)
        direction = pos_mean - neg_mean

        centered = t.cat([pos_acts - pos_mean, neg_acts - neg_mean], dim=0)
        covariance = centered.t() @ centered / acts.shape[0]

        return MMProbe(direction, covariance=covariance).to(device)

Exercise - implement LRProbe

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 15-20 minutes on this exercise. Logistic regression is the most common probe type in the literature.

Implement a logistic regression probe using sklearn's LogisticRegression with StandardScaler normalization, matching the methodology of the deception-detection repo.

You need to fill in three methods:

  1. __init__: Create a Sequential(Linear(d_in, 1, bias=False), Sigmoid()) network, and register scaler_mean and scaler_scale as buffers (these store the scaler parameters for use at inference time).
  2. forward: Apply _normalize to the input, then pass through self.net, squeezing the last dimension.
  3. from_data: Train a LogisticRegression(C=C, fit_intercept=False) on StandardScaler-normalized activations, then construct an LRProbe with the fitted scaler parameters and copy the learned coefficients into self.net[0].weight.data[0].

Key details: * L2 regularization: C=0.1 (matching the paper's lambda=10), fit_intercept=False * The _normalize method applies the stored scaler parameters during forward passes and is already provided for you

class LRProbe(t.nn.Module):
    def __init__(self, d_in: int, scaler_mean: Tensor | None = None, scaler_scale: Tensor | None = None):
        super().__init__()
        raise NotImplementedError()

    def _normalize(self, x: Float[Tensor, "n d_model"]) -> Float[Tensor, "n d_model"]:
        """Apply StandardScaler normalization if scaler parameters are available."""
        if self.scaler_mean is not None and self.scaler_scale is not None:
            return (x - self.scaler_mean) / self.scaler_scale
        return x

    def forward(self, x: Float[Tensor, "n d_model"]) -> Float[Tensor, " n"]:
        raise NotImplementedError()

    def pred(self, x: Float[Tensor, "n d_model"]) -> Float[Tensor, " n"]:
        return self(x).round()

    @property
    def direction(self) -> Float[Tensor, " d_model"]:
        return self.net[0].weight.data[0]

    @staticmethod
    def from_data(
        acts: Float[Tensor, "n d_model"],
        labels: Float[Tensor, " n"],
        C: float = 0.1,
        device: str = "cpu",
    ) -> "LRProbe":
        """
        Train an LR probe using sklearn's LogisticRegression with StandardScaler normalization.

        Args:
            acts: Activation matrix [n_samples, d_model].
            labels: Binary labels (1=true, 0=false).
            C: Inverse regularization strength (lower = stronger regularization).
                Default 0.1 (reg_coeff=10) matches the deception-detection paper's cfg.yaml.
                The repo class default is reg_coeff=1000 (C=0.001), which is stronger.
            device: Device to place the resulting probe on.
        """
        raise NotImplementedError()


lr_probe = LRProbe.from_data(train_acts["cities"], train_labels["cities"], device="cpu")

# Train accuracy
train_preds = lr_probe.pred(train_acts["cities"])
train_acc = (train_preds == train_labels["cities"]).float().mean().item()

# Test accuracy
test_preds = lr_probe.pred(test_acts["cities"])
test_acc = (test_preds == test_labels["cities"]).float().mean().item()

print("LRProbe on cities:")
print(f"  Train accuracy: {train_acc:.3f}")
print(f"  Test accuracy:  {test_acc:.3f}")
print(f"  Direction norm: {lr_probe.direction.norm().item():.3f}")
assert test_acc >= 0.90, f"Test accuracy too low: {test_acc:.3f} (expected >= 0.90)"

# Compare directions
mm_dir = mm_probe.direction / mm_probe.direction.norm()
lr_dir = lr_probe.direction / lr_probe.direction.norm()
cos_sim = (mm_dir @ lr_dir).item()
print(f"\nCosine similarity between MM and LR directions: {cos_sim:.4f}")

# Compare both probes across all 3 datasets
results_rows = []
for name in DATASET_NAMES:
    mm_p = MMProbe.from_data(train_acts[name], train_labels[name])
    lr_p = LRProbe.from_data(train_acts[name], train_labels[name])

    mm_test_acc = (mm_p.pred(test_acts[name]) == test_labels[name]).float().mean().item()
    lr_test_acc = (lr_p.pred(test_acts[name]) == test_labels[name]).float().mean().item()
    results_rows.append({"Dataset": name, "MM Test Acc": f"{mm_test_acc:.3f}", "LR Test Acc": f"{lr_test_acc:.3f}"})

results_df = pd.DataFrame(results_rows)
print("\nProbe accuracy comparison across datasets:")
display(results_df)

# Bar chart
fig = go.Figure()
fig.add_trace(go.Bar(name="MMProbe", x=DATASET_NAMES, y=[float(r["MM Test Acc"]) for r in results_rows]))
fig.add_trace(go.Bar(name="LRProbe", x=DATASET_NAMES, y=[float(r["LR Test Acc"]) for r in results_rows]))
fig.update_layout(
    title="Probe Test Accuracy by Dataset",
    yaxis_title="Test Accuracy",
    yaxis_range=[0.5, 1.05],
    barmode="group",
    height=400,
    width=600,
)
fig.show()
Click to see the expected output
LRProbe on cities:
  Train accuracy: 1.000
  Test accuracy:  0.993
  Direction norm: 0.540

Cosine similarity between MM and LR directions: 0.6393
Solution
class LRProbe(t.nn.Module):
    def __init__(self, d_in: int, scaler_mean: Tensor | None = None, scaler_scale: Tensor | None = None):
        super().__init__()
        self.net = t.nn.Sequential(t.nn.Linear(d_in, 1, bias=False), t.nn.Sigmoid())
        self.register_buffer("scaler_mean", scaler_mean)
        self.register_buffer("scaler_scale", scaler_scale)

    def _normalize(self, x: Float[Tensor, "n d_model"]) -> Float[Tensor, "n d_model"]:
        """Apply StandardScaler normalization if scaler parameters are available."""
        if self.scaler_mean is not None and self.scaler_scale is not None:
            return (x - self.scaler_mean) / self.scaler_scale
        return x

    def forward(self, x: Float[Tensor, "n d_model"]) -> Float[Tensor, " n"]:
        return self.net(self._normalize(x)).squeeze(-1)

    def pred(self, x: Float[Tensor, "n d_model"]) -> Float[Tensor, " n"]:
        return self(x).round()

    @property
    def direction(self) -> Float[Tensor, " d_model"]:
        return self.net[0].weight.data[0]

    @staticmethod
    def from_data(
        acts: Float[Tensor, "n d_model"],
        labels: Float[Tensor, " n"],
        C: float = 0.1,
        device: str = "cpu",
    ) -> "LRProbe":
        """
        Train an LR probe using sklearn's LogisticRegression with StandardScaler normalization.

        Args:
            acts: Activation matrix [n_samples, d_model].
            labels: Binary labels (1=true, 0=false).
            C: Inverse regularization strength (lower = stronger regularization).
                Default 0.1 (reg_coeff=10) matches the deception-detection paper's cfg.yaml.
                The repo class default is reg_coeff=1000 (C=0.001), which is stronger.
            device: Device to place the resulting probe on.
        """
        X = acts.cpu().float().numpy()
        y = labels.cpu().float().numpy()

        # Standardize features (zero mean, unit variance) before fitting, as in the paper
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)

        # fit_intercept=False: the paper fits on normalized data so the intercept is redundant
        lr_model = LogisticRegression(C=C, random_state=42, fit_intercept=False, max_iter=1000)
        lr_model.fit(X_scaled, y)

        # Build probe with scaler parameters baked in
        scaler_mean = t.tensor(scaler.mean_, dtype=t.float32)
        scaler_scale = t.tensor(scaler.scale_, dtype=t.float32)
        probe = LRProbe(acts.shape[-1], scaler_mean=scaler_mean, scaler_scale=scaler_scale).to(device)
        probe.net[0].weight.data[0] = t.tensor(lr_model.coef_[0], dtype=t.float32).to(device)

        return probe

Exercise - cross-dataset generalization matrix

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵
You should spend up to 15-20 minutes on this exercise. Cross-dataset generalization is the real test of whether your probe has found a universal truth direction.

Train MM and LR probes on each of the 3 curated datasets and evaluate each probe on all 3 datasets, producing 3×3 accuracy matrices. If the model has a unified truth direction, off-diagonal entries should be high. The Geometry of Truth paper reports strong cross-generalization for LR probes at 13B scale, and even stronger at 70B:

"The LR probes transfer almost perfectly: probes trained on any one of our three datasets achieve near-ceiling accuracy on the other two."

Do your results agree? You should also compute pairwise cosine similarities between probe directions trained on different datasets to check whether they point in a similar direction.

def compute_generalization_matrix(
    train_acts: dict[str, Float[Tensor, "n d"]],
    train_labels: dict[str, Float[Tensor, " n"]],
    test_acts: dict[str, Float[Tensor, "n d"]],
    test_labels: dict[str, Float[Tensor, " n"]],
    dataset_names: list[str],
    probe_cls: type,
) -> Float[Tensor, "n_datasets n_datasets"]:
    """
    Compute a generalization matrix: entry (i, j) is the test accuracy of a probe trained on dataset i
    and evaluated on dataset j.

    Args:
        train_acts, train_labels: Training data per dataset.
        test_acts, test_labels: Test data per dataset.
        dataset_names: Names of datasets (determines matrix ordering).
        probe_cls: Probe class to use (MMProbe or LRProbe), must have from_data and pred methods.

    Returns:
        Tensor of shape [n_datasets, n_datasets] with accuracy values.
    """
    raise NotImplementedError()


mm_matrix = compute_generalization_matrix(train_acts, train_labels, test_acts, test_labels, DATASET_NAMES, MMProbe)
lr_matrix = compute_generalization_matrix(train_acts, train_labels, test_acts, test_labels, DATASET_NAMES, LRProbe)

assert mm_matrix.shape == (3, 3), f"Wrong shape: {mm_matrix.shape}"
assert (mm_matrix.diag() > 0.6).all(), "In-distribution accuracy should be at least 60%"

# Heatmap visualization
fig = make_subplots(rows=1, cols=2, subplot_titles=["MMProbe", "LRProbe"], horizontal_spacing=0.15)

for idx, (matrix, name) in enumerate([(mm_matrix, "MM"), (lr_matrix, "LR")]):
    text_vals = [[f"{matrix[i, j]:.3f}" for j in range(len(DATASET_NAMES))] for i in range(len(DATASET_NAMES))]
    fig.add_trace(
        go.Heatmap(
            z=matrix.numpy(),
            x=DATASET_NAMES,
            y=DATASET_NAMES,
            text=text_vals,
            texttemplate="%{text}",
            colorscale="RdYlGn",
            zmin=0.5,
            zmax=1.0,
            showscale=(idx == 1),
        ),
        row=1,
        col=idx + 1,
    )
    fig.update_yaxes(title_text="Train dataset" if idx == 0 else "", row=1, col=idx + 1)
    fig.update_xaxes(title_text="Test dataset", row=1, col=idx + 1)

fig.update_layout(title="Cross-dataset Generalization (Test Accuracy)", height=400, width=800)
fig.show()

# Cosine similarity between probe directions
mm_directions = {name: MMProbe.from_data(train_acts[name], train_labels[name]).direction for name in DATASET_NAMES}
lr_directions = {name: LRProbe.from_data(train_acts[name], train_labels[name]).direction for name in DATASET_NAMES}

print("\nPairwise cosine similarity between probe directions:")
for probe_name, directions in [("MM", mm_directions), ("LR", lr_directions)]:
    print(f"\n  {probe_name}Probe:")
    for i, n1 in enumerate(DATASET_NAMES):
        for j, n2 in enumerate(DATASET_NAMES):
            if j > i:
                d1 = directions[n1] / directions[n1].norm()
                d2 = directions[n2] / directions[n2].norm()
                print(f"    {n1} vs {n2}: {(d1 @ d2).item():.4f}")
Click to see the expected output
Discussion - interpreting the generalization matrix

If the model has a single "truth direction" that is shared across domains, then probes trained on any dataset should achieve high accuracy on all datasets (high off-diagonal entries). If instead the model represents truth in domain-specific ways, you'll see high diagonal (in-distribution) but low off-diagonal (out-of-distribution) accuracy.

The Geometry of Truth paper reports that LR probes generalize almost perfectly at 13B+ scale, which suggests a largely unified truth representation. However, at smaller model scales or for very different domains, you may see more variation.

As a further test, try evaluating your cities probe on the likely dataset from the geometry-of-truth repo, which contains nonfactual text with likely or unlikely continuations (code snippets, random text). If the truth probe also separates this dataset, the probe may detect text probability rather than factual truth. The paper predicts (and confirms) that a well-trained truth probe should not separate likely, because truth and likelihood are distinct features.

Solution
def compute_generalization_matrix(
    train_acts: dict[str, Float[Tensor, "n d"]],
    train_labels: dict[str, Float[Tensor, " n"]],
    test_acts: dict[str, Float[Tensor, "n d"]],
    test_labels: dict[str, Float[Tensor, " n"]],
    dataset_names: list[str],
    probe_cls: type,
) -> Float[Tensor, "n_datasets n_datasets"]:
    """
    Compute a generalization matrix: entry (i, j) is the test accuracy of a probe trained on dataset i
    and evaluated on dataset j.

    Args:
        train_acts, train_labels: Training data per dataset.
        test_acts, test_labels: Test data per dataset.
        dataset_names: Names of datasets (determines matrix ordering).
        probe_cls: Probe class to use (MMProbe or LRProbe), must have from_data and pred methods.

    Returns:
        Tensor of shape [n_datasets, n_datasets] with accuracy values.
    """
    n = len(dataset_names)
    matrix = t.zeros(n, n)
    for i, train_name in enumerate(dataset_names):
        probe = probe_cls.from_data(train_acts[train_name], train_labels[train_name])
        for j, test_name in enumerate(dataset_names):
            preds = probe.pred(test_acts[test_name])
            acc = (preds == test_labels[test_name]).float().mean().item()
            matrix[i, j] = acc
    return matrix
A note about CCS (optional)

Contrastive Consistent Search (CCS) and the unsupervised discovery debate

Before moving to causal interventions, it's worth discussing CCS (Contrastive Consistent Search), which represents one of the most exciting and debated ideas in probing research.

The CCS method

CCS is unsupervised: it requires paired positive/negative statements but no labels. The loss enforces that p(x) ≈ 1 - p(neg_x) (consistency) and that the probe is confident. This suggests truth might be discoverable without any supervision.

Key critiques

Contrast pairs do all the work. Emmons (2023) shows that PCA on contrast pair differences achieves 97% of CCS's accuracy. The CCS loss function is largely redundant; the contrast pair construction is the real innovation.

CCS may not find knowledge. Farquhar et al. (Google DeepMind, 2023) prove that for every possible binary classification, there exists a zero-loss CCS probe, so the loss has no structural preference for knowledge over arbitrary features. When they append random distractor words to contrast pairs, CCS learns to classify the distractor instead of truth.

The XOR problem. Sam Marks (2024) shows that LLMs linearly represent XORs of arbitrary features, creating a new failure mode for probes: a probe could learn truth XOR geography which works on geographic data but fails under distribution shift. Empirically probes often do generalize, likely because basic features have higher variance than their XORs, but this is a matter of degree, not a guarantee. (Note, most of this paper talks about the XOR problem in a broader way than just its relevance to CCS - still strongly recommended as further reading!)

Takeaway

The CCS literature tells a cautionary tale: probing accuracy on a single dataset tells you very little. This is why Section 3 demands causal evidence via interventions, and why cross-dataset generalization is the real test of a probe.

Bonus exercise: implement CCS

If you want to dig deeper, try implementing CCS yourself using the neg_cities dataset (which provides natural contrast pairs). The core idea: take the difference vector for each contrast pair (true statement activation minus false statement activation), then run PCA on those difference vectors. As Emmons shows, just the first principal component of the difference vectors recovers ~97% of CCS's accuracy, so you don't even need the CCS loss. You can then compare this unsupervised direction to your supervised MM and LR probes on the generalization datasets from Section 2, and see how the causal effects compare in Section 3.

For further reading: Levinstein & Herrmann (2024), "Still No Lie Detector for Language Models"; the geometry-of-truth repo includes a full CCSProbe implementation in probes.py.