Now that we've found what appear to be neuron clusters, it's time to validate our observations. We'll do this by showing that, for each neuron cluster, we can set terms for any other frequency to zero and still get good performance on the task.
Exercise - validate neuron clusters
Difficulty:
🔴🔴🔴🔴⚪
Importance:
🔵🔵🔵🔵⚪
You should spend up to 25-40 minutes on the following exercises. They are designed to get you to engage with ideas from linear algebra (specifically projections), and are very important conceptually.
We want to do the following:
Take neuron_acts_post, which is a tensor of shape (p*p, d_mlp), with the [i, j]-th element being the activation of neuron j on the i-th input sequence in our all_data batch.
Treating this tensor as p separate vectors in $\mathbb{R}^{p^2}$ (with the last dimension being the batch dimension), we will project each of these vectors onto the subspace of $\mathbb{R}^{p^2}$ spanned by the 2D Fourier basis vectors for the associated frequency of that particular neuron (i.e. the constant, linear, and quadratic terms).
Take these projected neuron_acts_post vectors, and apply W_logit to give us new logits. Compare the cross entropy loss with these logits to our original logits.
If our hypothesis is correct (i.e. that for each cluster of neurons associated with a particular frequency, that frequency is the only one that matters), then our loss shouldn't decrease by much when we project out the other frequencies in this way.
First, you'll need to write the function project_onto_direction. This takes two inputs: batch_vecs (a batch of vectors, with batch dimension at the end) and v (a single vector), and returns the projection of each vector in batch_vecs onto the direction v.
def project_onto_direction ( batch_vecs : Tensor , v : Tensor ) -> Tensor :
"""
Returns the component of each vector in `batch_vecs` in the direction of `v`.
batch_vecs.shape = (n, ...)
v.shape = (n,)
"""
raise NotImplementedError ()
tests . test_project_onto_direction ( project_onto_direction )
Hint
Recall that the projection of $w$ onto $v$ (for a normalized vector $v$) is given by:
$$
w_{proj} = (w \cdot v) v
$$
You might find it easier to do this in two steps: first calculate the components in the $v$-direction by taking the inner product along the 0th dimension, then create a batch of multiples of $v$, scaled by these components.
Solution
def project_onto_direction ( batch_vecs : Tensor , v : Tensor ) -> Tensor :
"""
Returns the component of each vector in `batch_vecs` in the direction of `v`.
batch_vecs.shape = (n, ...)
v.shape = (n,)
"""
# Get tensor of components of each vector in v-direction
components_in_v_dir = einops . einsum ( batch_vecs , v , "n ..., n -> ..." )
# Use these components as coefficients of v in our projections
return einops . einsum ( components_in_v_dir , v , "..., n -> n ..." )
Next, you should write the function project_onto_frequency. This takes a batch of vectors with shape (p**2, batch), and a frequency freq, and returns the projection of each vector onto the subspace spanned by the nine 2D Fourier basis vectors for that frequency (i.e. one constant term, four linear terms, and four quadratic terms).
def project_onto_frequency ( batch_vecs : Tensor , freq : int ) -> Tensor :
"""
Returns the projection of each vector in `batch_vecs` onto the
2D Fourier basis directions corresponding to frequency `freq`.
batch_vecs.shape = (p**2, ...)
"""
assert batch_vecs . shape [ 0 ] == p ** 2
raise NotImplementedError ()
tests . test_project_onto_frequency ( project_onto_frequency )
Hint
This will just involve summing nine calls to the project_onto_direction function you wrote above (one for each basis vector you're projecting onto), since your basis vectors are orthogonal.
You should use the function fourier_2d_basis_term to get the vectors you'll be projecting onto. Remember to flatten these vectors, because you're working with vectors of length p**2 rather than of size (p, p)!
Solution
The for loop in this code goes over the indices for the constant term (0, 0), the linear terms (2f-1, 0), (2f, 0), (0, 2f-1), (0, 2f), and the quadratic terms (2f-1, 2f-1), (2f-1, 2f), (2f, 2f-1), (2f, 2f).
def project_onto_frequency ( batch_vecs : Tensor , freq : int ) -> Tensor :
"""
Returns the projection of each vector in `batch_vecs` onto the
2D Fourier basis directions corresponding to frequency `freq`.
batch_vecs.shape = (p**2, ...)
"""
assert batch_vecs . shape [ 0 ] == p ** 2
return sum (
[
project_onto_direction (
batch_vecs ,
fourier_2d_basis_term ( i , j ) . flatten (),
)
for i in [ 0 , 2 * freq - 1 , 2 * freq ]
for j in [ 0 , 2 * freq - 1 , 2 * freq ]
]
)
Finally, run the following code to project out the other frequencies from the neuron activations, and compare the new loss. You should make sure you understand what this code is doing.
logits_in_freqs = []
for freq in key_freqs :
# Get all neuron activations corresponding to this frequency
filtered_neuron_acts = neuron_acts_post [:, neuron_freqs == freq ]
# Project onto const/linear/quadratic terms in 2D Fourier basis
filtered_neuron_acts_in_freq = project_onto_frequency ( filtered_neuron_acts , freq )
# Calcluate new logits, from these filtered neuron activations
logits_in_freq = filtered_neuron_acts_in_freq @ W_logit [ neuron_freqs == freq ]
logits_in_freqs . append ( logits_in_freq )
# We add on neurons in the always firing cluster, unfiltered
logits_always_firing = neuron_acts_post [:, neuron_freqs == - 1 ] @ W_logit [ neuron_freqs == - 1 ]
logits_in_freqs . append ( logits_always_firing )
# Compute new losses
key_freq_loss = utils . test_logits ( sum ( logits_in_freqs ), bias_correction = True , original_logits = original_logits )
key_freq_loss_no_always_firing = utils . test_logits (
sum ( logits_in_freqs [: - 1 ]), bias_correction = True , original_logits = original_logits
)
# Print new losses
print ( f """Loss with neuron activations ONLY in key freq (including always firing cluster): { key_freq_loss : .3e }
Loss with neuron activations ONLY in key freq (exclusing always firing cluster): { key_freq_loss_no_always_firing : .3e }
Original loss: { original_loss : .3e }
""" )
Click to see the expected output
Loss with neuron activations ONLY in key freq (including always firing cluster)
2.482e-07
Loss with neuron activations ONLY in key freq (exclusing always firing cluster)
1.179e-06
Original loss
2.412e-07
You should find that the loss doesn't change much when you project out the other frequencies, and even if you remove the always firing cluster the loss is still very small.
We can also compare the importance of each cluster of neurons by ablating it (while continuing to restrict each cluster to its frequency). We see from this that freq=52 is the most important cluster (because the loss increases a lot when this is removed), although clearly all clusters are important (because the loss is still very small if any one of them is ablated). We also see that ablating the always firing cluster has a very small effect, so clearly this cluster isn't very helpful for the task. (This is something we might have guessed beforehand, since the ReLU never firing makes this essentially a linear function, and in general having non-linearities allows you to learn much more expressive functions.)
print ( "Loss with neuron activations excluding none: {:.9f} " . format ( original_loss . item ()))
for c , freq in enumerate ( key_freqs_plus ):
print (
"Loss with neuron activations excluding freq= {} : {:.9f} " . format (
freq ,
utils . test_logits (
sum ( logits_in_freqs ) - logits_in_freqs [ c ],
bias_correction = True ,
original_logits = original_logits ,
),
)
)
Click to see the expected output
Loss with neuron activations excluding none: 0.000000241
Loss with neuron activations excluding freq=14: 0.000199645
Loss with neuron activations excluding freq=35: 0.000458851
Loss with neuron activations excluding freq=41: 0.001917986
Loss with neuron activations excluding freq=42: 0.005197690
Loss with neuron activations excluding freq=52: 0.024398957
Loss with neuron activations excluding freq=-1: 0.000001179
Understanding Logit Computation
TLDR: The network uses $W_{logit}=W_{out}W_U$ to cancel out all 2D Fourier components other than the directions corresponding to $\cos(w(x+y)),\sin(w(x+y))$, and then multiplies these directions by $\cos(wz),\sin(wz)$ respectively and sums to get the output logits.
Recall that (for each neuron cluster with associated frequency $k$), each neuron's activations are a linear combination of const, linear and quadratic terms:
$$
\begin{bmatrix}
1 & \cos(\omega_k x) & \sin(\omega_k x) \\
\cos(\omega_k y) & \cos(\omega_k x)\cos(\omega_k y) & \sin(\omega_k x)\cos(\omega_k y) \\
\sin(\omega_k y) & \cos(\omega_k x)\sin(\omega_k y) & \sin(\omega_k x)\sin(\omega_k y)
\end{bmatrix}
$$
for $\omega_k = 2\pi k / p$.
To calculate the logits, the network cancels out all directions apart from:
$$
\begin{aligned}
\cos(\omega_k (x+y)) &= \cos(\omega_k x)\cos(\omega_k y)-\sin(\omega_k x)\sin(\omega_k y) \\
\sin(\omega_k (x+y)) &= \sin(\omega_k x)\cos(\omega_k y)+\cos(\omega_k x)\sin(\omega_k y)
\end{aligned}
$$
The network then multiplies these by $\cos(wz),\sin(wz)$ and sums (i.e. the logit for value $z$ will be the product of these with $\cos(wz),\sin(wz)$, summed over all neurons in that cluster, summed over all clusters).
Question - can you explain why this algorithm works?
Hint
Think about the following expression:
$$
\cos(\omega (x+y))\cos(\omega z)+\sin(\omega (x+y))\sin(\omega z)
$$
which is (up to a scale factor) the value added to the logit score for $z$.
Answer
The reason is again thanks to our trig formulas! Each neuron will add to the final logits a vector which looks like:
$$
\cos(w(x+y))\cos(wz)+\sin(w(x+y))\sin(wz)
$$
which we know from our trig formulas equals:
$$
\cos(w(x+y-z))
$$
which is largest when $z=x+y$.
Another way of writing this would be that, on inputs (x, y), the model's logit output is (up to a scale factor) equal to:
$$
\cos(\omega(x+y-\vec{\textbf{z}})) = \begin{bmatrix}
\cos(\omega(x+y)) \\
\cos(\omega(x+y-1)) \\
\vdots \\
\cos(\omega(x+y-(p-1)))
\end{bmatrix}
$$
This vector is largest at element with index $x+y$, meaning the logit for $x+y$ will be largest (which is exactly what we want to happen, to solve our problem!).
Also, remember that we have several different frequencies $\omega_k$, and so when we sum over neurons, the vectors will combine constructively at $z = x+y$, and combine destructively everywhere else:
$$
f(t) = \sum_{k \in K} C_k \cos(\omega_k (x + y - \vec{\textbf{z}}))
$$
(where $C_k$ are large positive constants, which we'll explicitly calculate later on).
Logits in Fourier Basis
To see that the network cancels out other directions, we can transform both the neuron activations and logits to the 2D Fourier Basis, and show the norm of the vector corresponding to each Fourier component - we see that the quadratic terms have much higher norm in the logits than neuron activations, and linear terms are close to zero. Remember that, to get from activations to logits, we apply the linear map $W_{logit}=W_{out}W_U$, which involves summing the outputs of all neurons.
Below is some code to visualise this. Note the use of einops.reduce rather than the mean method in the code below. Like most other uses of einops, this is useful because it's explicit and readable.
utils . imshow_fourier (
einops . reduce ( neuron_acts_centered_fourier . pow ( 2 ), "y x neuron -> y x" , "mean" ),
title = "Norm of Fourier Components of Neuron Acts" ,
)
# Rearrange logits, so the first two dims represent (x, y) in modular arithmetic equation
original_logits_sq = einops . rearrange ( original_logits , "(x y) z -> x y z" , x = p )
original_logits_fourier = fft2d ( original_logits_sq )
utils . imshow_fourier (
einops . reduce ( original_logits_fourier . pow ( 2 ), "y x z -> y x" , "mean" ),
title = "Norm of Fourier Components of Logits" ,
)
Click to see the expected output
You should find that the linear and constant terms have more or less vanished relative to the quadratic terms, and that the quadratic terms are much larger in the logits than the neuron activations. This is annotated in the plots below (which should match the results you get from running the code):
Exercise - validate by only taking quadratic terms
Difficulty:
🔴🔴🔴⚪⚪
Importance:
🔵🔵🔵⚪⚪
You should spend up to 10-20 minutes on this exercise. This exercise should feel similar to the previous one, since it's about vectors and projections.
Here, you will validate your results still further by just taking the components of the logits corresponding to $\cos(\omega_k(x+y))$ and $\sin(\omega_k(x+y))$ for each of our key frequencies $k$, and showing this increases performance.
First, you should write a function get_trig_sum_directions, which takes in a frequency k and returns the (p, p)-size vectors in 2D Fourier space corresponding to the directions:
$$
\begin{aligned}
\cos(\omega_k (\vec{\textbf{x}}+\vec{\textbf{y}})) &= \cos(\omega_k \vec{\textbf{x}})\cos(\omega_k \vec{\textbf{y}})-\sin(\omega_k \vec{\textbf{x}})\sin(\omega_k \vec{\textbf{y}}) \\
\sin(\omega_k (\vec{\textbf{x}}+\vec{\textbf{y}})) &= \sin(\omega_k \vec{\textbf{x}})\cos(\omega_k \vec{\textbf{y}})+\cos(\omega_k \vec{\textbf{x}})\sin(\omega_k \vec{\textbf{y}})
\end{aligned}
$$
respectively. Remember, the vectors you return should be normalized.
def get_trig_sum_directions ( k : int ) -> tuple [ Float [ Tensor , "p p" ], Float [ Tensor , "p p" ]]:
"""
Given frequency k, returns the normalized vectors in the 2D Fourier basis representing the
two directions:
cos(ω_k * (x + y))
sin(ω_k * (x + y))
respectively.
"""
raise NotImplementedError ()
tests . test_get_trig_sum_directions ( get_trig_sum_directions )
Hint
You can get the vector $\cos(\omega_k \vec{\textbf{x}}) \cos(\omega_k \vec{\textbf{y}})$ as follows:
cosx_cosy_direction = fourier_2d_basis_term ( 2 * k - 1 , 2 * k - 1 )
Solution
def get_trig_sum_directions ( k : int ) -> tuple [ Float [ Tensor , "p p" ], Float [ Tensor , "p p" ]]:
"""
Given frequency k, returns the normalized vectors in the 2D Fourier basis representing the
two directions:
cos(ω_k * (x + y))
sin(ω_k * (x + y))
respectively.
"""
cosx_cosy_direction = fourier_2d_basis_term ( 2 * k - 1 , 2 * k - 1 )
sinx_siny_direction = fourier_2d_basis_term ( 2 * k , 2 * k )
sinx_cosy_direction = fourier_2d_basis_term ( 2 * k , 2 * k - 1 )
cosx_siny_direction = fourier_2d_basis_term ( 2 * k - 1 , 2 * k )
cos_xplusy_direction = ( cosx_cosy_direction - sinx_siny_direction ) / np . sqrt ( 2 )
sin_xplusy_direction = ( sinx_cosy_direction + cosx_siny_direction ) / np . sqrt ( 2 )
return cos_xplusy_direction , sin_xplusy_direction
Once you've passed these tests, you can run the code to project the logits onto these directions, and see how the loss changes. Note the use of the project_onto_direction function which you wrote earlier.
trig_logits = []
for k in key_freqs :
cos_xplusy_direction , sin_xplusy_direction = get_trig_sum_directions ( k )
cos_xplusy_projection = project_onto_direction ( original_logits , cos_xplusy_direction . flatten ())
sin_xplusy_projection = project_onto_direction ( original_logits , sin_xplusy_direction . flatten ())
trig_logits . extend ([ cos_xplusy_projection , sin_xplusy_projection ])
trig_logits = sum ( trig_logits )
print ( f "Loss with just x+y components: { utils . test_logits ( trig_logits , True , original_logits ) : .4e } " )
print ( f "Original Loss: { original_loss : .4e } " )
Click to see the expected output
Loss with just x+y components: 5.5474e-08
Original Loss: 2.4122e-07
You should find that the loss with just these components is significantly lower than your original loss. This is very strong evidence that we've correctly identified the algorithm used by our model.
$W_{logit}$ in Fourier Basis
Okay, so we know that the model is mainly working with the terms $\cos(\omega_k(x+y))$ and $\sin(\omega_k(x+y))$ for each of our key frequencies $k$. Now, we want to show that the model's final ouput is:
$$
\cos(\omega_k (x + y - \vec{\textbf{z}})) = \cos(\omega_k (x + y))\cos(\omega_k \vec{\textbf{z}}) + \sin(\omega_k (x + y))\sin(\omega_k \vec{\textbf{z}})
$$
How do we do this?
Answer: we examine $W_{logit}$ in the Fourier basis. If we think that $W_{logit}$ is mainly projecting onto the directions $\cos(\omega_k \vec{\textbf{z}})$ and $\sin(\omega_k \vec{\textbf{z}})$, then we expect to find:
$$
W_{logit} \approx U S F = \sum_{i=1}^p \sigma_i u_i f_i^T
$$
where the singular values $\sigma_i$ are zero except those corresponding to Fourier basis vectors $f_i = \cos(\omega_k \vec{\textbf{z}}), \sin(\omega_k \vec{\textbf{z}})$ for key frequencies $k$. In other words:
$$
W_{logit} \approx \sum_{k \in K} \sigma_{2k-1} u_{2k-1} \cos(\omega_k \vec{\textbf{z}}) + \sigma_{2k} u_{2k} \sin(\omega_k \vec{\textbf{z}})
$$
Thus, if we right-multiply $W_{logit}$ by $F^T$, we should get a matrix $W_{logit} F^T \approx US$ of shape (d_mlp, p), with all columns zero except for $\sigma_{2k-1}u_{2k-1}$ and $\sigma_{2k} u_{2k}$ for key frequencies $k$. Let's verify this:
US = W_logit @ fourier_basis . T
utils . imshow_div ( US , x = fourier_basis_names , yaxis = "Neuron index" , title = "W_logit in the Fourier Basis" )
Click to see the expected output
You should see that the columns of this matrix are only non-zero at positions $2k$, $2k-1$ for the key frequencies. Since our model's final output is just a linear combination of these columns (with the coefficients given by the neuron activations), this proves that $W_{logit}$ is projecting onto directions corresponding to our key frequencies.
Note the contrast between what we just did and what we've done before. In previous sections, we've taken the 2D Fourier transform of our activations / effective weights with respect to the input space (the vectors were $\vec{\textbf{x}}$ and $\vec{\textbf{y}}$). Now, we're taking the 1D Fourier tranformation with respect to the output space (the vectors are $\vec{\textbf{z}})$. It's pretty cool that this works!
So we've proven that:
$$
W_{logit} \approx \sum_{k \in K} \sigma_{2k-1} u_{2k-1} \cos(\omega_k \vec{\textbf{z}}) + \sigma_{2k} u_{2k} \sin(\omega_k \vec{\textbf{z}})
$$
but we still want to show that our final output is:
$$
f(t) = n_{post}^T W_{logit} \approx \sum_{k \in K} C_k \big(\cos(\omega_k (x + y)) \cos(\omega_k \vec{\textbf{z}}) + \sin(\omega_k (x + y)) \sin(\omega_k \vec{\textbf{z}})\big)
$$
where $n_{post} \in \mathbb{R}^{d_{mlp}}$ is the vector of neuron activations, and $C_k$ are large positive constants.
Matching coefficients of the vectors $\cos(\omega_k \vec{\textbf{z}})$ and $\sin(\omega_k \vec{\textbf{z}})$, this means we want to show that:
$$
\begin{aligned}
\sigma_{2k-1} u_{2k-1} &\approx C_k \cos(\omega_k (x + y)) \\
\sigma_{2k} u_{2k} &\approx C_k \sin(\omega_k (x + y))
\end{aligned}
$$
for each key frequency $k$.
First, let's do a quick sanity check. We expect vectors $u_{2k-1}$ and $u_{2k}$ to only contain components of frequency $k$, which means we expect the only non-zero elements of these vectors to correspond to the neurons in the $k$-frequency cluster. Let's test this by rearranging the matrix $W_{logit}F^T \approx US$, so that the neurons in each cluster are grouped together:
US_sorted = t . concatenate ([ US [ neuron_freqs == freq ] for freq in key_freqs_plus ])
hline_positions = np . cumsum ([( neuron_freqs == freq ) . sum () . item () for freq in key_freqs ]) . tolist () + [ cfg . d_mlp ]
utils . imshow_div (
US_sorted ,
x = fourier_basis_names ,
yaxis = "Neuron" ,
title = "W_logit in the Fourier Basis (rearranged by neuron cluster)" ,
hline_positions = hline_positions ,
hline_labels = [ f "Cluster: { freq =} " for freq in key_freqs . tolist ()] + [ "No freq" ],
)
Click to see the expected output
You should find that, for each frequency $k$, the components of the output in directions $\cos(\omega_k \vec{\textbf{z}})$ and $\sin(\omega_k \vec{\textbf{z}})$ are determined only by the neurons in the $k$-cluster, i.e. they are determined only by the 2D Fourier components of the input $(x, y)$ with frequency $k$.
This is promising, but we still haven't shown that $\sigma_{2k-1} u_{2k-1} \propto \cos(\omega_k(x+y))$, etc. To do this, we'll calculate the vectors $\sigma_{2k-1} u_{2k-1}$ and $\sigma_{2k} u_{2k}$ over all inputs $(x, y)$, then take the 2D Fourier transform.
cos_components = []
sin_components = []
for k in key_freqs :
sigma_u_sin = US [:, 2 * k ]
sigma_u_cos = US [:, 2 * k - 1 ]
logits_in_cos_dir = neuron_acts_post_sq @ sigma_u_cos
logits_in_sin_dir = neuron_acts_post_sq @ sigma_u_sin
cos_components . append ( fft2d ( logits_in_cos_dir ))
sin_components . append ( fft2d ( logits_in_sin_dir ))
for title , components in zip ([ "Cosine" , "Sine" ], [ cos_components , sin_components ]):
utils . imshow_fourier (
t . stack ( components ),
title = f " { title } components of neuron activations in Fourier basis" ,
animation_frame = 0 ,
animation_name = "Frequency" ,
animation_labels = key_freqs . tolist (),
)
Click to see the expected output
Can you interpret this plot? Can you explain why this plot confirms our hypothesis about how logits are computed?
Output (and explanation)
Recall we are trying to show that:
$$
\begin{aligned}
\sigma_{2k-1} u_{2k-1} &\approx C_k \cos(\omega_k (x + y)) \\
\sigma_{2k} u_{2k} &\approx C_k \sin(\omega_k (x + y))
\end{aligned}
$$
Writing this in the 2D Fourier basis, we get:
$$
\begin{aligned}
\sigma_{2k-1} u_{2k-1} &\approx \frac{C_k}{\sqrt{2}} \cos(\omega_k \vec{\textbf{x}})\cos(\omega_k \vec{\textbf{y}}) - \frac{C_k}{\sqrt{2}} \sin (\omega_k \vec{\textbf{x}})\sin (\omega_k \vec{\textbf{y}}) \\
\sigma_{2k} u_{2k} &\approx \frac{C_k}{\sqrt{2}} \cos(\omega_k \vec{\textbf{x}})\sin(\omega_k \vec{\textbf{y}}) + \frac{C_k}{\sqrt{2}} \sin (\omega_k \vec{\textbf{x}})\cos (\omega_k \vec{\textbf{y}})
\end{aligned}
$$
You should find that these exepcted 2D Fourier coefficients match the ones you get on your plot (i.e. they are approximately equal in magnitude, and the same sign for $\sin$ / opposite sign for $\cos$). For instance, zooming in on the $\cos$ plot for frequency $k=14$, we get:
Recap of section
Let's review what we've learned about each part of the network. We found that:
Embedding
The embedding projects onto a select few 1D Fourier basis vectors, corresponding to a handful of key frequencies .
Neuron Activations
Each neuron's activations are a linear combination of the constant term and the linear and quadratic terms of a specific frequency . The neurons clearly cluster according to the key frequencies .
We found this by:
* Seeing how well explained the neurons' variance was by a specific frequency, and finding that every neuron was well explained by a single frequency (except for a group of neurons which always fired, in other words they weren't doing anything nonlinear).
* Projecting our neuron activations onto these corresponding frequencies, and showing that this didn't increase the overall loss by much.
These activations are calculated by some nonlinear magic involving ReLUs and attention layers (because calculating things like the product of inputs at 2 different sequence positions is not a natural operation for a transformer to learn!).
Logit Computation
The network uses $W_{logit}=W_{out}W_U$ to cancel out all 2D Fourier components other than the directions corresponding to $\cos(w(x+y)),\sin(w(x+y))$, and then multiplies these directions by $\cos(wz),\sin(wz)$ respectively and sums to get the output logits.
We found this by:
Showing that the linear terms more or less vanished from the logits, leaving only the quadratic terms behind
Each of our neurons seemed to be associated with a particular frequency, and the output of a neuron in frequency cluster $k$ directly affected the logits in the $k$-th Fourier basis.
The exact way that the neurons affected the logits at this frequency matched our initial guess of $\cos(\omega_k (x + y - \vec{\textbf{z}})) = \cos(\omega_k (x + y))\cos(\omega_k \vec{\textbf{z}}) + \sin(\omega_k (x + y))\sin(\omega_k \vec{\textbf{z}})$, i.e. each frequency cluster was responsible for computing this linear combination and adding it to the logit output.