Exercise - interpret these plots
Difficulty:
🔴🔴🔴⚪⚪
Importance:
🔵🔵🔵🔵⚪
You should spend up to 10-15 minutes on this exercise.
The first row shows plots of $W$. The rows are features, the columns are hidden dimensions (neurons).
The second row shows stacked weight plots: in other words, each column is a neuron, and the values in a column are the exposures of the features to that particular neuron. In these plots, each feature is colored differently based on its interference with other features (dark blue means the feature is orthogonal to all other features, and lighter colors means the sum of squared dot products with other features is large).
What is your interpretation of these plots? You should discuss things like monosemanticity / polysemanticity and how this changes with increasing sparsity.
Explanation for some of these plots
Low sparsity / high feature probability
With very low sparsity (feature prob $\approx 1$), we get no superposition: every feature is represented faithfully by a different one of the model's neurons, or not represented at all. In other words, we have pure monosemanticity .
In the heatmaps, we see a diagonal plot (up to rearrangement of neurons), i.e. each of the 5 most important features has a corresponding neuron which detects that particular feature, and no other.
In the bar charts, we see this monosemanticity represented: each neuron has just one feature exposed to it.
Medium sparsity / medium feature probability
At intermediate values, we get some monosemantic neurons, and some polysemantic ones. You should see reoccurring block patterns like these (up to rearrangements of rows and/or columns):
Can you see what geometric arrangements these correspond to? The answer is in the nested dropdown below.
Answer
The 3x2 block shows 3 features embedded in 2D space. Denoting the 3 features $i, j, k$ respectively, we can see that $j$ is represented along the direction $(1, 1)$ (orthogonal to the other two), and $i, k$ are represented as $(-1, 1)$ and $(1, -1)$ respectively (antipodal pairs).
As for the 3x3 block, it's actually 3 of the 4 points from a regular tetrahedron! This hints at an important fact which we'll explore in the next (optional) set of exercises: **superposition results in features organizing themselves into geometric structures**, which often represent uniform polyhedra.
The bar chart shows some neurons are starting to become polysemantic, with exposures to more than one feature.
High sparsity / low feature probability
With high sparsity, all neurons are polysemantic, and most / all features are represented in some capacity. The neurons aren't orthogonal (since we have way more features than neurons), but they don't need to be orthogonal: we saw in earlier sections how high sparsity can allow us to represent more features than we had dimensions. The same is true in this case.
Note - Anthropic finds that with very high sparsity, each feature will correspond to a pair of neurons. However, you may not find this for your own plots (I didn't!). This is because - as Anthropic mention - they trained many separate instances and took the ones with smallest loss, since these models proved more difficult to optimize than others in their toy model setup.
Overall, it looks a great deal like there are neuron-level phase changes from monosemantic to polysemantic as we increase the sparsity, mirroring the feature phase changes we saw earlier.
Try playing around with different settings (sparsity, importance). What kind of results do you get?
Exercise (optional) - replicate plots more faithfully
Difficulty:
🔴🔴🔴⚪⚪
Importance:
🔵⚪⚪⚪⚪
You should spend up to 10-25 minutes on this exercise, if you choose to do it.
Anthropic mention in their paper that they trained 1000 instances and chose the ones which achieved lowest loss. This is why your results might have differed from theirs, especially when the sparsity is very high / feature probability is very low.
Can you implement this "choose lowest loss" method in your own class? Some suggestions:
The most basic way would be to modify the optimize function to return the loss per instance, and also use a for loop to run several optimize calls & at the end give you the best instances for each different level of sparsity.
A much better way would be to train more instances at once (e.g. N instances per level of sparsity), then for each level of sparsity you can argmax over N at the end to get a single instance. This will be much faster (although you'll have to be careful not to train 1000 instances at once; your GPU might not support it!).
To get very fancy, you could even add another dimension to the weight matrices, corresponding to this N dimension you argmax over. Then this "taking lowest-loss instance" behavior will be automatic.
Computation in superposition
The example above was interesting, but in some ways it was also limited. The key problem here is that the model doesn't benefit from the ReLU hidden layer . Adding a ReLU does encourage the model to have a privileged basis, but since the model is trying to reconstruct the input (i.e. the identity, which is a linear function) it doesn't actually need to use the ReLU, and it will try anything it can to circumvent it - including learning biases which shift all the neurons into a positive regime where they behave linearly. This is a mark against using this toy model to study superposition.
To extend this point: we don't want to study boring linear functions like the identity, we want to study how models perform (nonlinear) computation in superposition . The MLP layer in a transformer isn't just a way to represent information faithfully and recover it; it's a way to perform computation on that information. So for this next section, we'll train a model to perform some non-linear computation. Specifically, we'll train our model to compute the absolute value of inputs $x$ .
Our data $x$ are now sampled from the range $[-1, 1]$ rather than $[0, 1]$ (otherwise calculating the absolute value would be equivalent to reconstructing the input). This is about as simple as a nonlinear function can get, since $abs(x)$ is equivalent to $\operatorname{ReLU}(x) + \operatorname{ReLU}(-x)$. But since it's nonlinear, we can be sure that the model has to use the hidden layer ReLU.
Exercise - implement NeuronComputationModel
Difficulty:
🔴🔴🔴⚪⚪
Importance:
🔵🔵🔵⚪⚪
You should spend up to 20-30 minutes on this exercise.
You should fill in the NeuronComputationModel class below. Specifically, you'll need to fill in the forward, generate_batch and calculate_loss methods. Some guidance:
The model's forward function is different - it has a ReLU hidden layer in its forward function (as described above & in the paper).
The model's data is different - see the discussion above. Your generate_batch function should be rewritten - it will be the same as the first version of this function you wrote (i.e. without correlations) except for one difference: the value is sampled uniformly from the range $[-1, 1]$ rather than $[0, 1]$.
The model's loss function is different. Rather than computing the importance-weighted $L_2$ error between the input $x$ and output $x'$, we're computing the importance-weighted $L_2$ error between $\operatorname{abs}(x)$ and $x'$. This should just require changing one line. The optimize function can stay the same, but it will now be optimizing this new loss function.
class NeuronComputationModel ( ToyModel ):
W1 : Float [ Tensor , "inst d_hidden feats" ]
W2 : Float [ Tensor , "inst feats d_hidden" ]
b_final : Float [ Tensor , "inst feats" ]
def __init__ (
self ,
cfg : ToyModelConfig ,
feature_probability : float | Tensor = 1.0 ,
importance : float | Tensor = 1.0 ,
device = device ,
):
super ( ToyModel , self ) . __init__ ()
self . cfg = cfg
if isinstance ( feature_probability , float ):
feature_probability = t . tensor ( feature_probability )
self . feature_probability = feature_probability . to ( device ) . broadcast_to (( cfg . n_inst , cfg . n_features ))
if isinstance ( importance , float ):
importance = t . tensor ( importance )
self . importance = importance . to ( device ) . broadcast_to (( cfg . n_inst , cfg . n_features ))
self . W1 = nn . Parameter ( nn . init . kaiming_uniform_ ( t . empty (( cfg . n_inst , cfg . d_hidden , cfg . n_features ))))
self . W2 = nn . Parameter ( nn . init . kaiming_uniform_ ( t . empty (( cfg . n_inst , cfg . n_features , cfg . d_hidden ))))
self . b_final = nn . Parameter ( t . zeros (( cfg . n_inst , cfg . n_features )))
self . to ( device )
def forward ( self , features : Float [ Tensor , "... inst feats" ]) -> Float [ Tensor , "... inst feats" ]:
raise NotImplementedError ()
def generate_batch ( self , batch_size ) -> Float [ Tensor , "batch instances features" ]:
raise NotImplementedError ()
def calculate_loss (
self ,
out : Float [ Tensor , "batch instances features" ],
batch : Float [ Tensor , "batch instances features" ],
) -> Float [ Tensor , "" ]:
raise NotImplementedError ()
tests . test_neuron_computation_model ( NeuronComputationModel )
Solution for forward
def forward ( self , features : Float [ Tensor , "... inst feats" ]) -> Float [ Tensor , "... inst feats" ]:
activations = F . relu (
einops . einsum ( features , self . W1 , "... inst feats, inst d_hidden feats -> ... inst d_hidden" )
)
out = F . relu (
einops . einsum ( activations , self . W2 , "... inst d_hidden, inst feats d_hidden -> ... inst feats" )
+ self . b_final
)
return out
Solution for generate_batch
def generate_batch ( self , batch_size ) -> Float [ Tensor , "batch instances features" ]:
feat_mag = 2 * t . rand (( batch_size , self . cfg . n_inst , self . cfg . n_features ), device = self . W1 . device ) - 1
feat_seed = t . rand (
( batch_size , self . cfg . n_inst , self . cfg . n_features ),
device = self . W1 . device ,
)
batch = t . where ( feat_seed < self . feature_probability , feat_mag , 0.0 )
return batch
Solution for calculate_loss
def calculate_loss (
self ,
out : Float [ Tensor , "batch instances features" ],
batch : Float [ Tensor , "batch instances features" ],
) -> Float [ Tensor , "" ]:
error = self . importance * (( batch . abs () - out ) ** 2 )
loss = einops . reduce ( error , "batch inst feats -> inst" , "mean" ) . sum ()
return loss
Once you've passed these tests, you can run the code below to make the same visualisation as above.
You should see similar patterns: with very low sparsity most/all neurons are monosemantic, but more polysemantic neurons appear as sparsity increases (until all neurons are polysemantic). Another interesting observation: in the monosemantic (or mostly monosemantic) cases, for any given feature there will be some neurons which have positive exposures to that feature and others with negative exposure. This is because some neurons are representing the value $\operatorname{ReLU}(x_i)$ and others are representing the value of $\operatorname{ReLU}(-x_i)$ (as discussed above, we require both of these to compute the absolute value).
cfg = ToyModelConfig ( n_inst = 7 , n_features = 100 , d_hidden = 40 )
importance = 0.8 ** t . arange ( 1 , 1 + cfg . n_features )
feature_probability = t . tensor ([ 1.0 , 0.3 , 0.1 , 0.03 , 0.01 , 0.003 , 0.001 ])
model = NeuronComputationModel (
cfg = cfg ,
device = device ,
importance = importance [ None , :],
feature_probability = feature_probability [:, None ],
)
model . optimize ()
utils . plot_features_in_Nd (
model . W1 ,
height = 800 ,
width = 1600 ,
title = f "Neuron computation model: n_features = { cfg . n_features } , d_hidden = { cfg . d_hidden } , I<sub>i</sub> = 0.75<sup>i</sup>" ,
subplot_titles = [ f "1 - S = { i : .3f } " for i in feature_probability . squeeze ()],
neuron_plot = True ,
)
Click to see the expected output
To further confirm that this is happening, we can color the values in the bar chart discretely by feature, rather than continuously by the polysemanticity of that feature. We'll use a feature probability of 50% for this visualisation, which is high enough to make sure each neuron is monosemantic. You should find that the input weights $W_1$ form pairs of antipodal neurons (i.e. ones with positive / negative exposures to that feature direction), but both of these neurons have positive output weights $W_2$ for that feature.
cfg = ToyModelConfig ( n_inst = 6 , n_features = 20 , d_hidden = 10 )
importance = 0.8 ** t . arange ( 1 , 1 + cfg . n_features )
feature_probability = 0.5
model = NeuronComputationModel (
cfg = cfg ,
device = device ,
importance = importance [ None , :],
feature_probability = feature_probability ,
)
model . optimize ()
utils . plot_features_in_Nd_discrete (
W1 = model . W1 ,
W2 = model . W2 ,
title = "Neuron computation model (colored discretely, by feature)" ,
legend_names = [ f "I<sub> { i } </sub> = { importance . squeeze ()[ i ] : .3f } " for i in range ( cfg . n_features )],
)
Click to see the expected output
Bonus - the asymmetric superposition motif
In the Asymmetric Superposition Motif section from Anthropic's paper, they discuss a particular quirk of this toy model in detail. Their section explains it in more detail than we will here (including some visual explanations), but we'll provide a relatively brief explanation here.
When we increase sparsity in our model & start to get superposed features, we don't always have monosemantic neurons which each calculate either $\operatorname{ReLU}(x_i)$ or $\operatorname{ReLU}(-x_i)$ for some feature $i$. Instead, we sometimes have asymmetric superposition , where a single neuron detects two different features $i$ and $j$, and stores these features with different magnitudes (assume the $W_1$ vector for feature $i$ is much larger). The $W_2$ vectors have flipped magnitudes (i.e. the vector for $j$ is much larger). When $i$ is present and $j$ is not, there's no problem, because the output for feature $i$ is large * small (correct size) and for $j$ is small * small (near zero). But when $j$ is present and $i$ is not, the output for feature $j$ is small * large (correct size) and for $i$ is large * large (much larger than it should be). In particular, this is bad when the sign of output for $i$ is positive. The model fixes this by repurposing another neuron to correct for the case when $j$ is present and $i$ is not. We omit the exact mechanism, but it takes advantage of the fact that the model has a ReLU at the very end, so it doesn't matter if output for a feature is very large and negative (the loss will be truncated at zero), but being large and positive is very bad.
You should read the linked section of the Anthropic paper for details. We've given you code below to replicate the results of this plot - note that some plots will display the kind of asymmetric superposition described above, whereas others will simply have a single pair of neurons for each feature - you might have to fun a few random seeds to observe something exactly resembling Anthropic's plots. Can you understand what the output represents? Can you play around with the hyperparameters to see how this behaviour varies (e.g. with different feature probability or importance)?
cfg = ToyModelConfig ( n_inst = 6 , n_features = 10 , d_hidden = 10 )
importance = 0.8 ** t . arange ( 1 , 1 + cfg . n_features )
feature_probability = 0.35 # slightly lower feature probability, to encourage a small degree of superposition
model = NeuronComputationModel (
cfg = cfg ,
device = device ,
importance = importance [ None , :],
feature_probability = feature_probability ,
)
model . optimize ()
utils . plot_features_in_Nd_discrete (
W1 = model . W1 ,
W2 = model . W2 ,
title = "Neuron computation model (colored discretely, by feature)" ,
legend_names = [ f "I<sub> { i } </sub> = { importance . squeeze ()[ i ] : .3f } " for i in range ( cfg . n_features )],
)
Click to see the expected output
Summary - what have we learned?
With toy models like this, it's important to make sure we take away generalizable lessons, rather than just details of the training setup.
The core things to take away form this paper are:
What superposition is
How it varies over feature importance and sparsity
How it varies when we have correlated or anticorrelated features
The difference between neuron and bottleneck superposition (or equivalently "computational and representational supervision")