2️⃣ TMS: Superposition in a Privileged Basis

Learning Objectives
  • Understand the difference between neuron vs bottleneck superposition (or computational vs representational superposition)
  • Learn how models can perform computation in superposition, via the example f(x) = abs(x)

Introduction

So far, we've explored superposition in a model without a privileged basis. We can rotate the hidden activations arbitrarily and, as long as we rotate all the weights, have the exact same model behavior. That is, for any ReLU output model with weights $W$, we could take an arbitrary orthogonal matrix $O$ and consider the model $W' = OW$. Since $(OW)^T(OW) = W^T W$, the result would be an identical model!

Models without a privileged basis are elegant, and can be an interesting analogue for certain neural network representations which don't have a privileged basis – word embeddings, or the transformer residual stream. But we'd also (and perhaps primarily) like to understand neural network representations where there are neurons which do impose a privileged basis, such as transformer MLP layers or conv net neurons.

Our goal in this section is to explore the simplest toy model which gives us a privileged basis. There are at least two ways we could do this: we could add an activation function or apply $L_1$ regularization to the hidden layer. We'll focus on adding an activation function, since the representation we are most interested in understanding is hidden layers with neurons, such as the transformer MLP layer.

This gives us the following "ReLU hidden layer" model. It's the simplest one we can use which is still likely to give us a privileged basis; we just take our previous setup and apply ReLU to the hidden layer.

$$ \begin{aligned} h & =\operatorname{ReLU}(W x) \\ x^{\prime} & =\operatorname{ReLU}\left(W^T h+b\right) \end{aligned} $$

Exercise - implement NeuronModel

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to ~10 minutes on this exercise.

In this section, you'll replicate the first set of results in the Anthropic paper on studying superposition in a privileged basis. To do this, you'll need a new NeuronModel class. It can inherit most methods from the Model class, but you'll need to redefine the forward method to include an intermediate ReLU.

class NeuronModel(ToyModel):
    def forward(self, features: Float[Tensor, "... inst feats"]) -> Float[Tensor, "... inst feats"]:
        raise NotImplementedError()


tests.test_neuron_model(NeuronModel)
Solution
class NeuronModel(ToyModel):
    def forward(self, features: Float[Tensor, "... inst feats"]) -> Float[Tensor, "... inst feats"]:
        activations = F.relu(
            einops.einsum(
                features, self.W, "... inst feats, inst d_hidden feats -> ... inst d_hidden"
            )
        )
        out = F.relu(
            einops.einsum(
                activations, self.W, "... inst d_hidden, inst d_hidden feats -> ... inst feats"
            )
            + self.b_final
        )
        return out

Once you've passed these tests, you can run the cells below to train the model in the same way as before. We use 7 instances, each with 10 features (with probability ranging between $0.75$ and $0.01$ across instances), and feature importance within each instance decaying as $0.75^{i}$.

We also visualize the matrix $W$. In these plots, we make it so the top-row visualisation is of $W$ rather than $W^T W$ - we can get away with this now because (unlike before) the individual elements of $W$ are meaningful. We're working with a privileged basis, and $W$ connects features to basis-aligned neurons.

cfg = ToyModelConfig(n_inst=7, n_features=10, d_hidden=5)

importance = 0.75 ** t.arange(1, 1 + cfg.n_features)
feature_probability = t.tensor([0.75, 0.35, 0.15, 0.1, 0.06, 0.02, 0.01])

model = NeuronModel(
    cfg=cfg,
    device=device,
    importance=importance[None, :],
    feature_probability=feature_probability[:, None],
)
model.optimize(steps=10_000)


utils.plot_features_in_Nd(
    model.W,
    height=600,
    width=1000,
    title=f"Neuron model: {cfg.n_features=}, {cfg.d_hidden=}, I<sub>i</sub> = 0.75<sup>i</sup>",
    subplot_titles=[f"1 - S = {i:.2f}" for i in feature_probability.squeeze()],
    neuron_plot=True,
)
Click to see the expected output

Exercise - interpret these plots

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 10-15 minutes on this exercise.

The first row shows plots of $W$. The rows are features, the columns are hidden dimensions (neurons).

The second row shows stacked weight plots: in other words, each column is a neuron, and the values in a column are the exposures of the features to that particular neuron. In these plots, each feature is colored differently based on its interference with other features (dark blue means the feature is orthogonal to all other features, and lighter colors means the sum of squared dot products with other features is large).

What is your interpretation of these plots? You should discuss things like monosemanticity / polysemanticity and how this changes with increasing sparsity.

Explanation for some of these plots Low sparsity / high feature probability

