1️⃣ TMS: Superposition in a Nonprivileged Basis Learning Objectives Understand the concept of superposition, and how it helps models represent a larger set of features Understand the difference between superposition and polysemanticity Learn how sparsity contributes to superposition Understand the idea of the feature importance curve Learn how feature correlation changes the nature and degree of superposition Toy Model setup In this section, we'll be examining & running experiments on the toy model studied in Anthropic's paper. You can follow along with the paper from the Demonstrating Superposition section onwards; it will approximately follow the order of the sections in this notebook. This paper presented a very rudimentary model for bottleneck superposition - when you try and represent more than $n$ features in a vector space of dimension $n$. The model is as follows: We take a 5-dimensional input $x$ We map it down into 2D space We map it back up into 5D space (using the transpose of the first matrix) We add a bias and ReLU $$ \begin{aligned} h &= W x \\ x' &= \operatorname{ReLU}(W^T h + b) \end{aligned} $$ What's the motivation for this setup? The input $x$ represents our five features (they're uniformly sampled between 0 and 1). Each feature can have importance and sparsity. Recall our earlier definitions: Importance = how useful is this feature for achieving lower loss? Sparsity = how frequently is it in the input data? This is realised in our toy model as follows: Importance = the coefficient on the weighted mean squared error between the input and output, which we use for training the model In other words, our loss function is $L = \sum_x \sum_i I_i (x_i - x_i^\prime)^2$, where $I_i$ is the importance of feature $i$. Sparsity = the probability of the corresponding element in $x$ being zero In other words, this affects the way our training data is generated (see the method generate_batch in the Module class below) We often refer to feature probability (1 minus sparsity) rather than sparsity The justification for using $W^T W$ is as follows: we can think of $W$ (which is a matrix of shape (2, 5)) as a grid of "overlap values" between the features and bottleneck dimensions. The values of the 5x5 matrix $W^T W$ are the dot products between the 2D representations of each pair of features. To make this intuition clearer, imagine each of the columns of $W$ were unit vectors, then $W^T W$ would be a matrix of cosine similarities between the features (with diagonal elements equal to 1, because the similarity of a feature with itself is 1). To see this for yourself: t.manual_seed(2) W = t.randn(2, 5) W_normed = W / W.norm(dim=0, keepdim=True) imshow( W_normed.T @ W_normed, title="Cosine similarities of each pair of 2D feature embeddings", width=600, ) To put it another way - if the columns of $W$ were orthogonal, then $W^T W$ would be the identity. This can't actually be the case because $W$ is a 2x5 matrix, but its columns can be "nearly orthgonal" in the sense of having pairwise cosine similarities close to 0. Question - can you prove that W.T @ W can't be the identity when W has more columns than rows (or alternatively, when the hidden dimension is strictly smaller than the input dimension)? Proof #1: the rank of a matrix product $AB$ is upper-bounded by the maximum of the two factors $A$ and $B$. In the case of $W^T W$, both matrices have rank at most 2, so the product has rank at most 2. Proof #2: for any vector $x$, $W^T W x = W^T (Wx)$ is in the span of the columns of $W^T$, which is vector space with rank 2. Another nice thing about using two bottleneck dimensions is that we get to visualise our output! We've got a few helper functions for this purpose. utils.plot_features_in_2d( W_normed.unsqueeze(0), # shape [instances=1 d_hidden=2 features=5] ) Compare this plot to the `imshow` plot above, and make sure you understand what's going on here (and how the two plots relate to each other). A lot of the subsequent exercises run with this idea of a geometric interpretation of the model's features and bottleneck dimensions. Help - I'm confused about how these plots work. As mentioned, you can view $W$ as being a set of five 2D vectors, one for each of our five features. The heatmap shows us the cosine similarities between each pair of these vectors, and the second plot shows us these five vectors in 2D space. In the example above, we can see two pairs of vectors (the 1st & 2nd, and the 0th & 4th) have very high cosine similarity. This is reflected in the 2D plot, where these features are very close to each other (the 0th feature is the darkest color, the 4th feature is the lightest). ### Defining our model Below is some code for your model (with most methods not filled out yet). It should be familiar to you if you've already built simple neural networks earlier in this course. Some notes on the initialization method, which is filled out for you: #### Weights & instances The `Config` class has an `n_inst` class. This is so we can optimize multiple models at once in a single training loop (this'll be useful later on). You should treat this as basically like a batch dimension for your weights: each of your weights/biases will actually be `n_inst` separate weights/biases stacked along the zeroth dimension, and each of these will be trained independently, on different data, in parallel (using the same optimizer). We initialize weights `W` and `b_final`, which correspond to $W$ and $b$ in the Anthropic paper. #### Sparsity & Importance The `feature_probability` argument tells us the probability that any given feature will be active. We have the relation `feature_probability = 1 - sparsity`. We'll often be dealing with very small feature probabilities $p = 1 - S \approx 0$, i.e. sparsities close to 1. The feature probability is used to generate our training data; the importance is used in our loss function (see later for both of these). The default is `feature_probability = 0.01`, i.e. each feaure is present with probability 1%. The `importance` argument is used when calculating loss (see later exercise). The default is `importance = None` which results in uniform importance. In the `__init__` method, we have code to broadcast `feature_probability` and `importance`, so that by the end they both always have shape `(n_inst, n_features)`. ### Exercise - implement `forward` Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to 10-20 minutes on this exercise. For now, you just need to fill in the `forward` method. As the exercises go on, you'll fill in some more of these functions, but for now you can ignore the others. def linear_lr(step, steps): return 1 - (step / steps) def constant_lr(*_): return 1.0 def cosine_decay_lr(step, steps): return np.cos(0.5 * np.pi * step / (steps - 1)) @dataclass class ToyModelConfig: # We optimize n_inst models in a single training loop to let us sweep over sparsity or importance # curves efficiently. You should treat the number of instances `n_inst` like a batch dimension, # but one which is built into our training setup. Ignore the latter 3 arguments for now, they'll # return in later exercises. n_inst: int n_features: int = 5 d_hidden: int = 2 n_correlated_pairs: int = 0 n_anticorrelated_pairs: int = 0 feat_mag_distn: Literal["unif", "normal"] = "unif" class ToyModel(nn.Module): W: Float[Tensor, "inst d_hidden feats"] b_final: Float[Tensor, "inst feats"] # Our linear map (for a single instance) is x -> ReLU(W.T @ W @ x + b_final) def __init__( self, cfg: ToyModelConfig, feature_probability: float | Tensor = 0.01, importance: float | Tensor = 1.0, device=device, ): super(ToyModel, self).__init__() self.cfg = cfg if isinstance(feature_probability, float): feature_probability = t.tensor(feature_probability) self.feature_probability = feature_probability.to(device).broadcast_to( (cfg.n_inst, cfg.n_features) ) if isinstance(importance, float): importance = t.tensor(importance) self.importance = importance.to(device).broadcast_to((cfg.n_inst, cfg.n_features)) self.W = nn.Parameter( nn.init.xavier_normal_(t.empty((cfg.n_inst, cfg.d_hidden, cfg.n_features))) ) self.b_final = nn.Parameter(t.zeros((cfg.n_inst, cfg.n_features))) self.to(device) def forward( self, features: Float[Tensor, "... inst feats"], ) -> Float[Tensor, "... inst feats"]: """ Performs a single forward pass. For a single instance, this is given by: x -> ReLU(W.T @ W @ x + b_final) """ raise NotImplementedError() def generate_batch(self, batch_size: int) -> Float[Tensor, "batch inst feats"]: """ Generates a batch of data of shape (batch_size, n_instances, n_features). """ # You'll fill this in later raise NotImplementedError() def calculate_loss( self, out: Float[Tensor, "batch inst feats"], batch: Float[Tensor, "batch inst feats"], ) -> Float[Tensor, ""]: """ Calculates the loss for a given batch (as a scalar tensor), using this loss described in the Toy Models of Superposition paper: https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss Note, `self.importance` is guaranteed to broadcast with the shape of `out` and `batch`. """ # You'll fill this in later raise NotImplementedError() def optimize( self, batch_size: int = 1024, steps: int = 5_000, log_freq: int = 50, lr: float = 1e-3, lr_scale: Callable[[int, int], float] = constant_lr, ): """ Optimizes the model using the given hyperparameters. """ optimizer = t.optim.Adam(self.parameters(), lr=lr) progress_bar = tqdm(range(steps)) for step in progress_bar: # Update learning rate step_lr = lr * lr_scale(step, steps) for group in optimizer.param_groups: group["lr"] = step_lr # Optimize optimizer.zero_grad() batch = self.generate_batch(batch_size) out = self(batch) loss = self.calculate_loss(out, batch) loss.backward() optimizer.step() # Display progress bar if step % log_freq == 0 or (step + 1 == steps): progress_bar.set_postfix(loss=loss.item() / self.cfg.n_inst, lr=step_lr) tests.test_model(ToyModel) Solution def forward( self, features: Float[Tensor, "... inst feats"], ) -> Float[Tensor, "... inst feats"]: """ Performs a single forward pass. For a single instance, this is given by: x -> ReLU(W.T @ W @ x + b_final) """ h = einops.einsum( features, self.W, "... inst feats, inst hidden feats -> ... inst hidden" ) out = einops.einsum( h, self.W, "... inst hidden, inst hidden feats -> ... inst feats" ) return F.relu(out + self.b_final) ### Exercise - implement `generate_batch` Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to 10-15 minutes on this exercise. Next, you should implement the function `generate_batch` above. This should return a tensor of shape `(n_batch, instances, features)`, where: * The `instances` and `features` values are taken from the model config, * Each feature is present with probability `self.feature_probability`, * For each present feature, its **magnitude** is sampled from a uniform distribution between 0 and 1. Make sure you understand this function well (we recommend looking at the solutions even after you pass the tests), because we'll be making more complicated versions of this function in the section on correlations. Remember, you can assume `model.feature_probability` has shape `(n_inst, n_features)`. When you've implemented this function, run the code below to test it. # Go back up and edit your `ToyModel.generate_batch` method, then run the test below tests.test_generate_batch(ToyModel) Solution def generate_batch(self: ToyModel, batch_size: int) -> Float[Tensor, "batch inst feats"]: """ Generates a batch of data of shape (batch_size, n_instances, n_features). """ batch_shape = (batch_size, self.cfg.n_inst, self.cfg.n_features) feat_mag = t.rand(batch_shape, device=self.W.device) feat_seeds = t.rand(batch_shape, device=self.W.device) return t.where(feat_seeds <= self.feature_probability, feat_mag, 0.0) ToyModel.generate_batch = generate_batch ## Training our model The details of training aren't very conceptually important, so we've given you most of the code for this (in the `optimize` method). We use **learning rate schedulers** to control the learning rate as the model trains - you'll use this later on during the RL chapter. ### Exercise - implement `calculate_loss` Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to 5-10 minutes on this exercise. You should fill in the `calculate_loss` function below. The loss function **for a single instance** is given by: $$ L=\frac{1}{BF}\sum_x \sum_i I_i\left(x_i-x_i^{\prime}\right)^2 $$ where: * $B$ is the batch size, * $F$ is the number of features, * $x_i$ are the inputs and $x_i'$ are the model's outputs, * $I_i$ is the importance of feature $i$, * $\sum_i$ is a sum over features, * $\sum_x$ is a sum over the elements in the batch. For the general case, we sum this formula over all instances. Question - why do you think we take the mean over the feature and batch dimensions, but we sum over the instances dimension? We take the mean over batch size because this is standard for loss functions (and means we don't have to use a different learning rate for different batch sizes). We take the mean over the feature dimension because that's [normal for MSE loss](https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html). We sum over the instances dimension because we want to train each instance independently, and at the same rate as we would train a single instance. # Go back up and edit your `ToyModel.calculate_loss` method, then run the test below tests.test_calculate_loss(ToyModel) Solution def calculate_loss( self: ToyModel, out: Float[Tensor, "batch inst feats"], batch: Float[Tensor, "batch inst feats"], ) -> Float[Tensor, ""]: """ Calculates the loss for a given batch, using this loss described in the Toy Models paper: https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss Remember, self.importance will always have shape (n_inst, n_features). """ error = self.importance * ((batch - out) ** 2) loss = einops.reduce(error, "batch inst feats -> inst", "mean").sum() return loss ToyModel.calculate_loss = calculate_loss Now, we'll reproduce a version of the figure from the introduction. A few notes: * The `importance` argument is the same for all instances. It takes values between 1 and ~0.66 for each feature (so for every instance, there will be some features which are more important than others). * The `feature_probability` is the same for all features, but it varies across instances. In other words, we're runnning several different experiments at once, and we can compare the effect of having larger feature sparsity in these experiments. cfg = ToyModelConfig(n_inst=8, n_features=5, d_hidden=2) # importance varies within features for each instance importance = 0.9 ** t.arange(cfg.n_features) # sparsity is the same for all features in a given instance, but varies over instances feature_probability = 50 ** -t.linspace(0, 1, cfg.n_inst) line( importance, width=600, height=400, title="Importance of each feature (same over all instances)", labels={"y": "Feature importance", "x": "Feature"}, ) line( feature_probability, width=600, height=400, title="Feature probability (varied over instances)", labels={"y": "Probability", "x": "Instance"}, ) model = ToyModel( cfg=cfg, device=device, importance=importance[None, :], feature_probability=feature_probability[:, None], ) model.optimize() utils.plot_features_in_2d( model.W, colors=model.importance, title=f"Superposition: {cfg.n_features} features represented in 2D space", subplot_titles=[f"1 - S = {i:.3f}" for i in feature_probability.squeeze()], ) Click to see the expected output for the plot_features_in_2d function ### Exercise - interpret these diagrams Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵🔵⚪ You should spend up to 10-20 minutes on this exercise. Remember that for all these diagrams, the darker colors have lower importance and the lighter colors have higher importance. Also, the sparsity of all features is increasing as we move from left to right (at the far left there is no sparsity, at the far right feature probability is 5% for all features, i.e. sparsity of 95%). Hint For low sparsity, think about what the model would learn to do if all 5 features were present all the time. What's the best our model could do in this case, and how does that relate to the importance values? For high sparsity, think about what the model would learn to do if there was always exactly one feature present. Does this make interference between features less of a problem? Answer (intuitive) When there is no sparsity, the model can never represent more than 2 features faithfully, so it makes sense for it to only represent the two most important features. It stores them orthogonally in 2D space, and sets the other 3 features to zero. This way, it can reconstruct these two features perfectly, and ignores all the rest. When there is high sparsity, we get a pentagon structure. Most of the time at most one of these five features will be active, which helps avoid interference between features. When we try to recover our initial features by projecting our point in 2D space onto these five directions, most of the time when feature $i$ is present, we can be confident that our projection onto the $i$-th feature direction only captures this feature, rather than being affected by the presence of other features. We omit the mathematical details here. The key idea here is that two forces are competing in our model: feature benefit (representing more thing is good!), and interference (representing things non-orthogonally is bad). The higher the sparsity, the more we can reduce the negative impact of interference, and so the trade-off skews towards "represent more features, non-orthogonally". We can also generate a batch and visualise its embedding. Most interestingly, you should see that in the plots with high sparsity (to the right), we very rarely have interference between the five features, because most often $\leq 1$ of those features is present, and the model can recover it by projecting along the corresponding feature dimension without losing any information. with t.inference_mode(): batch = model.generate_batch(200) hidden = einops.einsum( batch, model.W, "batch instances features, instances hidden features -> instances hidden batch", ) utils.plot_features_in_2d(hidden, title="Hidden state representation of a random batch of data") Click to see the expected output ### Visualizing features across varying sparsity Now that we've got our pentagon plots and started to get geometric intuition for what's going on, let's scale things up! We're now operating in dimensions too large to visualise, but hopefully our intuitions will carry over. cfg = ToyModelConfig(n_inst=10, n_features=100, d_hidden=20) importance = 100 ** -t.linspace(0, 1, cfg.n_features) feature_probability = 20 ** -t.linspace(0, 1, cfg.n_inst) line( importance, width=600, height=400, title="Feature importance (same over all instances)", labels={"y": "Importance", "x": "Feature"}, ) line( feature_probability, width=600, height=400, title="Feature probability (varied over instances)", labels={"y": "Probability", "x": "Instance"}, ) model = ToyModel( cfg=cfg, device=device, importance=importance[None, :], feature_probability=feature_probability[:, None], ) model.optimize(steps=10_000) Because we can't plot features in 2D anymore, we're going to use a different kind of visualisation: * The **bottom row plots** shows a bar graph of all the features and their corresponding embedding norms $||W_i||$. * As we increase sparsity, the model is able to represent more features (i.e. we have more features with embedding norms close to 1). * We also color the bars according to whether they're orthogonal to other features (purple) or not (yellow). So we can see that for low sparsity most features are represented orthogonally (like our left-most plots above) but as we increase sparsity we transition to all features being represented non-orthogonally (like our right-most pentagon plots above). * The **top row plots** show us the dot products between all pairs of feature vectors (kinda like the heatmaps we plotted at the start of this section). * This is another way of visualising the increasing interference between features as we increase sparsity. * Note that all these right hand plots represent **matrices with rank at most `d_hidden=20`**. The first few are approximately submatrices of the identity (because we perfectly reconstruct 20 features and delete the rest), but the later plots start to display inference as we plot more than 20 values (the diagonals of these matrices have more than 20 non-zero elements). See the section [Basic Results](https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-basic-results) for more of an explanation of this graph and what you should interpret from it. utils.plot_features_in_Nd( model.W, height=800, width=1600, title="ReLU output model: n_features = 100, d_hidden = 20, I<sub>i</sub> = 0.9<sup>i</sup>", subplot_titles=[f"Feature prob = {i:.3f}" for i in feature_probability], ) Click to see the expected output ## Superposition with correlation > Note, if you're aiming to progress quickly through these exercises in order to just cover the key ideas behind superposition, this is probably the point at which you can jump to the next section! The key idea here is essentially that negative correlation between features leads to more superposition, because the model suffers less from interference (the cases when both features are active at once). If you're interested in the details and actually performing the replication, read on! One major thing we haven't considered in our experiments is **correlation**. We could guess that superposition is even more common when features are **anticorrelated** (for a similar reason as why it's more common when features are sparse). Most real-world features are anticorrelated (e.g. the feature "this is a sorted Python list" and "this is some text in an edgy teen vampire romance novel" are probably anticorrelated - that is, unless you've been reading some pretty weird fanfics). In this section, you'll define a new data-generating function for correlated features, and run the same experiments as in the first section. ### Exercise - implement `generate_correlated_batch` Difficulty: 🔴🔴🔴🔴⚪ Importance: 🔵🔵⚪⚪⚪ You should spend up to 20-40 minutes on this exercise. The exercise itself is a bit fiddly / delicate, so you should definitely look at the solutions if you get stuck. You should now fill in the three methods `generate_correlated_features`, `generate_anticorrelated_features` and `generate_uncorrelated_features` in the `Model` class, which are created to generate correlated / anticorrelated data. We've given you a new `generate_batch` function which returns the aggregation from all of these methods. Note, in the correlated & anticorrelated cases you can assume that the feature probability is the same for all features in each instance. We start these functions by asserting this for you, and creating a vector `p` which contains this feature probability for each instance (which is what you should use instead of `model.feature_probability`). The same is also true for the uncorrelated case, when the number of uncorrelated features we're generating is less than `cfg.n_features` (since if not, it's fine to use the full `self.feature_probability` tensor). You'll also need to be careful with your probabilities in the anticorrelated case. For example, if you do the following for your pair of features 1 & 2: feat1_is_present = t.rand() < p feat2_is_present = t.rand() < p & ~feat1_is_present then your `feat2` probability will actually be `p * (1 - p)` rather than the intended `p`. You want to try and make both features have probability `p`, while _also_ ensuring that they are never both active at the same time! The hints provide some guidance on how you can implement this (it's a bit fiddly and not very conceptually important!). For more details, you can read the [experimental details in Anthropic's paper](https://transformer-circuits.pub/2022/toy_model/index.html#geometry-correlated-setup), where they describe how they setup correlated and anticorrelated sets. Help - I'm confused about how to implement the correlated features function. Try first creating a boolean mask of shape (batch_size, n_inst, n_correlated_pairs) representing whether the pair is present, then repeating that mask across feature pairs with einops.repeat. Help - I'm confused about how to implement the anticorrelated features function. Here are 2 suggested methods: 1. Create a boolean mask of shape (batch_size, n_inst, n_anticorrelated_pairs) with probability $2p$, which represents whether either feature is present - and where true, we choose the present feature uniform randomly from the pair. This works because both features will have probability $2p \times 0.5 = p$. 2. Create 2 boolean masks M1, M2 both of shape (batch_size, n_inst, n_anticorrelated_pairs) with probability $p$ and $p / (1 - p)$ respectively. Set the first feature to be present where M1 is true, and the second feature to be present where ~M1 && M2 is true. This works because the first feature will have probability $p$, and the second will have probability $\frac{(1 - p)p}{(1 - p)} = p$. The solutions use a method like (2), but either is valid. def generate_correlated_features( self: ToyModel, batch_size: int, n_correlated_pairs: int ) -> Float[Tensor, "batch inst 2*n_correlated_pairs"]: """ Generates a batch of correlated features. For each pair `batch[i, j, [2k, 2k+1]]`, one of them is non-zero if and only if the other is non-zero. """ assert t.all((self.feature_probability == self.feature_probability[:, [0]])) p = self.feature_probability[:, [0]] # shape (n_inst, 1) # YOUR CODE HERE! raise NotImplementedError() def generate_anticorrelated_features( self: ToyModel, batch_size: int, n_anticorrelated_pairs: int ) -> Float[Tensor, "batch inst 2*n_anticorrelated_pairs"]: """ Generates a batch of anti-correlated features. For each pair `batch[i, j, [2k, 2k+1]]`, each of them can only be non-zero if the other one is zero. """ assert t.all((self.feature_probability == self.feature_probability[:, [0]])) p = self.feature_probability[:, [0]] # shape (n_inst, 1) assert p.max().item() <= 0.5, "For anticorrelated features, must have 2p < 1" # YOUR CODE HERE! raise NotImplementedError() def generate_uncorrelated_features(self: ToyModel, batch_size: int, n_uncorrelated: int) -> Tensor: """ Generates a batch of uncorrelated features. """ if n_uncorrelated == self.cfg.n_features: p = self.feature_probability else: assert t.all((self.feature_probability == self.feature_probability[:, [0]])) p = self.feature_probability[:, [0]] # shape (n_inst, 1) # YOUR CODE HERE! raise NotImplementedError() def generate_batch(self: ToyModel, batch_size) -> Float[Tensor, "batch inst feats"]: """ Generates a batch of data, with optional correlated & anticorrelated features. """ n_corr_pairs = self.cfg.n_correlated_pairs n_anti_pairs = self.cfg.n_anticorrelated_pairs n_uncorr = self.cfg.n_features - 2 * n_corr_pairs - 2 * n_anti_pairs data = [] if n_corr_pairs > 0: data.append(generate_correlated_features(self, batch_size, n_corr_pairs)) if n_anti_pairs > 0: data.append(generate_anticorrelated_features(self, batch_size, n_anti_pairs)) if n_uncorr > 0: data.append(generate_uncorrelated_features(self, batch_size, n_uncorr)) batch = t.cat(data, dim=-1) return batch ToyModel.generate_batch = generate_batch Solution def generate_correlated_features( self: ToyModel, batch_size: int, n_correlated_pairs: int ) -> Float[Tensor, "batch inst 2n_correlated_pairs"]: """ Generates a batch of correlated features. For each pair batch[i, j, [2k, 2k+1]], one of them is non-zero if and only if the other is non-zero. """ assert t.all((self.feature_probability == self.feature_probability[:, [0]])) p = self.feature_probability[:, [0]] # shape (n_inst, 1) feat_mag = t.rand((batch_size, self.cfg.n_inst, 2 n_correlated_pairs), device=self.W.device) feat_set_seeds = t.rand((batch_size, self.cfg.n_inst, n_correlated_pairs), device=self.W.device) feat_set_is_present = feat_set_seeds <= p feat_is_present = einops.repeat( feat_set_is_present, "batch instances features -> batch instances (features pair)", pair=2, ) return t.where(feat_is_present, feat_mag, 0.0) def generate_anticorrelated_features( self: ToyModel, batch_size: int, n_anticorrelated_pairs: int ) -> Float[Tensor, "batch inst 2n_anticorrelated_pairs"]: """ Generates a batch of anti-correlated features. For each pair batch[i, j, [2k, 2k+1]], each of them can only be non-zero if the other one is zero. """ assert t.all((self.feature_probability == self.feature_probability[:, [0]])) p = self.feature_probability[:, [0]] # shape (n_inst, 1) assert p.max().item() <= 0.5, "For anticorrelated features, must have 2p < 1" feat_mag = t.rand( (batch_size, self.cfg.n_inst, 2 n_anticorrelated_pairs), device=self.W.device ) seed = t.rand((batch_size, self.cfg.n_inst, n_anticorrelated_pairs), device=self.W.device) mask = ( einops.rearrange(t.stack([seed, 1 - seed], dim=-1), "... feat pair -> ... (feat pair)") <= p ) return feat_mag * mask def generate_uncorrelated_features(self: ToyModel, batch_size: int, n_uncorrelated: int) -> Tensor: """ Generates a batch of uncorrelated features. """ if n_uncorrelated == self.cfg.n_features: p = self.feature_probability else: assert t.all((self.feature_probability == self.feature_probability[:, [0]])) p = self.feature_probability[:, [0]] # shape (n_inst, 1) if n_uncorrelated == self.cfg.n_features: p = self.feature_probability else: assert t.all((self.feature_probability == self.feature_probability[:, [0]])) p = self.feature_probability[:, [0]] # shape (n_inst, 1) feat_mag = t.rand((batch_size, self.cfg.n_inst, n_uncorrelated), device=self.W.device) feat_seeds = t.rand((batch_size, self.cfg.n_inst, n_uncorrelated), device=self.W.device) return t.where(feat_seeds <= p, feat_mag, 0.0) The code below tests your function, by generating a large number of batches and measuring them statistically. cfg = ToyModelConfig( n_inst=30, n_features=4, d_hidden=2, n_correlated_pairs=1, n_anticorrelated_pairs=1 ) feature_probability = 10 ** -t.linspace(0.5, 1, cfg.n_inst).to(device) model = ToyModel(cfg=cfg, device=device, feature_probability=feature_probability[:, None]) # Generate a batch of 4 features: first 2 are correlated, second 2 are anticorrelated batch = model.generate_batch(batch_size=100_000) corr0, corr1, anticorr0, anticorr1 = batch.unbind(dim=-1) assert ((corr0 != 0) == (corr1 != 0)).all(), "Correlated features should be active together" assert ((corr0 != 0).float().mean(0) - feature_probability).abs().mean() < 0.002, ( "Each correlated feature should be active with probability `feature_probability`" ) assert not ((anticorr0 != 0) & (anticorr1 != 0)).any(), ( "Anticorrelated features should never be active together" ) assert ((anticorr0 != 0).float().mean(0) - feature_probability).abs().mean() < 0.002, ( "Each anticorrelated feature should be active with probability `feature_probability`" ) We can also visualise these features, in the form of a bar chart. You should see the correlated features always co-occurring, and the anticorrelated features never co-occurring. # Generate a batch of 4 features: first 2 are correlated, second 2 are anticorrelated batch = model.generate_batch(batch_size=1) correlated_feature_batch, anticorrelated_feature_batch = batch.split(2, dim=-1) # Plot correlated features utils.plot_correlated_features( correlated_feature_batch, title="Correlated feature pairs: should always co-occur", ) utils.plot_correlated_features( anticorrelated_feature_batch, title="Anti-correlated feature pairs: should never co-occur", ) Click to see the expected output Now, let's try training our model & visualising features in 2D, when we have 2 pairs of correlated features (matching the [first row of the correlation figure](https://transformer-circuits.pub/2022/toy_model/index.html#geometry-organization) in the Anthropic paper). cfg = ToyModelConfig(n_inst=5, n_features=4, d_hidden=2, n_correlated_pairs=2) # All same importance, very low feature probabilities (ranging from 5% down to 0.25%) feature_probability = 400 ** -t.linspace(0.5, 1, cfg.n_inst) model = ToyModel( cfg=cfg, device=device, feature_probability=feature_probability[:, None], ) model.optimize(steps=10_000) utils.plot_features_in_2d( model.W, colors=["blue"] * 2 + ["limegreen"] * 2, title="Correlated feature sets are represented in local orthogonal bases", subplot_titles=[f"1 - S = {i:.3f}" for i in feature_probability], ) Click to see the expected output ### Exercise - generate more correlated feature plots Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to ~10 minutes on this exercise. It should just involve changing the parameters in your code above. You should now reproduce the second and third rows from the paper's [correlation figure](https://transformer-circuits.pub/2022/toy_model/index.html#geometry-organization). You may not get exactly the same results as the paper, but they should still roughly match (e.g. you should see no antipodal pairs in the code above, but you should see at least some when you test the anticorrelated sets, even if not all of them are antipodal). You can look at the solutions colab to see some examples. Question - for the anticorrelated feature plots, you'll have to increase the feature probability to something like ~10%, or else you won't always form antipodal pairs. Why do you think this is? The reason antipodal pairs are better for handling anticorrelated features is that the model can be sure only one of these antipodal pairs will be active at a time, so they won't interfere with each other. So effectively we can be sure that a maximum of 2 directions will be non-zero at a time, and those 2 directions are guaranteed to be orthogonal if they co-occur (because they're from 2 different orthogonal pairs, which lie in orthogonal subspaces to each other). So we can get zero loss. If we don't have antipodal pairs, then we'll sometimes get interference between features from different feature pairs (since their directions might be antipodal). The key point here - antipodal pairs are only better because they handle interference better, i.e. the cases where both feature pairs are active. This happens with $O(p^2)$ probability (where $p$ is the feature probability). So for very small values of $p$, the edge that the antipodal solution has over the non-antipodal solution is much smaller, and it may end up just settling on whichever solution it finds first. # YOUR CODE HERE - generate more correlated feature plots Solution (example code, and what you should find) # Anticorrelated feature pairs cfg = ToyModelConfig(n_inst=5, n_features=4, d_hidden=2, n_anticorrelated_pairs=2) # All same importance, not-super-low feature probabilities (all >10%) feature_probability = 10 ** -t.linspace(0.5, 1, cfg.n_inst) model = ToyModel(cfg=cfg, device=device, feature_probability=feature_probability[:, None]) model.optimize(steps=10_000) utils.plot_features_in_2d( model.W, colors=["red"] 2 + ["orange"] 2, title="Anticorrelated feature sets are frequently represented as antipodal pairs", subplot_titles=[f"1 - S = {i:.3f}" for i in feature_probability], ) # 3 correlated feature pairs cfg = ToyModelConfig(n_inst=5, n_features=6, d_hidden=2, n_correlated_pairs=3) # All same importance, very low feature probabilities (ranging from 5% down to 0.25%) feature_probability = 100 ** -t.linspace(0.5, 1, cfg.n_inst) model = ToyModel(cfg=cfg, device=device, feature_probability=feature_probability[:, None]) model.optimize(steps=10_000) utils.plot_features_in_2d( model.W, colors=["blue"] 2 + ["limegreen"] 2 + ["purple"] * 2, title="Correlated feature sets are side by side if they can't be orthogonal (and sometimes we get collapse)", subplot_titles=[f"1 - S = {i:.3f}" for i in feature_probability], )