☆ Bonus
Discussion & Future Directions
The key claim I want to make is that grokking happens when the process of learning a general algorithm exhibits a phase change, and the model is given the minimal amount of data such that the phase change still happens.
Why Phase Changes?
I further speculate that phase changes happen when a model learns any single generalising circuit that requires different parts of the model to work together in a non-linear way. Learning a sophisticated generalising circuit is hard and requires different parts of the model to line up correctly (ie coordinate on a shared representation). It doesn't seem super surprising if the circuit is in some sense either working or not, rather than smoothly increasing in performance
The natural next question is why the model learns the general algorithm at all - if performance is constant pre-phase change, there should be no gradient. Here I think the key is that it is hard to generalise to unseen data. The circuits pre-phase change are still fairly helpful for predicting seen data, as can be seen from excluded loss. (though, also, test loss tends to decrease even pre-phase change, it's just slower - the question is of how large the gradient needs to be, not whether it's zero or not)
My current model is that every gradient update of the model is partially an incentive to memorise the data in the batch and partially to pick up patterns within the data. The gradient updates to memorise point in arbitrary directions and tend to cancel out (both within a batch and between batches) if the model doesn't have enough capacity to memorise fully, while the gradient updates to identify patterns reinforce each other. Regularisation is something that dampens the memorising gradients over the pattern recognition gradients. We don't observe this dynamic in the infinite data setting because we never re-run the model on past data points, but I expect it has at least slightly memorised previous training data. (See, eg, Does GPT-2 Know Your Phone Number?)
I speculate that this is because the model is incentivised to memorise the data as efficiently as possible, ie is biased towards simplicity. This comes both due to both explicit regularisation such as weight decay and implicit regularisation like inherent model capacity and SGD. The model learns to pick up patterns between data points, because these regularities allow it to memorise more efficiently. But this needs enough training data to make the memorisation circuit more complex than the generalising circuit(s).
This further requires there to be a few crisp circuits for the model to learn - if it fuzzily learns many different circuits, then lowering the amount of data will fairly smoothly decrease test performance, as it learns fewer and fewer circuits.
Limitations
There are several limitations to this work and how confidently I can generalise from it to a general understanding of deep learning/grokking. In particular: * The modular addition transformer is a toy model, which only needs a few circuits to solve the problem fully (unlike eg an LLM or an image model, which needs many circuits) * The model is over-parametrised, as is shown by eg it learning many redundant neurons doing roughly the same task. It has many more parameters than it needs for near-perfect performance on its task, unlike real models. * I only study weight decay as the form of regularisation * My explanations of what's going on rely heavily on fairly fuzzy and intuitive notions that I have not explicitly defined: crisp circuits, simplicity, phase changes, what it means to memorise vs generalise, model capacity, etc.
Relevance to Alignment
Model Training Dynamics
The main thing which seems relevant from an alignment point of view is a better understanding of model training dynamics.
In particular, phase changes seem pretty bad from an alignment perspective, because they suggest that rapid capability gain as models train or are scaled up is likely. This makes it less likely that we get warning shots (ie systems exhibiting non-sophisticated unaligned behaviour) before we get dangerous unaligned systems, and makes it less likely that systems near to AGI are good empirical test beds for alignment work. This work exhibits a bunch more examples of phase changes, and marginally updates me towards alignment being hard.
This work makes me slightly more hopeful, since it shows that phase changes are clearly foreshadowed by the model learning the circuits beforehand, and that we can predict capability gain with interpretability inspired metrics (though obviously this is only exhibited in a toy case, and only after I understood the circuit well - I'd rather not have to train and interpret an unaligned AGI before we can identify misalignment!).
More speculatively, it suggests that we may be able to use interpretability tools to shape training dynamics. A natural concern here is that a model may learn to obfuscate our tools, and perform the same algorithms but without them showing up on our tools. This could happen both from gradient descent directly learning to obfuscate things (because the unaligned solution performs much better than an aligned solution), or because the system itself has learned to be deceptive and alter itself to avoid our tools. The fact that we observe the circuits developing well before they are able to generalise suggests that we might be able to disincentivise deception before the model gets good enough at deception to be able to generalise and evade deveption detectors (though doesn't do much for evading gradient descent).
The natural future direction here is to explore training on interpretability inspired metrics, and to see how much gradient descent learns to Goodhart them vs shifting its inductive bias to learn a different algorithm. Eg, can we incentive generalisation and get grokking with less data? Can we incentivise memorisation and change the algorithm it learns? What happens if we disincentivise learning to add with certain frequencies?
Other Relevance
This also seems relevant as more evidence of the circuits hypothesis, that networks can be mechanistically understood and are learning interpretable algorithms. And also as one of the first examples of interpreting a transformer's neurons (though on a wildly different task to language).
It also seems relevant as furthering the science of deep learning by better understanding how models generalise, and the weird phenomena of grokking. Understanding whether a model will be aligned or not requires us to understand how it will generalise, and to predict future model behaviour, and what different training setups will and will not generalise. This motivation is somewhat more diffuse and less directed, but seems analogous to how statistical learning theory allows us to predict things about models in classical statistics (though is clearly insufficient for deep learning).
(Though honestly I mostly worked on this because it was fun, I was on holiday, and I got pretty nerd-sniped by the problem. So all of this is somewhat ad-hoc and backwards looking, rather than this being purely alignment motivated work)
Training Dynamics
One interesting thing this research gives us insight into is the training dynamics of circuits - which parts of the circuit develop first, at what rates, and why?
My speculative guess is that, when a circuit has several components interacting between layers, the component closest to the logits is easier to form and will tend to form first.
Intuition: Imagine a circuit involving two components interacting with a non-linearity in the middle, which are both randomly initialised. They want to converge on a shared representation, involving a few directions in the hidden space. Initially, the random first component will likely have some output corresponding to the correct features, but with small coefficients and among a lot of noise. The second component can learn to focus on these correct features and cancel out the noise, reinforcing the incentive for the first component to focus on these features. On the other hand, if the second component is random, it's difficult to see how the first component can produce reasonable features that are productively used by the second component.
Qualitatively, we observe that all circuits are forming in parallel pre-grokking, but it roughly seems that the order is logit circuit > embedding circuit > neuron circuit > attention circuit (ie logit is fastest, attention is slowest)
This seems like an interesting direction of future research by giving a concrete example of crisp circuits that are easy to study during training. Possible initial experiments would be fixing parts of the network to be random, initialising parts of the network to be fully trained and frozen, or giving different parts of the network different learning rates.
Suggested capstone projects
Investigate phase changes
You could look for other examples of phase changes, for example:
- Toy problems
- Something incentivising skip trigrams
- Something incentivising virtual attention heads
- e.g. one of the models below (or pick an easier task)
- Looking for curve detectors in a ConvNet
- A dumb way to try this would be to train a model to imitate the actual curve detectors in Inception (eg minimising OLS loss between the model's output and curve detector activations)
- Looking at the formation of interpretable neurons in a SoLU transformer
- Looking inside a LLM with many checkpoints
- Eleuther have many checkpoints of GPT-J and GPT-Neo, and will share if you ask
- Mistral have public versions of GPT-2 small and medium, with 5 runs and many checkpoints
- Possible capabilities to look for
- Performance on benchmarks, or specific questions from benchmarks
- Simple algorithmic tasks like addition, or sorting words into alphabetical order, or matching open and close brackets
- Soft induction heads, eg translation
- Look at attention heads on various text and see if any have recognisable attention patterns (eg start of word, adjective describing current word, syntactic features of code like indents or variable definitions, most recent open bracket, etc).
Try more algorithmic problems
Interpreting toy models is a good way to increase your confidence working with TransformerLens and basic interpretability methods. It's maybe not the most exciting category of open problems in mechanistic interpretability, but can still be a useful exercise - and sometimes it can lead to interesting new insights about how interpretability tools can be used.
If you're feeling like it, you can try to hop onto LeetCode and pick a suitable problem (we recommend the "Easy" section) to train a transformer and interpret its output. Here are a few suggestions to get you started (some of these were taken from LeetCode, others from Neel Nanda's open problems post). They're listed approximately from easier to harder, although this is just a guess since I haven't personally interpreted these. Note, there are ways you could make any of these problems easier or harder with modifications - I've included some ideas inline.
- Calculating sequences with a Fibonacci-style recurrence relation (i.e. predicting the next element from the previous two)
- Search Insert Position - an easier version would be when the target is always guaranteed to be in the list (you also wouldn't need to worry about sorting in this case). The version without this guarantee is a very different problem, and would be much harder
- Is Subsequence - you should start with subsequences of length 1 (in which case this problem is pretty similar to the easier version of the previous problem), and work up from there
- Majority Element - you can try playing around with the data generation process to change the difficulty, e.g. sequences where there is no guarantee on the frequency of the majority element (i.e. you're just looking for the token which appears more than any other token) would be much harder
- Number of Equivalent Domino Pairs - you could restrict this problem to very short lists of dominos to make it easier (e.g. start with just 2 dominos!)
- Longest Substring Without Repeating Characters
- Isomorphic Strings - you could make it simpler by only allowing the first string to have duplicate characters, or by restricting the string length / vocabulary size
- Plus One - you might want to look at the "sum of numbers" algorithmic problem before trying this, and/or the grokking exercises in this chapter. Understanding this problem well might actually help you build up to interpreting the "sum of numbers" problem (I haven't done this, so it's very possible you could come up with a better interpretation of that monthly problem than mine, since I didn't go super deep into the carrying mechanism)
- Predicting permutations, i.e. predicting the last 3 tokens of the 12-token sequence
(17 3 11) (17 1 13) (11 2 4) (11 4 2)(i.e. the model has to learn what permutation function is being applied to the first group to get the second group, and then apply that permutation to the third group to correctly predict the fourth group). Note, this problem might require 3 layers to solve - can you see why? - Train models for automata tasks and interpret them - do your results match the theory?
- Predicting the output to simple code functions. E.g. predicting the
1 2 4text in the following sequence (which could obviously be made harder with some obvious modifications, e.g. adding more variable definitions so the model has to attend back to the right one):
a = 1 2 3
a[2] = 4
a -> 1 2 4
- Graph theory problems like this. You might have to get creative with the input format when training transformers on tasks like this!
Note, ARENA runs a monthly algorithmic problems sequence, and you can get ideas from looking at past problems from this sequence. You can also use these repos to get some sample code for building & training a trnasformer on a toy model, and constructing a dataset for your particular problem.
Suggested paper replications
A Toy Model of Universality: Reverse Engineering How Networks Learn Group Operations
This paper extends the analysis of this particular task & model, by looking at a general group operation of which modular addition is a special case.
This might be a good replication for you if:
- You enjoyed all subsections in this exercise set, and would like to perform similar analysis on more complex algorithms
- You're interested in studying grokking and training dynamics
- You have a background in mathematics, and in particular have some familiarity with group theory (and ideally some representation theory)