2️⃣ Circuit and Feature Analysis Learning Objectives Apply your understanding of the 1D and 2D Fourier bases to show that the activations / effective weights of your model are highly sparse in the Fourier basis. Turn these observations into concrete hypotheses about the model's algorithm. Verify these hypotheses using statistical methods, and interventions like ablation. Fully understand the model's algorithm, and how it solves the task. Understanding a transformer breaks down into two high-level parts - interpreting the features represented by non-linear activations (output probabilities, attention patterns, neuron activations), and interpreting the circuits that calculate each feature (the weights representing the computation done to convert earlier features into later features). In this section we interpret the embedding, the neuron activations (a feature) and the logit computation (a circuit). These are the most important bits to interpet. Let's start with the embedding. Understanding the embedding Below is some code to plot the embedding in the Fourier basis. You should run this code, and interpret the output. utils.line( (fourier_basis @ W_E).pow(2).sum(1), hover=fourier_basis_names, title="Norm of embedding of each Fourier Component", xaxis="Fourier Component", yaxis="Norm", ) Click to see the expected output Interpretation You should find that the embedding is sparse in the Fourier basis, and throws away all Fourier components apart from a handful of frequencies (the number of frequencies and their values are arbitrary, and vary between training runs). The Fourier basis vector for component $2k$ is: $$ \cos\left(\frac{2 \pi k}{p}x\right)_{x = 0, ..., p-1} $$ and same for $2k-1$, but with $\cos$ replaced with $\sin$. So this result tells us that, for the input x, we're keeping the information $\cos\left(\frac{2 \pi k}{p}x\right)$ for each of the key frequencies $k$, and throwing away this information for all other frequencies $\omega$. Let us term the frequencies with non-trivial norm the key frequencies (here, 14, 31, 35, 41, 42, 52). Another perspective (from singular value decomposition) Recall that we can write any matrix as the product of an orthogonal matrix, a diagonal matrix, and another orthogonal matrix: $$ \begin{aligned} A &= U S V^T \\ &= \sum_{i=1}^k \sigma_i u_i v_i^T \end{aligned} $$ where $u_i$, $v_i$ are the column vectors for the orthogonal matrices $U$, $V$, and $\sigma_1, ..., \sigma_k$ are the non-zero singular values of $A$. If this isn't familiar, you might want to go through the induction heads exercises, which discusses SVD in more detail. This is often a natural way to represent low-rank matrices (because most singular values $\sigma_i$ will be zero). Denote the matrix fourier_basis as $F$ (remember that the rows are basis vectors). Here, we've found that the matrix $F W_E$ is very sparse (i.e. most of its row vectors are zero). From this, we can deduce that $W_E$ is well-approximated by the following low-rank SVD: $$ W_E \approx F^T S V^T $$ because then left-multiplying by $F$ gives us a matrix with zeros everywhere except for the rows corresponding to non-zero singular values: $$ F W_E \approx F^T F S V^T = S V^T $$ In other words, the input directions we're projecting onto when we take the embedding of our tokens $t_0$, $t_1$ are approximately the Fourier basis directions. To directly visualise the matrix product $S V^T$ (which will be sparse in rows because most diagonal values of $S$ are zero), run the following: imshow_div(fourier_basis @ W_E) Understanding neuron activations Now we've established that $W_E$ is only preserving the information corresponding to a few key frequencies, let's look at what those frequencies are actually being used for. TL;DR: * Each neuron's activations are a linear combination of the constant term and the linear and quadratic terms of a specific frequency. * The neurons clearly cluster according to the key frequencies. Neurons produce quadratic terms First, we recall the diagrams of the neuron activations in the input basis and the 2D Fourier basis, which we made in the previous section: top_k = 5 utils.inputs_heatmap(neuron_acts_post_sq[..., :top_k], ...) utils.imshow_fourier(neuron_acts_post_fourier_basis[..., :top_k], ...) We found that the first plot looked periodic (recall, periodic in standard basis = sparse in Fourier basis), and the second gave us more details about the Fourier basis representation by showing that each neuron was associated with some key frequency, from one of the key frequencies we observered earlier. For instance, look at the first neuron in these plots. We can see that the only frequencies which matter in the 2D Fourier basis are the constant terms and the frequencies corresponding to $\omega = 42$ (we get both $\sin$ and $\cos$ terms). In total, this gives us nine terms, which are (up to scale factors): $$ \begin{bmatrix} 1 & \cos(\omega_k x) & \sin(\omega_k x) \\ \cos(\omega_k y) & \cos(\omega_k x)\cos(\omega_k y) & \sin(\omega_k x)\cos(\omega_k y) \\ \sin(\omega_k y) & \cos(\omega_k x)\sin(\omega_k y) & \sin(\omega_k x)\sin(\omega_k y) \end{bmatrix} $$ where $\omega_k = 2 \pi k / p$, and in this case $k = 42$. These include the constant term, four linear terms, and four quadratic terms. What is the significance of this? Importantly, we have the following trig formulas: $$ \begin{aligned} \cos(\omega_k(x + y)) = \cos(\omega_k x) \cos(\omega_k y) - \sin(\omega_k x) \sin(\omega_k y) \\ \sin(\omega_k(x + y)) = \sin(\omega_k x) \cos(\omega_k y) + \cos(\omega_k x) \sin(\omega_k y) \end{aligned} $$ The plots tell us that some of the terms on the right of these equations (i.e. the quadratic terms) are the ones being detected by our neurons. Since we know that the model will eventually have to internally represent the quantity $x+y$ in some way in order to perform modular addition, we might guess that this is how it does it (i.e. it calculates the quantities on the left hand side of the equations, by first computing the things on the right). In other words, our neurons are in some sense storing the information $\cos(\omega_k(x + y))$ and $\sin(\omega_k(x + y))$, for different frequencies $k$. Let's have a closer look at some of the coefficients for these 2D Fourier basis terms. Exercise - calculate the mean squared coefficient Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to ~10 minutes on this exercise. This exercise asks you to perform some basic operations and make some simple plots with your neuron activations. We might speculate that the neuron activation (which is a function of the inputs $x$, $y$) is in some sense trying to approximate the function $\sum_i \sum_j B_{i, j} F_{i, x} F_{j, y}$, where: $F$ is the Fourier change-of-basis matrix. Note that, for example, $F_{1,17}= \cos(\frac{2\pi}{p} 17)$, and $F_{2, 17} = \sin(\frac{2\pi}{p} 17)$. $B$ is a matrix of coefficients, which we suspect is highly sparse. Specifically, we suspect that the "true function" our model is trying to approximate looks something like a linear combination of the nine terms in the matrix above: $$ \sum_{i\in \{0, 2\omega-1, 2\omega\}} \sum_{j\in \{0, 2\omega-1, 2\omega\}} B_{i,j} F_{i,x} F_{j,y} $$ (recall that the $2\omega-1$-th and $2\omega$-th basis vectors are the cosine and sine vectors for frequency $\omega$). Create a heatmap of the mean squared coefficient for each neuron (in other words, the (i, j)-th value in your heatmap is the mean of $B_{i, j}^2$ across all neurons). Your code should involve two steps: Centering the neuron activations, by subtracting the mean over all batches (because this essentially removes the bias term). All ReLU neurons always have non-negative activations, and the constant term should usually be considered separately. Taking the 2D Fourier transform of the centered neuron activations. We've given you the code to plot your results. Remember to work with the neuron_acts_post_sq object, which already has its batch dimensions shaped into a grid. # YOUR CODE HERE - compute neuron activations (centered) neuron_acts_centered = neuron_acts_post_sq - neuron_acts_post_sq.mean((0, 1), keepdim=True) # Take 2D Fourier transform neuron_acts_centered_fourier = fft2d(neuron_acts_centered) utils.imshow_fourier( neuron_acts_centered_fourier.pow(2).mean(-1), title="Norms of 2D Fourier components of centered neuron activations", ) Click to see the expected output Your plot of the average $B_{i, j}^2$ values should show the kind of sparsity you saw earlier: the only non-zero terms will be the const, linear and quadratic terms corresponding to a few select frequencies. You can compare these frequencies to the key frequencies we defined earlier, and validate that they're the same. How do we get the quadratic terms? Exactly how these quadratic terms are calculated is a bit convoluted. A good mental model for neural networks is that they are really good at matrix multiplication and addition, and anything else takes a lot of effort. So we should expect taking the product of two different parts of the input to be pretty hard! You can see the original notebook for some more details on this calculation. The short version - the model uses both ReLU activations and element-wise products with attention to multiply terms, in hacky ways. The attention pattern products (which we won't discuss much here) work because attention involves taking the product of value vectors and attention probabilities, which are each functions of the input. This seems a bit weird because attention probabilities and value vectors are usually thought of as playing different roles, but here they're working together to allow the model to multiply two different parts of the input. The ReLU activations also pretty surprising. It turns out that linear functions of 1D and 2D Fourier components are well approximated by the ReLU of a linear function of the 1D Fourier components. Specifically, if we approximate the expression: $$ \operatorname{ReLU}(A + B \cos(\omega x) + B \cos(\omega y)) $$ (for $A$, $B > 0$) as a linear combination of the following 4 terms in the 2D Fourier basis: $$ \alpha + \beta \cos(\omega x) + \beta \cos(\omega y) + \gamma \cos(\omega x) \cos(\omega y) $$ we find that it includes a significant component in the $\cos(\omega x) \cos(\omega y)$ direction. The key intuition for why this happens is that the quadratic term captures interaction between $x$ and $y$. $\operatorname{ReLU}$ is a convex function of $\cos(\omega x) + \cos(\omega y)$, so (if we pretend $x$ and $y$ are random variables) it is larger in expectation when these two inputs are correlated, hence $\gamma > 0$. * * Note - this is quite a handwavey argument, so don't worry too much if it doesn't seem intuitive! This is important because our model can calculate the first of these two expressions (it can take linear combinations of $\cos(\omega x)$ and $\cos(\omega y)$ in the attention layer, then apply $\operatorname{ReLU}$ during the MLP), but it can't directly calculate the second expression. So we can essentially use our ReLU to approximate a sum of linear and quadratic terms. Exercise - verify the quadratic term matters, and that $\gamma > 0$ Difficulty: 🔴🔴🔴🔴⚪ Importance: 🔵⚪⚪⚪⚪ You should spend up to 10-15 minutes on this exercise. Doing this exercise isn't super valuable to the overall experience of this section. You can skip it if you want. Take $A = {1}/{2\sqrt{p}}, \; B = 1, \;$ and $\;\omega = \omega_{42} = (2 \pi \times 42) / p \;$ (one of the key frequencies we observed earlier). Find the coefficients $\alpha$, $\beta$ and $\gamma$ that minimize the mean squared error between the ReLU approximation and the true quadratic function. Find the $r^2$ score of this fit. Verify that $\gamma > 0$, and the score is close to 1. Also, verify that the $r^2$ score decreases by quite a lot when you omit the quadratic term (showing that this quadratic term is important). You can use the LinearRegression function from sklearn. Remember to use the normalized 1D and 2D Fourier basis vectors in your regression. from sklearn.linear_model import LinearRegression # YOUR CODE HERE - compute quadratic term, and r^2 of regression with/without it Discussion of results You should get the following: ReLU(0.5 + cos(wx) + cos(wy)) ≈ 9.190const + 6.807cos(wx) + 6.807cos(wy) + 3.566cos(wx)cos(wy) r2: 0.966 r2 (no quadratic term): 0.849 This confirms that the quadratic term does indeed have a positive coefficient, and that it explains a lot of the variance in the ReLU function (specifically, it explains over 2/3 of the variance which is left after we've accounted for the linear terms). (Note that we didn't print out the coefficients for the regression without quadratic terms - our 2D Fourier basis vectors are orthogonal, so we can guarantee that these coefficients would be the same as the coefficients in the regression with quadratic terms.) Solution from sklearn.linear_model import LinearRegression # Choose a particular frequency, and get the corresponding cosine basis vector k = 42 idx = 2 k - 1 vec = fourier_basis[idx] # Get ReLU function values relu_func_values = F.relu(0.5 (p**-0.5) + vec[None, :] + vec[:, None]) # Get terms we'll be using to approximate it # Note we're including the constant term here data = t.stack( [fourier_2d_basis_term(i, j) for (i, j) in [(0, 0), (idx, 0), (0, idx), (idx, idx)]], dim=-1 ) # Reshape, and convert to numpy data = to_numpy(data.reshape(p p, 4)) relu_func_values = to_numpy(relu_func_values.flatten()) # Fit a linear model (we don't need intercept because we have const Fourier basis term) reg = LinearRegression(fit_intercept=False).fit(data, relu_func_values) coefs = reg.coef_ r2 = reg.score(data, relu_func_values) print( "ReLU(0.5 + cos(wx) + cos(wy)) ≈ {:.3f}const + {:.3f}cos(wx) + {:.3f}cos(wy) + {:.3f}cos(wx)cos(wy)".format( coefs ) ) print(f"r2: {r2:.3f}") # Run the regression again, but without the quadratic term data = data[:, :3] reg = LinearRegression().fit(data, relu_func_values) coefs = reg.coef_ bias = reg.intercept_ r2 = reg.score(data, relu_func_values) print(f"r2 (no quadratic term): {r2:.3f}") Neurons cluster by frequency Now that we've established that the neurons each seem to have some single frequency that they're most sensitive to (and ignore all others), let's try and sort the neurons by this frequency, and see how effective each of these frequencies are at explaining the neurons' behaviour. Exercise - find neuron clusters Difficulty: 🔴🔴🔴🔴⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to 15-30 minutes on this exercise. This exercise is conceptually important, but quite challenging. For each neuron, you should find the frequency such that the Fourier components containing that frequency explain the largest amount of the variance of that neuron's activation (in other words, which frequency is such that the sum of squares of the const, linear and quadratic terms of that frequency for this particular is largest, as a fraction of the sum of squares of all the Fourier coefficients for this neuron). We've provided you with the helper function arrange_by_2d_freqs. This takes in a tensor of coefficients in the 2D Fourier basis (with shape (p, p, ...)), and returns a tensor of shape (p//2, 3, 3, ...), representing the Fourier coefficients sorted into each frequency. In other words, the [k-1, ...]-th slice of this tensor will be the (3, 3, ...)-shape tensor containing the Fourier coefficients for the following (normalized) 2D Fourier basis vectors: $$ \begin{bmatrix} 1 & \cos(\omega_k x) & \sin(\omega_k x) \\ \cos(\omega_k y) & \cos(\omega_k x)\cos(\omega_k y) & \sin(\omega_k x)\cos(\omega_k y) \\ \sin(\omega_k y) & \cos(\omega_k x)\sin(\omega_k y) & \sin(\omega_k x)\sin(\omega_k y) \end{bmatrix} $$ Think of this as just a fancy rearranging of the tensor, so that you're more easily able to access all the (const, linear, quadratic) terms for any particular frequency. def arrange_by_2d_freqs(tensor: Tensor) -> Tensor: """ Takes a tensor of shape (p, p, *batch_dims), returns tensor of shape (p//2 - 1, 3, 3, *batch_dims) representing the Fourier coefficients sorted by frequency (each slice contains const, linear and quadratic terms). In other words, if `tensor` had shape (p, p) and looked like this: 1 cos(w_1*x) sin(w_1*x) ... cos(w_1*y) cos(w_1*x)cos(w_1*y) sin(w_1*x)cos(w_1*y) ... sin(w_1*y) cos(w_1*x)sin(w_1*y) sin(w_1*x)sin(w_1*y) ... cos(w_2*y) cos(w_1*x)cos(w_2*y) sin(w_1*x)cos(w_2*y) ... ... ... ... Then arrange_by_2d_freqs(tensor)[k-1] should be the following (3, 3) tensor of kth-mode Fourier frequencies: 1 cos(w_k*x) sin(w_k*x) cos(w_k*y) cos(w_k*x)cos(w_k*y) sin(w_k*x)cos(w_k*y) sin(w_k*y) cos(w_k*x)sin(w_k*y) sin(w_k*x)sin(w_k*y) for k = 1, 2, ..., p//2. Note we omit the constant term, i.e. the 0th slice has frequency k=1. Any dimensions beyond the first 2 are treated as batch dimensions, i.e. we only rearrange the first 2. """ idx_2d_y_all = [] idx_2d_x_all = [] for freq in range(1, p // 2): idx_1d = [0, 2 * freq - 1, 2 * freq] idx_2d_x_all.append([idx_1d for _ in range(3)]) idx_2d_y_all.append([[i] * 3 for i in idx_1d]) return tensor[idx_2d_y_all, idx_2d_x_all] def find_neuron_freqs( fourier_neuron_acts: Float[Tensor, "p p d_mlp"], ) -> tuple[Float[Tensor, "d_mlp"], Float[Tensor, "d_mlp"]]: """ Returns the tensors `neuron_freqs` and `neuron_frac_explained`, containing the frequencies that explain the most variance of each neuron and the fraction of variance explained, respectively. """ fourier_neuron_acts_by_freq = arrange_by_2d_freqs(fourier_neuron_acts) assert fourier_neuron_acts_by_freq.shape == (p // 2 - 1, 3, 3, utils.d_mlp) raise NotImplementedError() return neuron_freqs, neuron_frac_explained neuron_freqs, neuron_frac_explained = find_neuron_freqs(neuron_acts_centered_fourier) key_freqs, neuron_freq_counts = t.unique(neuron_freqs, return_counts=True) assert key_freqs.tolist() == [14, 35, 41, 42, 52] print("All tests for `find_neuron_freqs` passed!") Help - all my key frequencies are off by one from the true answer. Remember that the 0th slice of the arrange_by_2d_freqs function are actually the frequencies for $k=1$, not $k=0$ (and so on). If you get the key frequencies by argmaxing, then make sure you add one to these indices! Solution def find_neuron_freqs( fourier_neuron_acts: Float[Tensor, "p p d_mlp"], ) -> tuple[Float[Tensor, "d_mlp"], Float[Tensor, "d_mlp"]]: """ Returns the tensors neuron_freqs and neuron_frac_explained, containing the frequencies that explain the most variance of each neuron and the fraction of variance explained, respectively. """ fourier_neuron_acts_by_freq = arrange_by_2d_freqs(fourier_neuron_acts) assert fourier_neuron_acts_by_freq.shape == (p // 2 - 1, 3, 3, utils.d_mlp) # Sum squares of all frequency coeffs, for each neuron square_of_all_terms = einops.reduce( fourier_neuron_acts.pow(2), "x_coeff y_coeff neuron -> neuron", "sum" ) # Sum squares just corresponding to const+linear+quadratic terms, # for each frequency, for each neuron square_of_each_freq = einops.reduce( fourier_neuron_acts_by_freq.pow(2), "freq x_coeff y_coeff neuron -> freq neuron", "sum" ) # Find the freq explaining most variance for each neuron # (and the fraction of variance explained) neuron_variance_explained, neuron_freqs = square_of_each_freq.max(0) neuron_frac_explained = neuron_variance_explained / square_of_all_terms # The actual frequencies count up from k=1, not 0! neuron_freqs += 1 return neuron_freqs, neuron_frac_explained Note the use of einops.reduce in the solution, rather than just using e.g. fourier_neuron_acts.pow(2).sum((0, 1)). Like most of the situations where einops is helpful, this has the advantage of making your code more explicit, readable, and reduces the chance of mistakes. Once you've written this function and passed the tests, you can plot the fraction of variance explained. fraction_of_activations_positive_at_posn2 = (cache["pre", 0][:, -1] > 0).float().mean(0) utils.scatter( x=neuron_freqs, y=neuron_frac_explained, xaxis="Neuron frequency", yaxis="Frac explained", colorbar_title="Frac positive", title="Fraction of neuron activations explained by key freq", color=to_numpy(fraction_of_activations_positive_at_posn2), ) Click to see the expected output We color the neurons according to the fraction of data points for which they are active. We see that there are 5 distinct clusters of neurons that are well explained (frac > 0.85) by one frequency. There is a sixth, diffuse cluster of neurons that always fire. They are not well-explained by any particular frequency. This makes sense, because since ReLU acts as an identity on this cluster, there's no reason to privilege the neuron basis (i.e. there's no reason to expect that the specific value of this neuron's activations has any particular meaning in relation to the Fourier components of the input, since we could just as easily apply rotations to the always-firing neurons). # To represent that they are in a special sixth cluster, we set the frequency of these neurons to -1 neuron_freqs[neuron_frac_explained < 0.85] = -1.0 key_freqs_plus = t.concatenate([key_freqs, -key_freqs.new_ones((1,))]) for i, k in enumerate(key_freqs_plus): print(f"Cluster {i}: freq k={k}, {(neuron_freqs == k).sum()} neurons") Click to see the expected output Cluster 0: freq k=14, 44 neurons Cluster 1: freq k=35, 93 neurons Cluster 2: freq k=41, 145 neurons Cluster 3: freq k=42, 87 neurons Cluster 4: freq k=52, 64 neurons Cluster 5: freq k=-1, 79 neurons Further investigation of neuron clusters We can separately view the norms of the Fourier Components of the neuron activations for each cluster. The following code should do the same thing as your plot of the average $B_{i, j}^2$ values earlier, except it sorts the neurons into clusters by their frequency before taking the mean. (Note, we're using the argument facet_col rather than animation_frame, so we can see all the plots at once.) fourier_norms_in_each_cluster = [] for freq in key_freqs: fourier_norms_in_each_cluster.append( einops.reduce( neuron_acts_centered_fourier.pow(2)[..., neuron_freqs == freq], "batch_y batch_x neuron -> batch_y batch_x", "mean", ) ) utils.imshow_fourier( t.stack(fourier_norms_in_each_cluster), title="Norm of 2D Fourier components of neuron activations in each cluster", facet_col=0, facet_labels=[f"Freq={freq}" for freq in key_freqs], ) Click to see the expected output Now that we've found what appear to be neuron clusters, it's time to validate our observations. We'll do this by showing that, for each neuron cluster, we can set terms for any other frequency to zero and still get good performance on the task. Exercise - validate neuron clusters Difficulty: 🔴🔴🔴🔴⚪ Importance: 🔵🔵🔵🔵⚪ You should spend up to 25-40 minutes on the following exercises. They are designed to get you to engage with ideas from linear algebra (specifically projections), and are very important conceptually. We want to do the following: Take neuron_acts_post, which is a tensor of shape (p*p, d_mlp), with the [i, j]-th element being the activation of neuron j on the i-th input sequence in our all_data batch. Treating this tensor as p separate vectors in $\mathbb{R}^{p^2}$ (with the last dimension being the batch dimension), we will project each of these vectors onto the subspace of $\mathbb{R}^{p^2}$ spanned by the 2D Fourier basis vectors for the associated frequency of that particular neuron (i.e. the constant, linear, and quadratic terms). Take these projected neuron_acts_post vectors, and apply W_logit to give us new logits. Compare the cross entropy loss with these logits to our original logits. If our hypothesis is correct (i.e. that for each cluster of neurons associated with a particular frequency, that frequency is the only one that matters), then our loss shouldn't decrease by much when we project out the other frequencies in this way. First, you'll need to write the function project_onto_direction. This takes two inputs: batch_vecs (a batch of vectors, with batch dimension at the end) and v (a single vector), and returns the projection of each vector in batch_vecs onto the direction v. def project_onto_direction(batch_vecs: Tensor, v: Tensor) -> Tensor: """ Returns the component of each vector in `batch_vecs` in the direction of `v`. batch_vecs.shape = (n, ...) v.shape = (n,) """ raise NotImplementedError() tests.test_project_onto_direction(project_onto_direction) Hint Recall that the projection of $w$ onto $v$ (for a normalized vector $v$) is given by: $$ w_{proj} = (w \cdot v) v $$ You might find it easier to do this in two steps: first calculate the components in the $v$-direction by taking the inner product along the 0th dimension, then create a batch of multiples of $v$, scaled by these components. Solution def project_onto_direction(batch_vecs: Tensor, v: Tensor) -> Tensor: """ Returns the component of each vector in batch_vecs in the direction of v. batch_vecs.shape = (n, ...) v.shape = (n,) """ # Get tensor of components of each vector in v-direction components_in_v_dir = einops.einsum(batch_vecs, v, "n ..., n -> ...") # Use these components as coefficients of v in our projections return einops.einsum(components_in_v_dir, v, "..., n -> n ...") Next, you should write the function project_onto_frequency. This takes a batch of vectors with shape (p**2, batch), and a frequency freq, and returns the projection of each vector onto the subspace spanned by the nine 2D Fourier basis vectors for that frequency (i.e. one constant term, four linear terms, and four quadratic terms). def project_onto_frequency(batch_vecs: Tensor, freq: int) -> Tensor: """ Returns the projection of each vector in `batch_vecs` onto the 2D Fourier basis directions corresponding to frequency `freq`. batch_vecs.shape = (p**2, ...) """ assert batch_vecs.shape[0] == p**2 raise NotImplementedError() tests.test_project_onto_frequency(project_onto_frequency) Hint This will just involve summing nine calls to the project_onto_direction function you wrote above (one for each basis vector you're projecting onto), since your basis vectors are orthogonal. You should use the function fourier_2d_basis_term to get the vectors you'll be projecting onto. Remember to flatten these vectors, because you're working with vectors of length p**2 rather than of size (p, p)! Solution The for loop in this code goes over the indices for the constant term (0, 0), the linear terms (2f-1, 0), (2f, 0), (0, 2f-1), (0, 2f), and the quadratic terms (2f-1, 2f-1), (2f-1, 2f), (2f, 2f-1), (2f, 2f). def project_onto_frequency(batch_vecs: Tensor, freq: int) -> Tensor: """ Returns the projection of each vector in batch_vecs onto the 2D Fourier basis directions corresponding to frequency freq. batch_vecs.shape = (p2, ...) """ assert batch_vecs.shape[0] == p2 return sum( [ project_onto_direction( batch_vecs, fourier_2d_basis_term(i, j).flatten(), ) for i in [0, 2 freq - 1, 2 freq] for j in [0, 2 freq - 1, 2 freq] ] ) Finally, run the following code to project out the other frequencies from the neuron activations, and compare the new loss. You should make sure you understand what this code is doing. logits_in_freqs = [] for freq in key_freqs: # Get all neuron activations corresponding to this frequency filtered_neuron_acts = neuron_acts_post[:, neuron_freqs == freq] # Project onto const/linear/quadratic terms in 2D Fourier basis filtered_neuron_acts_in_freq = project_onto_frequency(filtered_neuron_acts, freq) # Calcluate new logits, from these filtered neuron activations logits_in_freq = filtered_neuron_acts_in_freq @ W_logit[neuron_freqs == freq] logits_in_freqs.append(logits_in_freq) # We add on neurons in the always firing cluster, unfiltered logits_always_firing = neuron_acts_post[:, neuron_freqs == -1] @ W_logit[neuron_freqs == -1] logits_in_freqs.append(logits_always_firing) # Compute new losses key_freq_loss = utils.test_logits( sum(logits_in_freqs), bias_correction=True, original_logits=original_logits ) key_freq_loss_no_always_firing = utils.test_logits( sum(logits_in_freqs[:-1]), bias_correction=True, original_logits=original_logits ) # Print new losses print(f"""Loss with neuron activations ONLY in key freq (including always firing cluster): {key_freq_loss:.3e} Loss with neuron activations ONLY in key freq (exclusing always firing cluster): {key_freq_loss_no_always_firing:.3e} Original loss: {original_loss:.3e} """) Click to see the expected output Loss with neuron activations ONLY in key freq (including always firing cluster) 2.482e-07 Loss with neuron activations ONLY in key freq (exclusing always firing cluster) 1.179e-06 Original loss 2.412e-07 You should find that the loss doesn't change much when you project out the other frequencies, and even if you remove the always firing cluster the loss is still very small. We can also compare the importance of each cluster of neurons by ablating it (while continuing to restrict each cluster to its frequency). We see from this that freq=52 is the most important cluster (because the loss increases a lot when this is removed), although clearly all clusters are important (because the loss is still very small if any one of them is ablated). We also see that ablating the always firing cluster has a very small effect, so clearly this cluster isn't very helpful for the task. (This is something we might have guessed beforehand, since the ReLU never firing makes this essentially a linear function, and in general having non-linearities allows you to learn much more expressive functions.) print("Loss with neuron activations excluding none: {:.9f}".format(original_loss.item())) for c, freq in enumerate(key_freqs_plus): print( "Loss with neuron activations excluding freq={}: {:.9f}".format( freq, utils.test_logits( sum(logits_in_freqs) - logits_in_freqs[c], bias_correction=True, original_logits=original_logits, ), ) ) Click to see the expected output Loss with neuron activations excluding none: 0.000000241 Loss with neuron activations excluding freq=14: 0.000199645 Loss with neuron activations excluding freq=35: 0.000458851 Loss with neuron activations excluding freq=41: 0.001917986 Loss with neuron activations excluding freq=42: 0.005197690 Loss with neuron activations excluding freq=52: 0.024398957 Loss with neuron activations excluding freq=-1: 0.000001179 Understanding Logit Computation TLDR: The network uses $W_{logit}=W_{out}W_U$ to cancel out all 2D Fourier components other than the directions corresponding to $\cos(w(x+y)),\sin(w(x+y))$, and then multiplies these directions by $\cos(wz),\sin(wz)$ respectively and sums to get the output logits. Recall that (for each neuron cluster with associated frequency $k$), each neuron's activations are a linear combination of const, linear and quadratic terms: $$ \begin{bmatrix} 1 & \cos(\omega_k x) & \sin(\omega_k x) \\ \cos(\omega_k y) & \cos(\omega_k x)\cos(\omega_k y) & \sin(\omega_k x)\cos(\omega_k y) \\ \sin(\omega_k y) & \cos(\omega_k x)\sin(\omega_k y) & \sin(\omega_k x)\sin(\omega_k y) \end{bmatrix} $$ for $\omega_k = 2\pi k / p$. To calculate the logits, the network cancels out all directions apart from: $$ \begin{aligned} \cos(\omega_k (x+y)) &= \cos(\omega_k x)\cos(\omega_k y)-\sin(\omega_k x)\sin(\omega_k y) \\ \sin(\omega_k (x+y)) &= \sin(\omega_k x)\cos(\omega_k y)+\cos(\omega_k x)\sin(\omega_k y) \end{aligned} $$ The network then multiplies these by $\cos(wz),\sin(wz)$ and sums (i.e. the logit for value $z$ will be the product of these with $\cos(wz),\sin(wz)$, summed over all neurons in that cluster, summed over all clusters). Question - can you explain why this algorithm works? Hint Think about the following expression: $$ \cos(\omega (x+y))\cos(\omega z)+\sin(\omega (x+y))\sin(\omega z) $$ which is (up to a scale factor) the value added to the logit score for $z$. Answer The reason is again thanks to our trig formulas! Each neuron will add to the final logits a vector which looks like: $$ \cos(w(x+y))\cos(wz)+\sin(w(x+y))\sin(wz) $$ which we know from our trig formulas equals: $$ \cos(w(x+y-z)) $$ which is largest when $z=x+y$. --- Another way of writing this would be that, on inputs (x, y), the model's logit output is (up to a scale factor) equal to: $$ \cos(\omega(x+y-\vec{\textbf{z}})) = \begin{bmatrix} \cos(\omega(x+y)) \\ \cos(\omega(x+y-1)) \\ \vdots \\ \cos(\omega(x+y-(p-1))) \end{bmatrix} $$ This vector is largest at element with index $x+y$, meaning the logit for $x+y$ will be largest (which is exactly what we want to happen, to solve our problem!). Also, remember that we have several different frequencies $\omega_k$, and so when we sum over neurons, the vectors will combine constructively at $z = x+y$, and combine destructively everywhere else: $$ f(t) = \sum_{k \in K} C_k \cos(\omega_k (x + y - \vec{\textbf{z}})) $$ (where $C_k$ are large positive constants, which we'll explicitly calculate later on). Logits in Fourier Basis To see that the network cancels out other directions, we can transform both the neuron activations and logits to the 2D Fourier Basis, and show the norm of the vector corresponding to each Fourier component - we see that the quadratic terms have much higher norm in the logits than neuron activations, and linear terms are close to zero. Remember that, to get from activations to logits, we apply the linear map $W_{logit}=W_{out}W_U$, which involves summing the outputs of all neurons. Below is some code to visualise this. Note the use of einops.reduce rather than the mean method in the code below. Like most other uses of einops, this is useful because it's explicit and readable. utils.imshow_fourier( einops.reduce(neuron_acts_centered_fourier.pow(2), "y x neuron -> y x", "mean"), title="Norm of Fourier Components of Neuron Acts", ) # Rearrange logits, so the first two dims represent (x, y) in modular arithmetic equation original_logits_sq = einops.rearrange(original_logits, "(x y) z -> x y z", x=p) original_logits_fourier = fft2d(original_logits_sq) utils.imshow_fourier( einops.reduce(original_logits_fourier.pow(2), "y x z -> y x", "mean"), title="Norm of Fourier Components of Logits", ) Click to see the expected output You should find that the linear and constant terms have more or less vanished relative to the quadratic terms, and that the quadratic terms are much larger in the logits than the neuron activations. This is annotated in the plots below (which should match the results you get from running the code): Exercise - validate by only taking quadratic terms Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to 10-20 minutes on this exercise. This exercise should feel similar to the previous one, since it's about vectors and projections. Here, you will validate your results still further by just taking the components of the logits corresponding to $\cos(\omega_k(x+y))$ and $\sin(\omega_k(x+y))$ for each of our key frequencies $k$, and showing this increases performance. First, you should write a function get_trig_sum_directions, which takes in a frequency k and returns the (p, p)-size vectors in 2D Fourier space corresponding to the directions: $$ \begin{aligned} \cos(\omega_k (\vec{\textbf{x}}+\vec{\textbf{y}})) &= \cos(\omega_k \vec{\textbf{x}})\cos(\omega_k \vec{\textbf{y}})-\sin(\omega_k \vec{\textbf{x}})\sin(\omega_k \vec{\textbf{y}}) \\ \sin(\omega_k (\vec{\textbf{x}}+\vec{\textbf{y}})) &= \sin(\omega_k \vec{\textbf{x}})\cos(\omega_k \vec{\textbf{y}})+\cos(\omega_k \vec{\textbf{x}})\sin(\omega_k \vec{\textbf{y}}) \end{aligned} $$ respectively. Remember, the vectors you return should be normalized. def get_trig_sum_directions(k: int) -> tuple[Float[Tensor, "p p"], Float[Tensor, "p p"]]: """ Given frequency k, returns the normalized vectors in the 2D Fourier basis representing the two directions: cos(ω_k * (x + y)) sin(ω_k * (x + y)) respectively. """ raise NotImplementedError() tests.test_get_trig_sum_directions(get_trig_sum_directions) Hint You can get the vector $\cos(\omega_k \vec{\textbf{x}}) \cos(\omega_k \vec{\textbf{y}})$ as follows: cosx_cosy_direction = fourier_2d_basis_term(2k-1, 2k-1) Solution def get_trig_sum_directions(k: int) -> tuple[Float[Tensor, "p p"], Float[Tensor, "p p"]]: """ Given frequency k, returns the normalized vectors in the 2D Fourier basis representing the two directions: cos(ω_k (x + y)) sin(ω_k (x + y)) respectively. """ cosx_cosy_direction = fourier_2d_basis_term(2 k - 1, 2 k - 1) sinx_siny_direction = fourier_2d_basis_term(2 k, 2 k) sinx_cosy_direction = fourier_2d_basis_term(2 k, 2 k - 1) cosx_siny_direction = fourier_2d_basis_term(2 k - 1, 2 k) cos_xplusy_direction = (cosx_cosy_direction - sinx_siny_direction) / np.sqrt(2) sin_xplusy_direction = (sinx_cosy_direction + cosx_siny_direction) / np.sqrt(2) return cos_xplusy_direction, sin_xplusy_direction Once you've passed these tests, you can run the code to project the logits onto these directions, and see how the loss changes. Note the use of the project_onto_direction function which you wrote earlier. trig_logits = [] for k in key_freqs: cos_xplusy_direction, sin_xplusy_direction = get_trig_sum_directions(k) cos_xplusy_projection = project_onto_direction(original_logits, cos_xplusy_direction.flatten()) sin_xplusy_projection = project_onto_direction(original_logits, sin_xplusy_direction.flatten()) trig_logits.extend([cos_xplusy_projection, sin_xplusy_projection]) trig_logits = sum(trig_logits) print(f"Loss with just x+y components: {utils.test_logits(trig_logits, True, original_logits):.4e}") print(f"Original Loss: {original_loss:.4e}") Click to see the expected output Loss with just x+y components: 5.5474e-08 Original Loss: 2.4122e-07 You should find that the loss with just these components is significantly lower than your original loss. This is very strong evidence that we've correctly identified the algorithm used by our model. $W_{logit}$ in Fourier Basis Okay, so we know that the model is mainly working with the terms $\cos(\omega_k(x+y))$ and $\sin(\omega_k(x+y))$ for each of our key frequencies $k$. Now, we want to show that the model's final ouput is: $$ \cos(\omega_k (x + y - \vec{\textbf{z}})) = \cos(\omega_k (x + y))\cos(\omega_k \vec{\textbf{z}}) + \sin(\omega_k (x + y))\sin(\omega_k \vec{\textbf{z}}) $$ How do we do this? Answer: we examine $W_{logit}$ in the Fourier basis. If we think that $W_{logit}$ is mainly projecting onto the directions $\cos(\omega_k \vec{\textbf{z}})$ and $\sin(\omega_k \vec{\textbf{z}})$, then we expect to find: $$ W_{logit} \approx U S F = \sum_{i=1}^p \sigma_i u_i f_i^T $$ where the singular values $\sigma_i$ are zero except those corresponding to Fourier basis vectors $f_i = \cos(\omega_k \vec{\textbf{z}}), \sin(\omega_k \vec{\textbf{z}})$ for key frequencies $k$. In other words: $$ W_{logit} \approx \sum_{k \in K} \sigma_{2k-1} u_{2k-1} \cos(\omega_k \vec{\textbf{z}}) + \sigma_{2k} u_{2k} \sin(\omega_k \vec{\textbf{z}}) $$ Thus, if we right-multiply $W_{logit}$ by $F^T$, we should get a matrix $W_{logit} F^T \approx US$ of shape (d_mlp, p), with all columns zero except for $\sigma_{2k-1}u_{2k-1}$ and $\sigma_{2k} u_{2k}$ for key frequencies $k$. Let's verify this: US = W_logit @ fourier_basis.T utils.imshow_div( US, x=fourier_basis_names, yaxis="Neuron index", title="W_logit in the Fourier Basis" ) Click to see the expected output You should see that the columns of this matrix are only non-zero at positions $2k$, $2k-1$ for the key frequencies. Since our model's final output is just a linear combination of these columns (with the coefficients given by the neuron activations), this proves that $W_{logit}$ is projecting onto directions corresponding to our key frequencies. Note the contrast between what we just did and what we've done before. In previous sections, we've taken the 2D Fourier transform of our activations / effective weights with respect to the input space (the vectors were $\vec{\textbf{x}}$ and $\vec{\textbf{y}}$). Now, we're taking the 1D Fourier tranformation with respect to the output space (the vectors are $\vec{\textbf{z}})$. It's pretty cool that this works! So we've proven that: $$ W_{logit} \approx \sum_{k \in K} \sigma_{2k-1} u_{2k-1} \cos(\omega_k \vec{\textbf{z}}) + \sigma_{2k} u_{2k} \sin(\omega_k \vec{\textbf{z}}) $$ but we still want to show that our final output is: $$ f(t) = n_{post}^T W_{logit} \approx \sum_{k \in K} C_k \big(\cos(\omega_k (x + y)) \cos(\omega_k \vec{\textbf{z}}) + \sin(\omega_k (x + y)) \sin(\omega_k \vec{\textbf{z}})\big) $$ where $n_{post} \in \mathbb{R}^{d_{mlp}}$ is the vector of neuron activations, and $C_k$ are large positive constants. Matching coefficients of the vectors $\cos(\omega_k \vec{\textbf{z}})$ and $\sin(\omega_k \vec{\textbf{z}})$, this means we want to show that: $$ \begin{aligned} \sigma_{2k-1} u_{2k-1} &\approx C_k \cos(\omega_k (x + y)) \\ \sigma_{2k} u_{2k} &\approx C_k \sin(\omega_k (x + y)) \end{aligned} $$ for each key frequency $k$. First, let's do a quick sanity check. We expect vectors $u_{2k-1}$ and $u_{2k}$ to only contain components of frequency $k$, which means we expect the only non-zero elements of these vectors to correspond to the neurons in the $k$-frequency cluster. Let's test this by rearranging the matrix $W_{logit}F^T \approx US$, so that the neurons in each cluster are grouped together: US_sorted = t.concatenate([US[neuron_freqs == freq] for freq in key_freqs_plus]) hline_positions = np.cumsum( [(neuron_freqs == freq).sum().item() for freq in key_freqs] ).tolist() + [cfg.d_mlp] utils.imshow_div( US_sorted, x=fourier_basis_names, yaxis="Neuron", title="W_logit in the Fourier Basis (rearranged by neuron cluster)", hline_positions=hline_positions, hline_labels=[f"Cluster: {freq=}" for freq in key_freqs.tolist()] + ["No freq"], ) Click to see the expected output You should find that, for each frequency $k$, the components of the output in directions $\cos(\omega_k \vec{\textbf{z}})$ and $\sin(\omega_k \vec{\textbf{z}})$ are determined only by the neurons in the $k$-cluster, i.e. they are determined only by the 2D Fourier components of the input $(x, y)$ with frequency $k$. This is promising, but we still haven't shown that $\sigma_{2k-1} u_{2k-1} \propto \cos(\omega_k(x+y))$, etc. To do this, we'll calculate the vectors $\sigma_{2k-1} u_{2k-1}$ and $\sigma_{2k} u_{2k}$ over all inputs $(x, y)$, then take the 2D Fourier transform. cos_components = [] sin_components = [] for k in key_freqs: sigma_u_sin = US[:, 2 * k] sigma_u_cos = US[:, 2 * k - 1] logits_in_cos_dir = neuron_acts_post_sq @ sigma_u_cos logits_in_sin_dir = neuron_acts_post_sq @ sigma_u_sin cos_components.append(fft2d(logits_in_cos_dir)) sin_components.append(fft2d(logits_in_sin_dir)) for title, components in zip(["Cosine", "Sine"], [cos_components, sin_components]): utils.imshow_fourier( t.stack(components), title=f"{title} components of neuron activations in Fourier basis", animation_frame=0, animation_name="Frequency", animation_labels=key_freqs.tolist(), ) Click to see the expected output Can you interpret this plot? Can you explain why this plot confirms our hypothesis about how logits are computed? Output (and explanation) Recall we are trying to show that: $$ \begin{aligned} \sigma_{2k-1} u_{2k-1} &\approx C_k \cos(\omega_k (x + y)) \\ \sigma_{2k} u_{2k} &\approx C_k \sin(\omega_k (x + y)) \end{aligned} $$ Writing this in the 2D Fourier basis, we get: $$ \begin{aligned} \sigma_{2k-1} u_{2k-1} &\approx \frac{C_k}{\sqrt{2}} \cos(\omega_k \vec{\textbf{x}})\cos(\omega_k \vec{\textbf{y}}) - \frac{C_k}{\sqrt{2}} \sin (\omega_k \vec{\textbf{x}})\sin (\omega_k \vec{\textbf{y}}) \\ \sigma_{2k} u_{2k} &\approx \frac{C_k}{\sqrt{2}} \cos(\omega_k \vec{\textbf{x}})\sin(\omega_k \vec{\textbf{y}}) + \frac{C_k}{\sqrt{2}} \sin (\omega_k \vec{\textbf{x}})\cos (\omega_k \vec{\textbf{y}}) \end{aligned} $$ You should find that these exepcted 2D Fourier coefficients match the ones you get on your plot (i.e. they are approximately equal in magnitude, and the same sign for $\sin$ / opposite sign for $\cos$). For instance, zooming in on the $\cos$ plot for frequency $k=14$, we get: Recap of section Let's review what we've learned about each part of the network. We found that: Embedding The embedding projects onto a select few 1D Fourier basis vectors, corresponding to a handful of key frequencies. Neuron Activations Each neuron's activations are a linear combination of the constant term and the linear and quadratic terms of a specific frequency. The neurons clearly cluster according to the key frequencies. We found this by: * Seeing how well explained the neurons' variance was by a specific frequency, and finding that every neuron was well explained by a single frequency (except for a group of neurons which always fired, in other words they weren't doing anything nonlinear). * Projecting our neuron activations onto these corresponding frequencies, and showing that this didn't increase the overall loss by much. These activations are calculated by some nonlinear magic involving ReLUs and attention layers (because calculating things like the product of inputs at 2 different sequence positions is not a natural operation for a transformer to learn!). Logit Computation The network uses $W_{logit}=W_{out}W_U$ to cancel out all 2D Fourier components other than the directions corresponding to $\cos(w(x+y)),\sin(w(x+y))$, and then multiplies these directions by $\cos(wz),\sin(wz)$ respectively and sums to get the output logits. We found this by: Showing that the linear terms more or less vanished from the logits, leaving only the quadratic terms behind Each of our neurons seemed to be associated with a particular frequency, and the output of a neuron in frequency cluster $k$ directly affected the logits in the $k$-th Fourier basis. The exact way that the neurons affected the logits at this frequency matched our initial guess of $\cos(\omega_k (x + y - \vec{\textbf{z}})) = \cos(\omega_k (x + y))\cos(\omega_k \vec{\textbf{z}}) + \sin(\omega_k (x + y))\sin(\omega_k \vec{\textbf{z}})$, i.e. each frequency cluster was responsible for computing this linear combination and adding it to the logit output.