With very low sparsity (feature prob $\approx 1$), we get no superposition: every feature is represented faithfully by a different one of the model's neurons, or not represented at all. In other words, we have pure monosemanticity.

In the heatmaps, we see a diagonal plot (up to rearrangement of neurons), i.e. each of the 5 most important features has a corresponding neuron which detects that particular feature, and no other.

In the bar charts, we see this monosemanticity represented: each neuron has just one feature exposed to it.

Medium sparsity / medium feature probability

At intermediate values, we get some monosemantic neurons, and some polysemantic ones. You should see reoccurring block patterns like these (up to rearrangements of rows and/or columns):

Can you see what geometric arrangements these correspond to? The answer is in the nested dropdown below.

Answer

The 3x2 block shows 3 features embedded in 2D space. Denoting the 3 features $i, j, k$ respectively, we can see that $j$ is represented along the direction $(1, 1)$ (orthogonal to the other two), and $i, k$ are represented as $(-1, 1)$ and $(1, -1)$ respectively (antipodal pairs).

As for the 3x3 block, it's actually 3 of the 4 points from a regular tetrahedron! This hints at an important fact which we'll explore in the next (optional) set of exercises: superposition results in features organizing themselves into geometric structures, which often represent uniform polyhedra.

The bar chart shows some neurons are starting to become polysemantic, with exposures to more than one feature. **High sparsity / low feature probability** With high sparsity, all neurons are polysemantic, and most / all features are represented in some capacity. The neurons aren't orthogonal (since we have way more features than neurons), but they don't need to be orthogonal: we saw in earlier sections how high sparsity can allow us to represent more features than we had dimensions. The same is true in this case. Note - Anthropic [finds](https://transformer-circuits.pub/2022/toy_model/index.html#privileged-basis:~:text=The%20solutions%20are%20visualized%20below) that with very high sparsity, each feature will correspond to a pair of neurons. However, you may not find this for your own plots (I didn't!). This is because - as Anthropic mention - they trained many separate instances and took the ones with smallest loss, since these models proved more difficult to optimize than others in their toy model setup. Overall, it looks a great deal like there are **neuron-level phase changes from monosemantic to polysemantic** as we increase the sparsity, mirroring the feature phase changes we saw earlier.

Try playing around with different settings (sparsity, importance). What kind of results do you get?

Exercise (optional) - replicate plots more faithfully

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵⚪⚪⚪⚪
You should spend up to 10-25 minutes on this exercise, if you choose to do it.

Anthropic mention in their paper that they trained 1000 instances and chose the ones which achieved lowest loss. This is why your results might have differed from theirs, especially when the sparsity is very high / feature probability is very low.

Can you implement this "choose lowest loss" method in your own class? Some suggestions:

  • The most basic way would be to modify the optimize function to return the loss per instance, and also use a for loop to run several optimize calls & at the end give you the best instances for each different level of sparsity.
  • A much better way would be to train more instances at once (e.g. N instances per level of sparsity), then for each level of sparsity you can argmax over N at the end to get a single instance. This will be much faster (although you'll have to be careful not to train 1000 instances at once; your GPU might not support it!).
  • To get very fancy, you could even add another dimension to the weight matrices, corresponding to this N dimension you argmax over. Then this "taking lowest-loss instance" behavior will be automatic.

Computation in superposition

The example above was interesting, but in some ways it was also limited. The key problem here is that the model doesn't benefit from the ReLU hidden layer. Adding a ReLU does encourage the model to have a privileged basis, but since the model is trying to reconstruct the input (i.e. the identity, which is a linear function) it doesn't actually need to use the ReLU, and it will try anything it can to circumvent it - including learning biases which shift all the neurons into a positive regime where they behave linearly. This is a mark against using this toy model to study superposition.

To extend this point: we don't want to study boring linear functions like the identity, we want to study how models perform (nonlinear) computation in superposition. The MLP layer in a transformer isn't just a way to represent information faithfully and recover it; it's a way to perform computation on that information. So for this next section, we'll train a model to perform some non-linear computation. Specifically, we'll train our model to compute the absolute value of inputs $x$.

Our data $x$ are now sampled from the range $[-1, 1]$ rather than $[0, 1]$ (otherwise calculating the absolute value would be equivalent to reconstructing the input). This is about as simple as a nonlinear function can get, since $abs(x)$ is equivalent to $\operatorname{ReLU}(x) + \operatorname{ReLU}(-x)$. But since it's nonlinear, we can be sure that the model has to use the hidden layer ReLU.

Exercise - implement NeuronComputationModel

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 20-30 minutes on this exercise.

You should fill in the NeuronComputationModel class below. Specifically, you'll need to fill in the forward, generate_batch and calculate_loss methods. Some guidance:

  • The model's forward function is different - it has a ReLU hidden layer in its forward function (as described above & in the paper).
  • The model's data is different - see the discussion above. Your generate_batch function should be rewritten - it will be the same as the first version of this function you wrote (i.e. without correlations) except for one difference: the value is sampled uniformly from the range $[-1, 1]$ rather than $[0, 1]$.
  • The model's loss function is different. Rather than computing the importance-weighted $L_2$ error between the input $x$ and output $x'$, we're computing the importance-weighted $L_2$ error between $\operatorname{abs}(x)$ and $x'$. This should just require changing one line. The optimize function can stay the same, but it will now be optimizing this new loss function.
class NeuronComputationModel(ToyModel):
    W1: Float[Tensor, "inst d_hidden feats"]
    W2: Float[Tensor, "inst feats d_hidden"]
    b_final: Float[Tensor, "inst feats"]

    def __init__(
        self,
        cfg: ToyModelConfig,
        feature_probability: float | Tensor = 1.0,
        importance: float | Tensor = 1.0,
        device=device,
    ):
        super(ToyModel, self).__init__()
        self.cfg = cfg

        if isinstance(feature_probability, float):
            feature_probability = t.tensor(feature_probability)
        self.feature_probability = feature_probability.to(device).broadcast_to(
            (cfg.n_inst, cfg.n_features)
        )
        if isinstance(importance, float):
            importance = t.tensor(importance)
        self.importance = importance.to(device).broadcast_to((cfg.n_inst, cfg.n_features))

        self.W1 = nn.Parameter(
            nn.init.kaiming_uniform_(t.empty((cfg.n_inst, cfg.d_hidden, cfg.n_features)))
        )
        self.W2 = nn.Parameter(
            nn.init.kaiming_uniform_(t.empty((cfg.n_inst, cfg.n_features, cfg.d_hidden)))
        )
        self.b_final = nn.Parameter(t.zeros((cfg.n_inst, cfg.n_features)))
        self.to(device)

    def forward(self, features: Float[Tensor, "... inst feats"]) -> Float[Tensor, "... inst feats"]:
        raise NotImplementedError()

    def generate_batch(self, batch_size) -> Float[Tensor, "batch instances features"]:
        raise NotImplementedError()

    def calculate_loss(
        self,
        out: Float[Tensor, "batch instances features"],
        batch: Float[Tensor, "batch instances features"],
    ) -> Float[Tensor, ""]:
        raise NotImplementedError()


tests.test_neuron_computation_model(NeuronComputationModel)
Solution for forward
def forward(self, features: Float[Tensor, "... inst feats"]) -> Float[Tensor, "... inst feats"]:
    activations = F.relu(
        einops.einsum(features, self.W1, "... inst feats, inst d_hidden feats -> ... inst d_hidden")
    )
    out = F.relu(
        einops.einsum(activations, self.W2, "... inst d_hidden, inst feats d_hidden -> ... inst feats")
        + self.b_final
    )
    return out
Solution for generate_batch
def generate_batch(self, batch_size) -> Float[Tensor, "batch instances features"]:
    feat_mag = 2 * t.rand((batch_size, self.cfg.n_inst, self.cfg.n_features), device=self.W1.device) - 1
    feat_seed = t.rand(
        (batch_size, self.cfg.n_inst, self.cfg.n_features),
        device=self.W1.device,
    )
    batch = t.where(feat_seed < self.feature_probability, feat_mag, 0.0)
    return batch
Solution for calculate_loss
def calculate_loss(
    self,
    out: Float[Tensor, "batch instances features"],
    batch: Float[Tensor, "batch instances features"],
) -> Float[Tensor, ""]:
    error = self.importance * ((batch.abs() - out) ** 2)
    loss = einops.reduce(error, "batch inst feats -> inst", "mean").sum()
    return loss

Once you've passed these tests, you can run the code below to make the same visualisation as above.

You should see similar patterns: with very low sparsity most/all neurons are monosemantic, but more polysemantic neurons appear as sparsity increases (until all neurons are polysemantic). Another interesting observation: in the monosemantic (or mostly monosemantic) cases, for any given feature there will be some neurons which have positive exposures to that feature and others with negative exposure. This is because some neurons are representing the value $\operatorname{ReLU}(x_i)$ and others are representing the value of $\operatorname{ReLU}(-x_i)$ (as discussed above, we require both of these to compute the absolute value).

cfg = ToyModelConfig(n_inst=7, n_features=100, d_hidden=40)

importance = 0.8 ** t.arange(1, 1 + cfg.n_features)
feature_probability = t.tensor([1.0, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001])

model = NeuronComputationModel(
    cfg=cfg,
    device=device,
    importance=importance[None, :],
    feature_probability=feature_probability[:, None],
)
model.optimize()


utils.plot_features_in_Nd(
    model.W1,
    height=800,
    width=1600,
    title=f"Neuron computation model: n_features = {cfg.n_features}, d_hidden = {cfg.d_hidden}, I<sub>i</sub> = 0.75<sup>i</sup>",
    subplot_titles=[f"1 - S = {i:.3f}" for i in feature_probability.squeeze()],
    neuron_plot=True,
)
Click to see the expected output

To further confirm that this is happening, we can color the values in the bar chart discretely by feature, rather than continuously by the polysemanticity of that feature. We'll use a feature probability of 50% for this visualisation, which is high enough to make sure each neuron is monosemantic. You should find that the input weights $W_1$ form pairs of antipodal neurons (i.e. ones with positive / negative exposures to that feature direction), but both of these neurons have positive output weights $W_2$ for that feature.

cfg = ToyModelConfig(n_inst=6, n_features=20, d_hidden=10)

importance = 0.8 ** t.arange(1, 1 + cfg.n_features)
feature_probability = 0.5

model = NeuronComputationModel(
    cfg=cfg,
    device=device,
    importance=importance[None, :],
    feature_probability=feature_probability,
)
model.optimize()


utils.plot_features_in_Nd_discrete(
    W1=model.W1,
    W2=model.W2,
    title="Neuron computation model (colored discretely, by feature)",
    legend_names=[
        f"I<sub>{i}</sub> = {importance.squeeze()[i]:.3f}" for i in range(cfg.n_features)
    ],
)
Click to see the expected output

Bonus - the asymmetric superposition motif

In the Asymmetric Superposition Motif section from Anthropic's paper, they discuss a particular quirk of this toy model in detail. Their section explains it in more detail than we will here (including some visual explanations), but we'll provide a relatively brief explanation here.

When we increase sparsity in our model & start to get superposed features, we don't always have monosemantic neurons which each calculate either $\operatorname{ReLU}(x_i)$ or $\operatorname{ReLU}(-x_i)$ for some feature $i$. Instead, we sometimes have asymmetric superposition, where a single neuron detects two different features $i$ and $j$, and stores these features with different magnitudes (assume the $W_1$ vector for feature $i$ is much larger). The $W_2$ vectors have flipped magnitudes (i.e. the vector for $j$ is much larger). When $i$ is present and $j$ is not, there's no problem, because the output for feature $i$ is large * small (correct size) and for $j$ is small * small (near zero). But when $j$ is present and $i$ is not, the output for feature $j$ is small * large (correct size) and for $i$ is large * large (much larger than it should be). In particular, this is bad when the sign of output for $i$ is positive. The model fixes this by repurposing another neuron to correct for the case when $j$ is present and $i$ is not. We omit the exact mechanism, but it takes advantage of the fact that the model has a ReLU at the very end, so it doesn't matter if output for a feature is very large and negative (the loss will be truncated at zero), but being large and positive is very bad.

You should read the linked section of the Anthropic paper for details. We've given you code below to replicate the results of this plot - note that some plots will display the kind of asymmetric superposition described above, whereas others will simply have a single pair of neurons for each feature - you might have to fun a few random seeds to observe something exactly resembling Anthropic's plots. Can you understand what the output represents? Can you play around with the hyperparameters to see how this behaviour varies (e.g. with different feature probability or importance)?

cfg = ToyModelConfig(n_inst=6, n_features=10, d_hidden=10)

importance = 0.8 ** t.arange(1, 1 + cfg.n_features)
feature_probability = (
    0.35  # slightly lower feature probability, to encourage a small degree of superposition
)

model = NeuronComputationModel(
    cfg=cfg,
    device=device,
    importance=importance[None, :],
    feature_probability=feature_probability,
)
model.optimize()


utils.plot_features_in_Nd_discrete(
    W1=model.W1,
    W2=model.W2,
    title="Neuron computation model (colored discretely, by feature)",
    legend_names=[
        f"I<sub>{i}</sub> = {importance.squeeze()[i]:.3f}" for i in range(cfg.n_features)
    ],
)
Click to see the expected output

Summary - what have we learned?

With toy models like this, it's important to make sure we take away generalizable lessons, rather than just details of the training setup.

The core things to take away form this paper are:

  • What superposition is
  • How it varies over feature importance and sparsity
  • How it varies when we have correlated or anticorrelated features
  • The difference between neuron and bottleneck superposition (or equivalently "computational and representational supervision")