1️⃣ Periodicity & Fourier basis Learning Objectives Understand the problem statement, the model architecture, and the corresponding and functional form of any possible solutions. Learn about the Fourier basis (1D and 2D), and how it can be used to represent arbitrary functions. Understand that periodic functions are sparse in the Fourier basis, and how this relates to the model's weights. Model architecture First, let's define our model, and some useful activations and weights as shorthand. To review the information given in the previous page: The model we will be reverse-engineering today is a one-layer transformer, with no layer norm and learned positional embeddings. $d_{model} = 128$, $n_{heads} = 4$, $d_{head}=32$, $d_{mlp}=512$. The task this model was trained on is addition modulo the prime $p = 113$. The input format is a sequence of three tokens [x, y, =], with $d_{vocab}=114$ (integers from $0$ to $p - 1$ and $=$). The prediction for the next token after = should be the token corresponding to $x + y \pmod{p}$. Run the code below to define your model: p = 113 cfg = HookedTransformerConfig( n_layers=1, d_vocab=p + 1, d_model=128, d_mlp=4 * 128, n_heads=4, d_head=128 // 4, n_ctx=3, act_fn="relu", normalization_type=None, device=device, ) model = HookedTransformer(cfg) Next, run the code below to download the data from GitHub & HuggingFace. if not grokking_root.exists(): os.system( f'git clone https://github.com/neelnanda-io/Grokking.git "{grokking_root.as_posix()}"' ) assert grokking_root.exists() os.mkdir(grokking_root / "large_files") REPO_ID = "callummcdougall/grokking_full_run_data" FILENAME = "full_run_data.pth" local_dir = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) full_run_data = t.load(local_dir, weights_only=True) state_dict = full_run_data["state_dicts"][400] model = utils.load_in_state_dict(model, state_dict) Once this has finished, you can load in your weights, using these helper functions: Before we start doing mech interp on our model, let's have a look at our loss curves. How quickly do the model's training and test loss curves come down? utils.lines( lines_list=[full_run_data["train_losses"][::10], full_run_data["test_losses"]], labels=["train loss", "test loss"], title="Grokking Training Curve", x=np.arange(5000) * 10, xaxis="Epoch", yaxis="Loss", log_y=True, width=900, height=450, ) This is fascinating! We can see that the model initially memorises traiing data (train loss curve falls sharply to almost zero, while test loss curve actually goes up), but eventually "groks" the task, i.e. suddenly learns to generalise on unseen data. This section and the next will focus on doing mech interp with our model. The third section will investigate plots like this one in more detail & track other metrics over time. The last section discusses some higher-level implications of this work, and possible future directions. ## Helper variables Let's define some useful variables, and print out their shape to verify they are what we expect: # Helper variables W_O = model.W_O[0] W_K = model.W_K[0] W_Q = model.W_Q[0] W_V = model.W_V[0] W_in = model.W_in[0] W_out = model.W_out[0] W_pos = model.W_pos W_E = model.W_E[:-1] final_pos_resid_initial = model.W_E[-1] + W_pos[2] W_U = model.W_U[:, :-1] print("W_O ", tuple(W_O.shape)) print("W_K ", tuple(W_K.shape)) print("W_Q ", tuple(W_Q.shape)) print("W_V ", tuple(W_V.shape)) print("W_in ", tuple(W_in.shape)) print("W_out", tuple(W_out.shape)) print("W_pos", tuple(W_pos.shape)) print("W_E ", tuple(W_E.shape)) print("W_U ", tuple(W_U.shape)) W_O (4, 32, 128) W_K (4, 128, 32) W_Q (4, 128, 32) W_V (4, 128, 32) W_in (128, 512) W_out (512, 128) W_pos (3, 128) W_E (113, 128) W_U (128, 113) Note here - we've taken slices of the embedding and unembedding matrices, to remove the final row/column (which corresponds to the `=` token). We've done this so that we can peform a Fourier transform on these weights later on. From now on, when we refer to $W_E$ and $W_U$, we'll usually be referring to these smaller matrices. We've explicitly defined `final_pos_resid_initial` because this will be needed later (to get the query vector for sequence position 2). Also note we've indexed many of these matrices by `[0]`, this is because the first dimension is the layer dimension and our model only has one layer. Next, we'll run our model on all data. It's worth being clear on what we're doing here - we're taking every single one of the $p^2 = 113^2 = 12769$ possible sequences, stacking them into a single batch, and running the model on them. This only works because, in this particular problem, our universe is pretty small. We'll use the `run_with_cache` method to store all the intermediate activations. # Get all data and labels, and cache activations all_data = t.tensor([(i, j, p) for i in range(p) for j in range(p)]).to(device) labels = t.tensor([utils.target_fn(i, j) for i, j, _ in all_data]).to(device) original_logits, cache = model.run_with_cache(all_data) # Final position only, also remove the logits for `=` original_logits = original_logits[:, -1, :-1] # Get cross entropy loss original_loss = utils.cross_entropy_high_precision(original_logits, labels) print(f"Original loss: {original_loss.item():.3e}") Original loss: 2.412e-07 ### Exercise - extract key activations Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to 5-10 minutes on these exercises. These are just designed to re-familiarize yourself with the ActivationCache object and how to use it. Some important activations which we'll be investigating later are the attention matrices and neuron activations. In the code below, you should define the following: * `attn_mat`: the attention patterns for query token `=`, over all sequences. This should have shape `(batch, head, key_posn)`, which in this case is `(12769, 4, 3)`. * Note that we only care when `=` is the query token, because this is the position we get our classifications from. * `neuron_acts_post`: the neuron activations **for the last sequence position**, after applying our ReLU function. This should have shape `(batch, d_mlp)`, which in this case is `(12769, 512)`. * Note again that we only care about the last sequence position - can you see why? * `neuron_acts_pre`: same as above, but before applying ReLU. You can check your results by printing the tensor shapes. # YOUR CODE HERE - get the relevant activations # Test shapes assert attn_mat.shape == (p * p, cfg.n_heads, 3) assert neuron_acts_post.shape == (p * p, cfg.d_mlp) assert neuron_acts_pre.shape == (p * p, cfg.d_mlp) # Test values tests.test_cache_activations(attn_mat, neuron_acts_post, neuron_acts_pre, cache) Solution attn_mat = cache["pattern", 0][:, :, 2] neuron_acts_post = cache["post", 0][:, -1] neuron_acts_pre = cache["pre", 0][:, -1] ## Functional form Next, let's think about the functional form of our model's solution. ### Exercise - answer some initial questions Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to 20-30 minutes on these exercises. Thinking about the functional form of your model before you start analysing a toy problem is an important skill. Here are a few questions, designed to get you thinking about the problem and how it relates to the model's internals. You can find answers to all of them (as well as more thorough discussion of other points) in the dropdown below. * Of the six distinct pieces of information fed into the model (three token embeddings and three positional embeddings), which ones are relevant for solving the modular addition task? * What does this imply about the role of position embeddings? * What should the attention pattern look like? Which parts of the attention pattern will even matter? * What will the role of the direct path (i.e. embeddings -> unembeddings, without any MLP or attention) be? How about the path that goes through the MLP layer but not the attention layer? * What kinds of symmetries to you expect to see in the model? Answers The position embeddings are irrelevant, since addition is commutative. In fact, this results in these position embeddings being approximately symmetric. Only the token embeddings of the first two tokens are relevant, since the last token is always =. The attention pattern should be such that position 2 pays attention only to positions 0 and 1, since position 2 has constant embeddings and provides no relevant information. (Note that it could act as a bias term; however, this is discouraged by the use of heavy weight decay during training and does not occur empirically.) The direct path provides no relevant information and hence only acts as a bias term. Empirically, ablating the residual stream to zero before applying the unembedding matrix does not hurt performance very much. The same goes for the path through the MLP layer but not the attention layer (because information can't move from the x, y tokens to the token we'll use for prediction). As mentioned, addition is commutative, so we expect to see symmetries betweeen how the model deals with the first two tokens. Evidence for this: We can look at the difference between the position embeddings for pos 0 and pos 1 and see that they are close together and have high cosine similarity. We look at the difference between the neuron activations and the transpose of the neuron activations (i.e. compare $N(x, y)$ and $N(y, x)$) and see that they are close together. # Get the first three positional embedding vectors W_pos_x, W_pos_y, W_pos_equals = W_pos # Look at the difference between positional embeddings; show they are symmetric def compare_tensors(v, w): return ((v-w).pow(2).sum()/v.pow(2).sum().sqrt()/w.pow(2).sum().sqrt()).item() print('Difference in position embeddings', compare_tensors(W_pos_x, W_pos_y)) print('Cosine similarity of position embeddings', t.cosine_similarity(W_pos_x, W_pos_y, dim=0).item()) # Compare N(x, y) and N(y, x) neuron_acts_square = neuron_acts.reshape(p, p, d_mlp) print('Difference in neuron activations for (x,y) and (y,x): {.2f}'.format( compare_tensors( neuron_acts_square, einops.rearrange(neuron_acts_square, "x y d_mlp -> y x d_mlp") ) )) This makes sense, because addition is commutative! Positions 0 and 1 should be symmetric. Evidence that attention from position 2 to itself is negligible - I plot the average attention to each position for each head across all data points, and see that $2\to 2$ averages to near zero (and so is almost always near zero, as attention is always positive), and $2\to 0$ and $2 \to 1$ both average to zero, as we'd expect from symmetry. imshow(attn_mat.mean(0), xaxis='Position', yaxis='Head', title='Average Attention by source position and head', text_auto=".3f") (Note that we could use circuitsvis to plot these attention patterns, but here we don't lose anything by using Plotly, since our analysis of attention patterns isn't too complicated.) ### Exercise - derive the functional form Difficulty: 🔴🔴🔴🔴⚪ Importance: 🔵🔵🔵🔵⚪ You should spend up to 20-30 minutes on these exercises. This exercise is challenging, and involves only maths and no coding. You should look at parts of the solution if you get stuck, because there are several steps involved. Even if you can't get all the steps, any progress is good! There exists a [comprehensive mathematical framework for understanding transformer circuits](https://transformer-circuits.pub/2021/framework/index.html), as you may have encountered in previous exercises. However, we will not need the full power of that framework today to understand the circuit responsible for modular additon, both because our model only has a single layer and because our task has some special structure that we can exploit. Consider the following simplifying assumptions about our model: * The position embeddings are irrelevant and can be zero-ablated without loss of performance; * The residual stream is irrelevant and can be mean-ablated without loss of performance; * For every head, position `2` only pays attention to positions `0` and `1`. Write down the function $\ell = f(t)$ computed by the model, where $\ell \in \mathbb{R}^p$ is a vector of logits for each token and $t \in \mathbb{R}^{2 \times p}$ is a one-hot vector representing the input integers $m$ and $n$. Simplify the expression obtaind as far as possible. What can be said about it? Hint - diagram Here is a diagram making the different computational stages more explicit. Can you use this to write down a closed-form expression for $f(t)$? Hint - first steps Your solution will look like: $$ f(t) = \operatorname{MLP}(\operatorname{Attn}(tW_E)_2)W_U $$ where $t \in R^{3 \times p}$ are the vectors of 1-hot encoded tokens, $\operatorname{MLP}$ denotes the MLP layer (which acts identically on the residual stream vectors at each sequence position), and $\operatorname{Attn}$ denotes the attention layer(so we take the value at sequence position 2, since this is where we take our predictions from). From here, can you write $\operatorname{MLP}$ and $\operatorname{Attn}(\cdot)_2$ in terms of the actual matrices? Can you simplify by assuming that token 2 only pays attention to tokens 0 and 1? Answer Let's work through the model step-by-step. Let $n_\mathrm{seq} = 3$ denote the sequence length, so our (one-hot encoded) input tokens are $t \in \mathbb{R}^{n_\mathrm{seq} \times p}$. This contains the one-hot encoded integers $t_0$ and $t_1$, as well as the one-hot encoded equals sign $t_2$ (which is the same for all inputs). After applying the embedding matrix $W_E \in \mathbb{R}^{p \times d_\mathrm{model}}$, we get the embeddings: $$ v = t W_E \in \mathbb{R}^{n_\mathrm{seq} \times d_\mathrm{model}}. $$ Our function will look something like: $$ f(t) = \operatorname{MLP}(\operatorname{Attn}(v)_2)W_U $$ where $\operatorname{MLP}$ denotes the MLP layer (which acts identically on the residual stream vectors at each sequence position), and $\operatorname{Attn}$ denotes the attention layer(so we take the value at sequence position 2, since this is where we take our predictions from). Note, this ignores all other residual stream terms. The only other paths which might be important are those going through the attention layer but not the MLP, but we can guess that by far the most significant ones will be those going through both. Let's first address the MLP, because it's simpler. The functional form is just: $$ \operatorname{MLP}(w) = \operatorname{ReLU}\left(w^T W_{in}\right)W_{out} $$ where $w \in \mathbb{R}^{d_\mathrm{model}}$ is a vector in the residual stream (i.e. at some sequence position) after applying $\operatorname{Attn}$. Now let's think about the attention. We have: $$ \begin{aligned} \operatorname{Attn}(v)_2&=\sum_h \operatorname{softmax}\left(\frac{v_2^{\top} W_Q^h (W_K^h)^T\left[v_0 \, v_1\right]}{\sqrt{d_{head}}}\right)\left[v_0\, v_1\right]^T W_V^h W_O^h \\ &= \sum_h (\alpha^h v_0 + (1 - \alpha^h) v_1)^T W_V^h W_O^h \\ &\in \mathbb{R}^{d_{model}} \end{aligned} $$ where $v_0, v_1, v_2 \in \mathbb{R}^{d_{model}}$ are the the three embedding vectors in the residual stream, and $\alpha^h$ is the attention probability that token 2 pays to token 0 in head $h$. Note that we've ignored the attention paid by token 2 to itself (because we've seen that this is near zero). This is why we've replaced the key-side term $v = t W_E$ with just the first two vectors $\left[v_0 \, v_1\right]$, and so the softmax is just over the key positions $\{0, 1\}$. Can we simplify the formula for $\alpha^h$? As it turns out, yes. We're softmaxing over 2 dimensions, which is equivalent to sigmoid of the difference between logits: $$ \operatorname{softmax}\left(\begin{array}{c} \alpha \\ \beta \end{array}\right)=\left(\begin{array}{c} e^\alpha / (e^\alpha+e^\beta) \\ e^\beta / (e^\alpha+e^\beta) \end{array}\right) = \left(\begin{array}{c} \sigma(\alpha-\beta) \\ 1-\sigma(\alpha-\beta) \end{array}\right) $$ so we can write: $$ \begin{aligned} \alpha^h &= \sigma\left(\frac{v_2^{\top} W_Q^h (W_K^h)^Tv_0}{\sqrt{d_{head}}} - \frac{v_2^{\top} W_Q^h (W_K^h)^Tv_1}{\sqrt{d_{head}}}\right) \\ &= \sigma\left(\frac{v_2^{\top} W_Q^h (W_K^h)^T(v_0 - v_1)}{\sqrt{d_{head}}}\right) \\ &= \sigma\left(\frac{(t_2^T W_E) W_Q^h (W_K^h)^T W_E^T(t_0 - t_1)}{\sqrt{d_{head}}}\right) \\ \end{aligned} $$ in terms of only the weight matrices and one-hot encoded tokens $t_i$. Now, let's put both of these two together. We have the functional form as: $$ f(t)=\operatorname{ReLU}\left(\sum_n\left(\alpha^h t_x+\left(1-\alpha^h\right) t_y\right)^T W_E W_V^h W_O^h W_{in}\right) W_{out} W_U $$ --- Now that we have the funcional form, we can observe that the model's behaviour is fully determined by a handful of matrices, which we call effective weight matrices. They are: $W_{logit} := W_{out} W_U$, which has size $(d_{mlp}, d_{vocab}-1) = $ (512, p), and tells us how to get from the output of the nonlinear activation function to our final logits. $W_{neur} := W_E W_V W_O W_{in}$, which has size $(n_{heads}, d_{vocab}-1, d_{mlp}) =$ (4, p, 512) (we're stacking the OV matrices for each head along the zeroth dimension). This tells us how to get from a weighted sum of initial embeddings, to our neuron activations. * $W_{attn} := (t_2^T W_E) W_Q^h (W_K^h)^T W_E^T / \sqrt{d_{head}}$, which has size $(n_{heads}, d_{vocab}-1) =$ (4, p). This is the set of row (one vector per head) which we we dot with $(t_0 - t_1)$, to give us our attention scores. We can see how they all act in the transformer: $$ f(t)=\operatorname{ReLU}\Bigg(\sum_h\underbrace{\bigg(\alpha^h t_0\;+\;\left(1\;-\;\alpha^h\right) t_1 \bigg)^T}_{\textstyle{\alpha^h = \sigma(W_{attn}^h(t_0 - t_1))}} \underbrace{W_E W_V^h W_O^h W_{in}}_{\textstyle{W_{neur}^h}}\Bigg) \;\underbrace{W_{out} W_U}_{\textstyle{W_{logit}}} $$ Note - the $W_E$ and $W_U$ above mostly refer to the reduced matrices (hence the sizes being $d_{vocab}-1$). This is because $t_0$ and $t_1$ can only ever be the integers $0, 1, ..., p-1$, and the only logit output we care about are those corresponding to integers. The only exception is when we define $W_{attn}$, because the $t_2^T W_E$ term is equal to the **last row of the full embedding matrix.** ### Exercise - define the effective weight matrices Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to 10-15 minutes on these exercises. They should not be challenging, and are designed to get you more comfortable with constructing circuits from model weights in a hands-on way. In the **Answers** dropdown above, we identified three **effective weight matrices** which collectively determine the behaviour of the transformer. Below, you should calculate these three matrices directly from the model. Don't worry about using any kind of factored matrices; the model isn't large enough for this to be necessary (the same goes for all subsequent exercises). # YOUR CODE HERE - define the following matrices # Test shapes assert W_logit.shape == (cfg.d_mlp, cfg.d_vocab - 1) assert W_neur.shape == (cfg.n_heads, cfg.d_vocab - 1, cfg.d_mlp) assert W_attn.shape == (cfg.n_heads, cfg.d_vocab - 1) # Test values tests.test_effective_weights(W_logit, W_neur, W_attn, model) Explanation of solution Note, all these examples use the @ operator. You might be more comfortable with einsum since it's more explicit and harder to make errors, but there's nothing wrong with using @ if you're already comfortable with the sizes of the matrices in question, and how @ handles matrix multiplication in different cases (e.g. when >2D tensors are involved). For instance, W_OV is a 3D tensor where the first dimension is the head index, and when multiplying this with a 2D matrix using @, PyTorch helpfully interprets W_OV as a batch of matrices, which is exactly what we want. There are a few subtleties though, e.g. remember that using .T on a 3D tensor won't by default transpose the last two dimensions like you might want. This is why we need the transpose method instead. W_logit = W_out @ W_U W_OV = W_V @ W_O W_neur = W_E @ W_OV @ W_in W_QK = W_Q @ W_K.transpose(-1, -2) W_attn = final_pos_resid_initial @ W_QK @ W_E.T / (cfg.d_head ** 0.5) Solution W_logit = W_out @ W_U W_OV = W_V @ W_O W_neur = W_E @ W_OV @ W_in W_QK = W_Q @ W_K.transpose(-1, -2) W_attn = final_pos_resid_initial @ W_QK @ W_E.T / (cfg.d_head**0.5) ### Everything is periodic Any initial investigation and visualisation of activations and of the above effective weight matrices shows that things in the vocab basis are obviously periodic. Run the cells below, and demonstrate this for yourself. #### Activations **Attention patterns:** The heatmap generated from the code below is a $p\times p$ image, where the cell $(x, y)$ represents some activation (ie a real number at a hidden layer of the network) on the input $x$ and $y$. **Note:** Animation sliders are used to represent the different heads, not as a time dimension. **Note:** $A_{2\to 2}^h\approx 0$, so $A_{2\to 1}^h = 1-A_{2\to 0}^h$. For this reason, the first thing we do below is redefine `attn_mat` to only refer to the attention paid to the first two tokens. **Note:** We start by rearranging the attention matrix, so that the first two dimensions represent the (x, y) coordinates in the modular arithmetic equation. This is the meaning of the plots' axes. attn_mat = attn_mat[:, :, :2] # we only care about attn from first 2 tokens to the "=" token # We rearrange attn_mat, so the first two dims represent (x, y) in modular arithmetic equation attn_mat_sq = einops.rearrange(attn_mat, "(x y) head seq -> x y head seq", x=p) utils.inputs_heatmap( attn_mat_sq[..., 0], title="Attention score for heads at position 0", animation_frame=2, animation_name="head", ) **Neuron activations:** # We rearrange activations, so the first two dims represent (x, y) in modular arithmetic equation neuron_acts_post_sq = einops.rearrange(neuron_acts_post, "(x y) d_mlp -> x y d_mlp", x=p) neuron_acts_pre_sq = einops.rearrange(neuron_acts_pre, "(x y) d_mlp -> x y d_mlp", x=p) top_k = 3 utils.inputs_heatmap( neuron_acts_post_sq[..., :top_k], title=f"Activations for first {top_k} neurons", animation_frame=2, animation_name="Neuron", ) **Effective weights:** #### **$W_{neur}$** top_k = 5 utils.animate_multi_lines( W_neur[..., :top_k], y_index=[f"head {hi}" for hi in range(4)], labels={"x": "Input token", "value": "Contribution to neuron"}, snapshot="Neuron", title=f"Contribution to first {top_k} neurons via OV-circuit of heads (not weighted by attention)", ) #### **$W_{attn}$** utils.lines( W_attn, labels=[f"head {hi}" for hi in range(4)], xaxis="Input token", yaxis="Contribution to attn score", title="Contribution to attention score (pre-softmax) for each head", ) All this periodicity might make us think that the vocabulary basis isn't the most natural one to be operating in. The question is - what is the appropriate basis? ## Fourier Transforms > TL;DR: > > * We can define a Fourier basis of cosine and sine waves with period dividing $p$ (i.e. frequency that's a multiple of $2 \pi / p$). > * We can apply a change of basis of the vocab space into the Fourier basis, and periodic functions are sparse in the Fourier basis. > * For activations that are a function of just one input we use the 1D Fourier transform; for activations that are a function of both inputs we use the 2D Fourier transform. A natural way to understand what's going on is using Fourier transforms. This represents any function as a sum of sine and cosine waves. Everything here is discrete, which means our functions are just $p$ or $p^2$ dimensional vectors, and the Fourier transform is just a change of basis. All functions have *some* representation in the Fourier basis, but "this function looks periodic" can be operationalised as "this function is sparse in the Fourier basis". Note that we are applying a change of basis to $\mathbb{R}^p$, corresponding to the vocabulary space of one-hot encoded input vectors. (We are abusing notation by pretending that `=` is not in our vocabulary, so $d_\mathrm{vocab} = p$, allowing us to take Fourier transforms over the input space.) ### 1D Fourier Basis We define the 1D Fourier basis as a list of sine and cosine waves. **We begin with the constant wave (and then add cosine and sine waves of different frequencies.** The waves need to have period dividing $p$, so they have frequencies that are integer multiples of $\omega_1 = 2 \pi / p $. We'll use the shorthand notation that $\vec{\textbf{x}} = (0, 1, ..., (p-1))$, and so $\cos (\omega_k \vec{\textbf{x}})$ actually refers to the following vector in $\mathbb{R}^p$: $$ \cos (\omega_k \vec{\textbf{x}}) = \big(1,\; \cos (\omega_k),\; \cos (2 \omega_k),\; ...,\; \cos ((p-1) \omega_k\big) $$ (after being scaled to unit norm), where $\omega_k = 2 \pi k / p$. We will also denote $F$ as the $p \times p$ matrix where each **row** is one such wave: $$ F = \begin{bmatrix} \leftarrow \vec{\textbf{1}} \rightarrow \\ \leftarrow \sin (\omega_1 \vec{\textbf{x}}) \rightarrow \\ \leftarrow \cos (\omega_1 \vec{\textbf{x}}) \rightarrow \\ \leftarrow \sin (\omega_2 \vec{\textbf{x}}) \rightarrow \\ \vdots \\ \leftarrow \cos (\omega_{(p-1)/2} \vec{\textbf{x}}) \rightarrow \\ \end{bmatrix} $$ Again, we've omitted the normalization constant, but you should assume each row is a basis vector with norm 1. This means the constant term $\vec{\textbf{1}}$ is scaled by $\sqrt{\frac{1}{p}}$, and the rest by $\sqrt{\frac{2}{p}}$. Note also that the waves (esp. at high frequencies) look jagged, not smooth. This is because we discretise the inputs to just be integers, rather than all reals. ### Exercise - create the 1D Fourier basis Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to 15-25 minutes on this exercise. We will be working with the Fourier basis extensively, so it's important to understand what it is. Complete the function below. Don't worry about computational efficiency; using a for loop is fine. def make_fourier_basis(p: int) -> tuple[Tensor, list[str]]: """ Returns a pair `fourier_basis, fourier_basis_names`, where `fourier_basis` is a `(p, p)` tensor whose rows are Fourier components and `fourier_basis_names` is a list of length `p` containing the names of the Fourier components (e.g. `["const", "cos 1", "sin 1", ...]`). You may assume that `p` is odd. """ raise NotImplementedError() tests.test_make_fourier_basis(make_fourier_basis) Solution def make_fourier_basis(p: int) -> tuple[Tensor, list[str]]: """ Returns a pair fourier_basis, fourier_basis_names, where fourier_basis is a (p, p) tensor whose rows are Fourier components and fourier_basis_names is a list of length p containing the names of the Fourier components (e.g. ["const", "cos 1", "sin 1", ...]). You may assume that p is odd. """ # Define a grid for the Fourier basis vecs (we'll normalize them all at the end) # Note, the first vector is just the constant wave fourier_basis = t.ones(p, p) fourier_basis_names = ["Const"] for i in range(1, p // 2 + 1): # Define each of the cos and sin terms fourier_basis[2 i - 1] = t.cos(2 t.pi t.arange(p) i / p) fourier_basis[2 i] = t.sin(2 t.pi t.arange(p) i / p) fourier_basis_names.extend([f"cos {i}", f"sin {i}"]) # Normalize vectors, and return them fourier_basis /= fourier_basis.norm(dim=1, keepdim=True) return fourier_basis.to(device), fourier_basis_names Once you've done this (and passed the tests), you can run the cell below to visualise your Fourier components. fourier_basis, fourier_basis_names = make_fourier_basis(p) utils.animate_lines( fourier_basis, snapshot_index=fourier_basis_names, snapshot="Fourier Component", title="Graphs of Fourier Components (Use Slider)", ) Click to see the expected output *Note - from this point onwards, the `fourier_basis` and `fourier_basis_names` variables are global, so you'll be using them in other functions. We won't be changing the value of `p`; this is also global.* Now, you can prove the fourier basis is orthonormal by showing that the inner product of any two vectors is one if they are the same vector, and zero otherwise. Run the following cell to see for yourself: utils.imshow(fourier_basis @ fourier_basis.T, title="Fourier Basis Cosine Similarity Matrix") Click to see the expected output Now that we've shown the Fourier transform is indeed an orthonormal basis, we can write any $p$-dimensional vector in terms of this basis. The **1D Fourier transform** is just the transformation taking the components of a vector in the standard basis to its components in the Fourier basis (in other words we project the vector along each of the Fourier basis vectors). ### Exercise - 1D Fourier transform Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵🔵⚪ This should be a short, one-line function. Again, this is much more important to understand conceptually as opposed to being difficult to implement. You should now write a function to compute the Fourier transform of a vector. Remember that the **rows** of `fourier_basis` are the Fourier basis vectors. def fft1d(x: Tensor) -> Tensor: """ Returns the 1D Fourier transform of `x`, which can be a vector or a batch of vectors. x.shape = (..., p) """ raise NotImplementedError() tests.test_fft1d(fft1d) Solution def fft1d(x: Tensor) -> Tensor: """ Returns the 1D Fourier transform of x, which can be a vector or a batch of vectors. x.shape = (..., p) """ return x @ fourier_basis.T Note - if x was a vector, then returning fourier_basis @ x would be perfectly fine. But if x is a batch of vectors, then we want to make sure the multiplication happens along the last dimension of x. We can demonstrate this transformation on an example function which looks periodic. The key intuition is that **'function looks periodic in the original basis'** implies **'function is sparse in the Fourier basis'**. Note that functions over the integers $[0, p-1]$ are equivalent to vectors in $\mathbb{R}^p$, since we can associate any such function $f$ with the vector: $$ \begin{bmatrix} f(0) \\ f(1) \\ \vdots \\ f(p-1) \end{bmatrix} $$ v = sum([fourier_basis[4], fourier_basis[15] / 5, fourier_basis[67] / 10]) utils.line(v, xaxis="Vocab basis", title="Example periodic function") utils.line( fft1d(v), xaxis="Fourier Basis", title="Fourier Transform of example function", hover=fourier_basis_names, ) Click to see the expected output You should observe a jagged but approximately periodic function in the first plot, and a very sparse function in the second plot (with only three non-zero coefficients). ### 2D Fourier Basis **All of the above ideas can be naturally extended to a 2D Fourier basis on $\mathbb{R}^{p \times p}$, ie $p \times p$ images. Each term in the 2D Fourier basis is the outer product $v w^T$ of two terms $v, w$ in the 1D Fourier basis.** Thus, our 2D Fourier basis contains (up to a scaling factor): * a constant term $\vec{\textbf{1}}$, * linear terms of the form $\,\cos(\omega_k \vec{\textbf{x}}),\,\sin(\omega_k \vec{\textbf{x}}),\,\cos(\omega_k \vec{\textbf{y}})$, and $\sin(\omega_k \vec{\textbf{y}})$, * and quadratic terms of the form: $$ \begin{aligned} & \cos(w_i\vec{\textbf{x}})\cos(w_j\vec{\textbf{y}}) \\ & \sin(w_i\vec{\textbf{x}})\cos(w_j\vec{\textbf{y}}) \\ & \cos(w_i\vec{\textbf{x}})\sin(w_j\vec{\textbf{y}}) \\ & \sin(w_i\vec{\textbf{x}})\sin(w_j\vec{\textbf{y}}) \end{aligned} $$ Although we can think of these as vectors of length $p^2$, it makes much more sense to think of them as matrices of size $(p, p)$. > Notation - $\cos(\omega_i \vec{\textbf{x}})\cos(\omega_j \vec{\textbf{y}})$ should be understood as the $(p, p)$-size matrix constructed from the outer product of 1D vectors $\cos (\omega_i \vec{\textbf{x}})$ and $\cos (\omega_j \vec{\textbf{y}})$. In other words, the $(x, y)$-th element of this matrix is $\cos(\omega_i x) \cos(\omega_j y)$. ### Exercise - create the 2D Fourier basis Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵🔵⚪ This should be a short, one-line function. Again, this is much more important to understand conceptually. Complete the following function. Note that (unlike the function we wrote for the 1D Fourier basis) this only returns a single basis term, rather than the entire basis. def fourier_2d_basis_term(i: int, j: int) -> Float[Tensor, "p p"]: """ Returns the 2D Fourier basis term corresponding to the outer product of the `i`-th component of the 1D Fourier basis in the `x` direction and the `j`-th component of the 1D Fourier basis in the `y` direction. Returns a 2D tensor of length `(p, p)`. """ raise NotImplementedError() tests.test_fourier_2d_basis_term(fourier_2d_basis_term) Solution def fourier_2d_basis_term(i: int, j: int) -> Float[Tensor, "p p"]: """ Returns the 2D Fourier basis term corresponding to the outer product of the i-th component of the 1D Fourier basis in the x direction and the j-th component of the 1D Fourier basis in the y direction. Returns a 2D tensor of length (p, p). """ return fourier_basis[i][:, None] fourier_basis[j][None, :] Note, indexing with None is one of many ways to write this function. A few others are: torch.outer Using einsum: torch.einsum('i,j->ij', ...) Using the unsqueeze method to add dummy dimensions to the vectors before multiplying them together. Once you've defined this function, you can visualize the 2D Fourier basis by running the following code. Verify that they do indeed look periodic. x_term = 4 y_term = 6 utils.inputs_heatmap( fourier_2d_basis_term(x_term, y_term).T, title=f"2D Fourier Basis term {fourier_basis_names[x_term]}x {fourier_basis_names[y_term]}y", ) Click to see the expected output What benefit do we get from thinking about $(p, p)$ images? Well, the batch dimension of all our data is of size $p^2$, since we're dealing with every possible value of inputs `x` and `y`. So we might think of reshaping this batch dimension to $(p, p)$, then applying a 2D Fourier transform to it. Let's implement this transform now! ### Exercise - Implementing the 2D Fourier Transform Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to ~10 minutes on this exercise. This exercise should be pretty familiar, since you've already done this in 1D. def fft2d(tensor: Tensor) -> Tensor: """ Retuns the components of `tensor` in the 2D Fourier basis. Asumes that the input has shape `(p, p, ...)`, where the last dimensions (if present) are the batch dims. Output has the same shape as the input. """ raise NotImplementedError() tests.test_fft2d(fft2d) Solution def fft2d(tensor: Tensor) -> Tensor: """ Retuns the components of tensor in the 2D Fourier basis. Asumes that the input has shape (p, p, ...), where the last dimensions (if present) are the batch dims. Output has the same shape as the input. """ # fourier_basis[i] is the i-th basis vector, which we want to multiply along return einops.einsum(tensor, fourier_basis, fourier_basis, "px py ..., i px, j py -> i j ...") While working with the 1D Fourier transform, we defined simple periodic functions which were linear combinations of the Fourier basis vectors, then showed that they were sparse when we expressed them in terms of the Fourier basis. That's exactly what we'll do here, but with functions of 2 inputs rather than 1. Below is some code to plot a simple 2D periodic function (which is a linear combination of 2D Fourier basis terms). Note that we call our matrix `example_fn`, because we're thinking of it as a function of its two inputs (in the x and y directions). example_fn = sum( [ fourier_2d_basis_term(4, 6), fourier_2d_basis_term(14, 46) / 3, fourier_2d_basis_term(97, 100) / 6, ] ) utils.inputs_heatmap(example_fn.T, title="Example periodic function") Click to see the expected output Code to show this function is sparse in the 2D Fourier basis (you'll have to zoom in to see the non-zero coefficients): utils.imshow_fourier(fft2d(example_fn), title="Example periodic function in 2D Fourier basis") Click to see the expected output You can run this code, and check that the non-zero components exactly match the basis terms we used to construct the function. ## Analysing our model with Fourier Transforms So far, we've made two observations: * Many of our model's activations appear periodic * Periodic functions appear sparse in the Fourier basis So let's take the obvious next step, and apply a 2D Fourier transformation to our activations! Remember that the batch dimension of our activations is $p^2$, which can be rearranged into $(p, p)$, with these two dimensions representing the `x` and `y` inputs to our modular arithmetic equation. These are the dimensions over which we'll take our Fourier transform. ### Plotting activations in the Fourier basis Recall our previous code, to plot the heatmap for the attention scores token 2 pays to token 0, for each head: inputs_heatmap( attn_mat[..., 0], title=f'Attention score for heads at position 0', animation_frame=2, animation_name='head' ) In that plot, the x and y axes represented the different values of inputs `x` and `y` in the modular arithmetic equation. The code below takes the 2D Fourier transform of the attention matrix, and plots the heatmap for the attention scores token 2 pays to token 0, for each head, in the Fourier basis: # Apply Fourier transformation attn_mat_fourier_basis = fft2d(attn_mat_sq) # Plot results utils.imshow_fourier( attn_mat_fourier_basis[..., 0], title="Attention score for heads at position 0, in Fourier basis", animation_frame=2, animation_name="head", ) Click to see the expected output You should find that the result is extremely sparse - there will only be a few cells (mostly on the zeroth rows or columns, i.e. corresponding to the constant or linear terms) which aren't zero. This suggests that we're on the right track using the 2D Fourier basis! Now, we'll do ths same for the neuron activations. Recall our previous code: top_k = 3 inputs_heatmap( neuron_acts_post[:, :top_k], title=f'Activations for first {top_k} neurons', animation_frame=2, animation_name='Neuron' ) We'll do the exact same here, and plot the activations in the Fourier basis: neuron_acts_post_fourier_basis = fft2d(neuron_acts_post_sq) top_k = 3 utils.imshow_fourier( neuron_acts_post_fourier_basis[..., :top_k], title=f"Activations for first {top_k} neurons", animation_frame=2, animation_name="Neuron", ) Click to see the expected output ### Exercise - spot patterns in the activations Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵⚪⚪ You should spend up to 10-15 minutes on this exercise. Increase `top_k` from 3 to a larger number, and look at different neurons. What do you notice about the patterns in the activations? Answer (what you should see) Again, you should see sparsity, although this time some quadratic terms should be visible too. Beyond this, there are 2 distinct patterns worth commenting on: Each neuron has the same pattern of non-zero terms: for some value of $k$, the non-zero terms are the constant term plus all four linear and four quadratic terms involving just frequencies $k$ (i.e. terms like $\cos(\omega_k \vec{\textbf{x}})$, $\sin(\omega_k \vec{\textbf{y}})$, $\cos(\omega_k \vec{\textbf{x}})\sin(\omega_k \vec{\textbf{y}})$, etc). There are only a handful of different values of $k$ across all the neurons, so many of them end up having very similar-looking activation patterns. #### Aside: Change of basis on the batch dimension A change of basis on the batch dimension is a pretty weird thing to do, and it's worth thinking carefully about what happens here (note that this is a significant deviation to the prior Transformer Circuits work, and only really makes sense here because this is such a toy problem that we can enter the entire universe as one batch). There are *four* operations that are not linear with respect to the batch dimension. As above, the attention softmax, ReLU and final softmax. But also the elementwise multiplication with the attention pattern. In particular, ReLU becomes super weird - it goes from an elementwise operation to the operation 'rotate by the inverse of the Fourier basis, apply ReLU elementwise in the *new* basis, rotate back to the Fourier basis'. ### Plotting effective weights in the Fourier basis As well as plotting our activations, we can also look at the weight matrices directly. *Note - this section isn't essential to understanding the rest of the notebook, so feel free to skip it if you're short on time.* We'll now adjust our previous code which plotted the `W_neur` matrix in the standard basis, so it's visualized in the Fourier basis instead. In the plot below, here, each line shows the activations for some neuron as a function of the input token `x` (if the `=` token at position 2 only paid attention to a token `x`), with the neuron index determined by the slider value. If this seems a bit confusing, you can use the dropdown below to remind yourself of the functional form of this transformer, and the role of $W_{neur}$. Functional Form $$ f(t)=\operatorname{ReLU}\Bigg(\sum_h{\bigg(\alpha^h t_0\;+\;\left(1\;-\;\alpha^h\right) t_1 \bigg)^T}{W_{neur}^h}\Bigg) \; W_{logit} $$ From this, we can see clearly the role of $W_{neur}^h$. It is a matrix of shape $(d_{vocab}, d_{mlp})$, and its rows are the vectors we take a weighted average of to get our MLP activations (pre-ReLU). The code below makes the same plot, but while the previous one was in the standard basis (with the x-axis representing the input token), in this plot the x-axis is the component of the input token in each Fourier basis direction. Note that we've provided you with the helper function `fft1d_given_dim`, which performs the 1D Fourier transform over a given dimension. This is necessary for `W_neur`, since it has shape `(n_heads, d_vocab, d_mlp)`, and we want to transform over the `d_vocab` dimension. def fft1d_given_dim(tensor: Tensor, dim: int) -> Tensor: """ Performs 1D FFT along the given dimension (not necessarily the last one). """ return fft1d(tensor.transpose(dim, -1)).transpose(dim, -1) W_neur_fourier = fft1d_given_dim(W_neur, dim=1) top_k = 5 utils.animate_multi_lines( W_neur_fourier[..., :top_k], y_index=[f"head {hi}" for hi in range(4)], labels={"x": "Fourier component", "value": "Contribution to neuron"}, snapshot="Neuron", hover=fourier_basis_names, title=f"Contribution to first {top_k} neurons via OV-circuit of heads (not weighted by attn), in Fourier basis", ) Click to see the expected output Note that each line plot generally has $\sin k$ and $\cos k$ terms non-zero, rather than having one but not the other. Lastly, we'll do the same with `W_attn`: utils.lines( fft1d(W_attn), labels=[f"Head {hi}" for hi in range(4)], xaxis="Input token", yaxis="Contribution to attn score", title="Contribution to attn score (pre-softmax) for each head, in Fourier Basis", hover=fourier_basis_names, ) Click to see the expected output You may have noticed that the handful of non-zero frequencies in both these last two line charts exactly match the important frequencies we read off the attention patterns! ## Recap of section Let's review what we've learned in this section. We found that: > - The simple architecture of our 1-layer model heavily constrains the functional form of any learned solutions. > - In particular, we can define a handful of matrices which fully describe the model's behaviour (after making some simplifying assumptions). > - Many of our model's internal activations appear periodic in the inputs `x`, `y` (e.g. the attention patterns and neuron activations). > - The natural way to represent a periodic function is in the Fourier basis. Periodic functions appear sparse in this basis. > - This suggests our model might only be using a handful of frequencies (i.e. projecting the inputs onto a few different Fourier basis vectors), and discarding the rest. > - We confirmed this hypothesis by looking at: > - The model's activations (i.e. attention patterns and neuron activations) > - The model's effective weight matrices (i.e. $W_{attn}$ and $W_{neur}$) > - Both these observations confirmed that we have sparsity in the Fourier basis. Furthermore, the same small handful of frequencies seemed to be appearing in all cases.