3️⃣ Analysis During Training
Learning Objectives
- Understand the idea of tracking metrics over time, and how this can inform when certain circuits are forming.
- Investigate and interpret the evolution over time of the singular values of the model's weight matrices.
- Investigate the formation of other capabilities in the model, like commutativity.
Some starting notes
Note - this section has fewer exercises than previous sections, and is intended more as a showcase of some of the results from the paper.
In this section, we analyse the modular addition transformer during training. In the data we'll be using, checkpoints were taken every 100 epochs, from epoch 0 to 50K (the model we used in previous exercises was taken at 40K).
Usability note: I often use animations in this section. I recommend using the slider manually, not pressing play - Plotly smooths animations in a confusing and misleading way (and I haven't figured out how to fix it).
Usability note 2: To get plots to display, they can't have too many data points, so often plots will have different epoch intervals between data points, or different final epochs
Notation: I use "trig components" to refer to the components $\cos(\omega(x+y))$, $\sin(\omega(x+y))$.
Overview
- The model starts to learn the generalised algorithm well before the phase change, and this develops fairly smoothly
- We can see this with the metric of excluded loss, which seems to indicate that the model learns to "memorise more efficiently" by using the $\cos(w(x+y))$ directions.
- This is a clear disproof of the 'grokking is a random walk in the loss landscape that eventually gets lucky' hypothesis.
- We also examine more qualitatively each circuit in the model and how it develops
- We see that all circuits somewhat develop pre-grokking, but at different rates and some have a more pronounced phase change than others
- We examine the embedding circuit, the 'calculating trig dimensions' circuit and the development of commutativity.
- We also explore the development of neuron activations and how it varies by cluster
- There's a small but noticeable lag between 'the model learns the generalisable algorithm' and 'the model cleans up all memorised noise'.
- There are indications of several smaller phase changes, beyond the main grokking one.
- In particular, a phase change at 43K-44K, well after grokking (I have not yet interpreted what's going on here).
Setup
First, we'll define some useful functions. In particular, the get_metrics function is designed to populate a dictionary of metrics over the training period. The argument metric_fn is itself a function which takes in a model, and returns a metric (e.g. we use metric_fn=test_loss, to return the model's loss on the test set).
# Define a dictionary to store our metrics in
metric_cache = {}
def get_metrics(model: HookedTransformer, metric_cache, metric_fn, name, reset=False):
"""
Define a metric (by metric_fn) and add it to the cache, with the name `name`.
If `reset` is True, then the metric will be recomputed, even if it is already in the cache.
"""
if reset or (name not in metric_cache) or (len(metric_cache[name]) == 0):
metric_cache[name] = []
for sd in tqdm(full_run_data["state_dicts"]):
model = utils.load_in_state_dict(model, sd)
out = metric_fn(model)
if isinstance(out, Tensor):
out = to_numpy(out)
metric_cache[name].append(out)
model = utils.load_in_state_dict(model, full_run_data["state_dicts"][400])
metric_cache[name] = t.tensor(np.array(metric_cache[name]))
def test_loss(model):
logits = model(all_data)[:, -1, :-1]
return utils.test_logits(logits, False, mode="test")
def train_loss(model):
logits = model(all_data)[:, -1, :-1]
return utils.test_logits(logits, False, mode="train")
epochs = full_run_data["epochs"]
plot_metric = partial(utils.lines, x=epochs, xaxis="Epoch")
get_metrics(model, metric_cache, test_loss, "test_loss")
get_metrics(model, metric_cache, train_loss, "train_loss")
Excluded Loss
Excluded Loss for frequency $w$ is the loss on the training set where we delete the components of the logits corresponding to $\cos(w(x+y))$ and $sin(w(x+y))$. We get a separate metric for each $w$ in the key frequencies.
Key observation: The excluded loss (especially for frequency 14) starts to go up well before the point of grokking.
(Note: this performance decrease is way more than you'd get for deleting a random direction.)
Exercise - implement excluded loss
Difficulty:
🔴🔴🔴⚪⚪
Importance:
🔵🔵🔵🔵⚪
You should spend up to 20-30 minutes on this exercise.
You should fill in the function below to implement excluded loss. You'll need to use the get_trig_sum_directions and project_onto_direction functions from previous exercises. We've given you the first few lines of the function as a guide.
Note - when calculating the loss, you should use the test_logits function, with arguments bias_correction=False, mode="train".
def excl_loss(model: HookedTransformer, key_freqs: list) -> list:
"""
Returns the excluded loss (i.e. subtracting the components of logits corresponding to
cos(w_k(x+y)) and sin(w_k(x+y)), for each frequency k in key_freqs.
"""
excl_loss_list = []
logits = model(all_data)[:, -1, :-1]
raise NotImplementedError()
tests.test_excl_loss(excl_loss, model, key_freqs)
Solution
def excl_loss(model: HookedTransformer, key_freqs: list) -> list:
"""
Returns the excluded loss (i.e. subtracting the components of logits corresponding to
cos(w_k(x+y)) and sin(w_k(x+y)), for each frequency k in key_freqs.
"""
excl_loss_list = []
logits = model(all_data)[:, -1, :-1]
for freq in key_freqs:
cos_xplusy_direction, sin_xplusy_direction = get_trig_sum_directions(freq)
logits_cos_xplusy = project_onto_direction(logits, cos_xplusy_direction.flatten())
logits_sin_xplusy = project_onto_direction(logits, sin_xplusy_direction.flatten())
logits_excl = logits - logits_cos_xplusy - logits_sin_xplusy
loss = utils.test_logits(logits_excl, bias_correction=False, mode="train").item()
excl_loss_list.append(loss)
return excl_loss_list
Once you've completed this function, you can run the following code to plot the excluded loss for each of the key frequencies (as well as the training and testing loss as a baseline). This plot should match the one at the end of the Key Claims section of Neel Nanda's LessWrong post.
get_metrics(model, metric_cache, partial(excl_loss, key_freqs=key_freqs), "excl_loss")
plot_metric(
t.concat(
[
metric_cache["excl_loss"].T,
metric_cache["train_loss"][None, :],
metric_cache["test_loss"][None, :],
]
),
labels=[f"excl {freq}" for freq in key_freqs] + ["train", "test"],
title="Excluded Loss for each trig component",
log_y=True,
yaxis="Loss",
)
Click to see the expected output