3️⃣ Analysis During Training

Learning Objectives
  • Understand the idea of tracking metrics over time, and how this can inform when certain circuits are forming.
  • Investigate and interpret the evolution over time of the singular values of the model's weight matrices.
  • Investigate the formation of other capabilities in the model, like commutativity.

Some starting notes

Note - this section has fewer exercises than previous sections, and is intended more as a showcase of some of the results from the paper.

In this section, we analyse the modular addition transformer during training. In the data we'll be using, checkpoints were taken every 100 epochs, from epoch 0 to 50K (the model we used in previous exercises was taken at 40K).

Usability note: I often use animations in this section. I recommend using the slider manually, not pressing play - Plotly smooths animations in a confusing and misleading way (and I haven't figured out how to fix it).

Usability note 2: To get plots to display, they can't have too many data points, so often plots will have different epoch intervals between data points, or different final epochs

Notation: I use "trig components" to refer to the components $\cos(\omega(x+y))$, $\sin(\omega(x+y))$.

Overview

  • The model starts to learn the generalised algorithm well before the phase change, and this develops fairly smoothly
    • We can see this with the metric of excluded loss, which seems to indicate that the model learns to "memorise more efficiently" by using the $\cos(w(x+y))$ directions.
    • This is a clear disproof of the 'grokking is a random walk in the loss landscape that eventually gets lucky' hypothesis.
  • We also examine more qualitatively each circuit in the model and how it develops
    • We see that all circuits somewhat develop pre-grokking, but at different rates and some have a more pronounced phase change than others
    • We examine the embedding circuit, the 'calculating trig dimensions' circuit and the development of commutativity.
    • We also explore the development of neuron activations and how it varies by cluster
  • There's a small but noticeable lag between 'the model learns the generalisable algorithm' and 'the model cleans up all memorised noise'.
  • There are indications of several smaller phase changes, beyond the main grokking one.
    • In particular, a phase change at 43K-44K, well after grokking (I have not yet interpreted what's going on here).

Setup

First, we'll define some useful functions. In particular, the get_metrics function is designed to populate a dictionary of metrics over the training period. The argument metric_fn is itself a function which takes in a model, and returns a metric (e.g. we use metric_fn=test_loss, to return the model's loss on the test set).

# Define a dictionary to store our metrics in
metric_cache = {}


def get_metrics(model: HookedTransformer, metric_cache, metric_fn, name, reset=False):
    """
    Define a metric (by metric_fn) and add it to the cache, with the name `name`.

    If `reset` is True, then the metric will be recomputed, even if it is already in the cache.
    """
    if reset or (name not in metric_cache) or (len(metric_cache[name]) == 0):
        metric_cache[name] = []
        for sd in tqdm(full_run_data["state_dicts"]):
            model = utils.load_in_state_dict(model, sd)
            out = metric_fn(model)
            if isinstance(out, Tensor):
                out = to_numpy(out)
            metric_cache[name].append(out)
        model = utils.load_in_state_dict(model, full_run_data["state_dicts"][400])
        metric_cache[name] = t.tensor(np.array(metric_cache[name]))


def test_loss(model):
    logits = model(all_data)[:, -1, :-1]
    return utils.test_logits(logits, False, mode="test")


def train_loss(model):
    logits = model(all_data)[:, -1, :-1]
    return utils.test_logits(logits, False, mode="train")


epochs = full_run_data["epochs"]
plot_metric = partial(utils.lines, x=epochs, xaxis="Epoch")

get_metrics(model, metric_cache, test_loss, "test_loss")
get_metrics(model, metric_cache, train_loss, "train_loss")

Excluded Loss

Excluded Loss for frequency $w$ is the loss on the training set where we delete the components of the logits corresponding to $\cos(w(x+y))$ and $sin(w(x+y))$. We get a separate metric for each $w$ in the key frequencies.

Key observation: The excluded loss (especially for frequency 14) starts to go up well before the point of grokking.

(Note: this performance decrease is way more than you'd get for deleting a random direction.)

Exercise - implement excluded loss

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 20-30 minutes on this exercise.

You should fill in the function below to implement excluded loss. You'll need to use the get_trig_sum_directions and project_onto_direction functions from previous exercises. We've given you the first few lines of the function as a guide.

Note - when calculating the loss, you should use the test_logits function, with arguments bias_correction=False, mode="train".

def excl_loss(model: HookedTransformer, key_freqs: list) -> list:
    """
    Returns the excluded loss (i.e. subtracting the components of logits corresponding to
    cos(w_k(x+y)) and sin(w_k(x+y)), for each frequency k in key_freqs.
    """
    excl_loss_list = []
    logits = model(all_data)[:, -1, :-1]
    raise NotImplementedError()


tests.test_excl_loss(excl_loss, model, key_freqs)
Solution
def excl_loss(model: HookedTransformer, key_freqs: list) -> list:
    """
    Returns the excluded loss (i.e. subtracting the components of logits corresponding to
    cos(w_k(x+y)) and sin(w_k(x+y)), for each frequency k in key_freqs.
    """
    excl_loss_list = []
    logits = model(all_data)[:, -1, :-1]
    for freq in key_freqs:
        cos_xplusy_direction, sin_xplusy_direction = get_trig_sum_directions(freq)
logits_cos_xplusy = project_onto_direction(logits, cos_xplusy_direction.flatten())
        logits_sin_xplusy = project_onto_direction(logits, sin_xplusy_direction.flatten())
        logits_excl = logits - logits_cos_xplusy - logits_sin_xplusy
loss = utils.test_logits(logits_excl, bias_correction=False, mode="train").item()
        excl_loss_list.append(loss)
return excl_loss_list

Once you've completed this function, you can run the following code to plot the excluded loss for each of the key frequencies (as well as the training and testing loss as a baseline). This plot should match the one at the end of the Key Claims section of Neel Nanda's LessWrong post.

get_metrics(model, metric_cache, partial(excl_loss, key_freqs=key_freqs), "excl_loss")

plot_metric(
    t.concat(
        [
            metric_cache["excl_loss"].T,
            metric_cache["train_loss"][None, :],
            metric_cache["test_loss"][None, :],
        ]
    ),
    labels=[f"excl {freq}" for freq in key_freqs] + ["train", "test"],
    title="Excluded Loss for each trig component",
    log_y=True,
    yaxis="Loss",
)
Click to see the expected output

Development of the embedding

Embedding in Fourier basis

We can plot the norms of the embedding of each 1D Fourier component at each epoch. Pre-grokking, the model is learning the representation of prioritising a few components, but most components still have non-trivial value, presumably because these directions are doing some work in memorising. Then, during the grokking period, the other components get set to near zero - the model no longer needs other directions to memorise things, it's learned a general algorithm.

(As a reminder, we found that the SVD if the embedding was approximately $W_E \approx F^T S V^T$ where $F$ is the Fourier basis vector and $S$ is sparse, hence $F W_E \approx S V^T$ is also sparse with most rows equal to zero. So when we calculate the norm of the rows of $F W_E$ for intermediate points in training, we're seeing how the input space of the embedding learns to only contain these select few frequencies.)

Exercise - define fourier_embed

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵⚪⚪⚪
This exercise shouldn't take more than ~10 minutes.

Write the function fourier_embed. This should calculate norm of the Fourier transformation of the model's embedding matrix (ignoring the embedding vector corresponding to =). In other words, you should left-multiply the embedding matrix by the Fourier basis matrix, then calculate the sum of the norm of each embedding vector.

def fourier_embed(model: HookedTransformer):
    """
    Returns norm of Fourier transform of the model's embedding matrix.
    """
    raise NotImplementedError()


tests.test_fourier_embed(fourier_embed, model)
Solution
def fourier_embed(model: HookedTransformer):
    """
    Returns norm of Fourier transform of the model's embedding matrix.
    """
    W_E_fourier = fourier_basis.T @ model.W_E[:-1]
    return einops.reduce(W_E_fourier.pow(2), "vocab d_model -> vocab", "sum")

Next, you can plot how the norm of Fourier components of the embedding changes during training:

# Plot every 200 epochs so it's not overwhelming
get_metrics(model, metric_cache, fourier_embed, "fourier_embed")

utils.animate_lines(
    metric_cache["fourier_embed"][::2],
    snapshot_index=epochs[::2],
    snapshot="Epoch",
    hover=fourier_basis_names,
    animation_group="x",
    title="Norm of Fourier Components in the Embedding Over Training",
)
Click to see the expected output

Exercise - Examine the SVD of $W_E$

Difficulty: 🔴⚪⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪
You should spend up to 5-10 minutes on this exercise.

We discussed $W_E \approx F^T S V^T$ as being a good approximation to the SVD, but what happens when we actually calculate the SVD?

You should fill in the following function, which returns the singular values from the singular value decomposition of $W_E$ at an intermediate point in training. (Remember to remove the last row of $W_E$, which corresponds to the bias term.) PyTorch has an SVD function (torch.svd) which you should use for this.

def embed_SVD(model: HookedTransformer) -> Tensor:
    """
    Returns vector S, where W_E = U @ diag(S) @ V.T in singular value decomp.
    """
    raise NotImplementedError()


tests.test_embed_SVD(embed_SVD, model)
get_metrics(model, metric_cache, embed_SVD, "embed_SVD")

utils.animate_lines(
    metric_cache["embed_SVD"],
    snapshot_index=epochs,
    snapshot="Epoch",
    title="Singular Values of the Embedding During Training",
    xaxis="Singular Number",
    yaxis="Singular Value",
)
Click to see the expected output
Solution
def embed_SVD(model: HookedTransformer) -> Tensor:
    """
    Returns vector S, where W_E = U @ diag(S) @ V.T in singular value decomp.
    """
    U, S, V = t.svd(model.W_E[:, :-1])
    return S

Can you interpret what's going on in this plot?

Interpretation

At first, our SVD values are essentially random. Throughout training, the smaller singular values tend to zero, while the singular values corresponding to the key frequencies increase. Eventually, the graph demonstrates a sparse matrix, with all singular values zero except for those corresponding to the key frequencies.

Note - after this point, there are no more exercises. The content just consists of plotting and interpreting/discussing results.

Development of computing trig components

In previous exercises, we've projected our logits or neuron activations onto 2D Fourier basis directions corresponding to $\cos(\omega_k(x+y))$ and $\sin(\omega_k(x+y))$ for each of the key frequencies $k$. We found that these directions explained basically all of the model's performance.

Here, we'll do the same thing, but over time, to see how the model learns to compute these trig components. The code is all provided for you below (it uses some functions you wrote in earlier sections).

The activations are first centered, then their sum of squares is taken, and then the Fourier components are extracted and we see what fraction of variance they explain. This is then averaged across the "output dimension" (which is neurons in the case of the neuron activations, or output classes in the case of the logits).

def tensor_trig_ratio(model: HookedTransformer, mode: str):
    """
    Returns the fraction of variance of the (centered) activations which is explained by the Fourier
    directions corresponding to cos(ω(x+y)) and sin(ω(x+y)) for all the key frequencies.
    """
    logits, cache = model.run_with_cache(all_data)
    logits = logits[:, -1, :-1]
    if mode == "neuron_pre":
        tensor = cache["pre", 0][:, -1]
    elif mode == "neuron_post":
        tensor = cache["post", 0][:, -1]
    elif mode == "logit":
        tensor = logits
    else:
        raise ValueError(f"{mode} is not a valid mode")

    tensor_centered = tensor - einops.reduce(tensor, "xy index -> 1 index", "mean")
    tensor_var = einops.reduce(tensor_centered.pow(2), "xy index -> index", "sum")
    tensor_trig_vars = []

    for freq in key_freqs:
        cos_xplusy_direction, sin_xplusy_direction = get_trig_sum_directions(freq)
        cos_xplusy_projection_var = (
            project_onto_direction(tensor_centered, cos_xplusy_direction.flatten()).pow(2).sum(0)
        )
        sin_xplusy_projection_var = (
            project_onto_direction(tensor_centered, sin_xplusy_direction.flatten()).pow(2).sum(0)
        )
        tensor_trig_vars.extend([cos_xplusy_projection_var, sin_xplusy_projection_var])

    return to_numpy(sum(tensor_trig_vars) / tensor_var)


for mode in ["neuron_pre", "neuron_post", "logit"]:
    get_metrics(
        model,
        metric_cache,
        partial(tensor_trig_ratio, mode=mode),
        f"{mode}_trig_ratio",
        reset=True,
    )

lines_list = []
line_labels = []
for mode in ["neuron_pre", "neuron_post", "logit"]:
    tensor = metric_cache[f"{mode}_trig_ratio"]
    lines_list.append(einops.reduce(tensor, "epoch index -> epoch", "mean"))
    line_labels.append(f"{mode}_trig_frac")

plot_metric(
    lines_list,
    labels=line_labels,
    log_y=False,
    yaxis="Ratio",
    title="Fraction of logits and neurons explained by trig terms",
)
Click to see the expected output

By plotting on a log scale, we can more clearly see that all 3 are having a higher proportion of trig components over training, but that the logits are smoother while the neurons exhibit more of a phase change.

Discussion

In the fully trained model, there are two key components to the algorithm that results in the model being able to meaningfully use trig directions in the logits - firstly that the neuron activations have significant quadratic terms, and secondly that $W_{logit}$ can cancel out all of the non-trig terms, and then map the trig terms to (x + y) % p.

A natural question is whether one of these comes first, or if both evolve in tandem - as far as I'm aware, "how do circuits with multiple moving parts form over training" is not at all understood.

In this case, the logits develop the capability to cancel out everything but the trig directions early on, and the neurons don't develop significant quadratic or trig components until close to the grokking point.

I vaguely speculate that it makes more sense for circuits to develop in "reverse-order" - if we need two layers working together to produce a nice output, then if the second layer is randomly initialised the first layer can do nothing. But if the first layer is randomly initialised, the second layer can learn to just extract the components of the output corresponding to the "correct" output, and use them to badly approximate the output solution. And now the network has a training incentive to build up both parts of the circuit.

(This maybe has something to do with the Lottery Ticket hypothesis?)

Development of neuron activations

There are two notable things about the neuron activations: * They contain a significant component of quadratic terms of with x and y of the same frequency * They group into clusters with Fourier terms of a single frequency

We can study the first one by plotting the fraction of a neuron's (centered) activation explained by the quadratic terms of that neuron's frequency (frequencies taken from the epoch 40K model)

(For the always firing cluster we sum over all frequencies).

def get_frac_explained(model: HookedTransformer) -> Tensor:
    _, cache = model.run_with_cache(all_data, return_type=None)

    returns = []

    for neuron_type in ["pre", "post"]:
        neuron_acts = cache[neuron_type, 0][:, -1].clone().detach()
        neuron_acts_centered = neuron_acts - neuron_acts.mean(0)
        neuron_acts_fourier = fft2d(
            einops.rearrange(neuron_acts_centered, "(x y) neuron -> x y neuron", x=p)
        )

        # Calculate the sum of squares over all inputs, for each neuron
        square_of_all_terms = einops.reduce(
            neuron_acts_fourier.pow(2), "x y neuron -> neuron", "sum"
        )

        frac_explained = t.zeros(utils.d_mlp).to(device)
        frac_explained_quadratic_terms = t.zeros(utils.d_mlp).to(device)

        for freq in key_freqs_plus:
            # Get Fourier activations for neurons in this frequency cluster
            # We arrange by frequency (i.e. each freq has a 3x3 grid with const, linear & quadratic
            # terms)
            acts_fourier = arrange_by_2d_freqs(neuron_acts_fourier[..., neuron_freqs == freq])

            # Calculate the sum of squares over all inputs, after filtering for just this frequency
            # Also calculate the sum of squares for just the quadratic terms in this frequency
            if freq == -1:
                squares_for_this_freq = squares_for_this_freq_quadratic_terms = einops.reduce(
                    acts_fourier[:, 1:, 1:].pow(2), "freq x y neuron -> neuron", "sum"
                )
            else:
                squares_for_this_freq = einops.reduce(
                    acts_fourier[freq - 1].pow(2), "x y neuron -> neuron", "sum"
                )
                squares_for_this_freq_quadratic_terms = einops.reduce(
                    acts_fourier[freq - 1, 1:, 1:].pow(2), "x y neuron -> neuron", "sum"
                )

            frac_explained[neuron_freqs == freq] = (
                squares_for_this_freq / square_of_all_terms[neuron_freqs == freq]
            )
            frac_explained_quadratic_terms[neuron_freqs == freq] = (
                squares_for_this_freq_quadratic_terms / square_of_all_terms[neuron_freqs == freq]
            )

        returns.extend([frac_explained, frac_explained_quadratic_terms])

    frac_active = (neuron_acts > 0).float().mean(0)

    return t.nan_to_num(t.stack(returns + [neuron_freqs, frac_active], axis=0))


get_metrics(model, metric_cache, get_frac_explained, "get_frac_explained")

frac_explained_pre = metric_cache["get_frac_explained"][:, 0]
frac_explained_quadratic_pre = metric_cache["get_frac_explained"][:, 1]
frac_explained_post = metric_cache["get_frac_explained"][:, 2]
frac_explained_quadratic_post = metric_cache["get_frac_explained"][:, 3]
neuron_freqs_ = metric_cache["get_frac_explained"][:, 4]
frac_active = metric_cache["get_frac_explained"][:, 5]

utils.animate_scatter(
    t.stack([frac_explained_quadratic_pre, frac_explained_quadratic_post], dim=1)[:200:5],
    color=neuron_freqs_[:200:5],
    color_name="freq",
    snapshot="epoch",
    snapshot_index=epochs[:200:5],
    xaxis="Quad ratio pre",
    yaxis="Quad ratio post",
    title="Fraction of variance explained by quadratic terms (up to epoch 20K)",
)

utils.animate_scatter(
    t.stack([neuron_freqs_, frac_explained_pre, frac_explained_post], dim=1)[:200:5],
    color=frac_active[:200:5],
    color_name="frac_active",
    snapshot="epoch",
    snapshot_index=epochs[:200:5],
    xaxis="Freq",
    yaxis="Frac explained",
    hover=list(range(utils.d_mlp)),
    title="Fraction of variance explained by this frequency (up to epoch 20K)",
)
Click to see the expected output

Development of commutativity

We can plot the average attention to each position, and see that the model quickly learns to not attend to the final position, but doesn't really learn commutativity (ie equal attention to pos 0 and pos 1) until the grokking point.

Aside: There's a weird phase change at epoch 43K ish, where it starts to attend to position 2 again - I haven't investigated what's up with that yet.

(Each frame is 100 epochs)

def avg_attn_pattern(model: HookedTransformer):
    _, cache = model.run_with_cache(all_data, return_type=None)
    return to_numpy(
        einops.reduce(cache["pattern", 0][:, :, 2], "batch head pos -> head pos", "mean")
    )


get_metrics(model, metric_cache, avg_attn_pattern, "avg_attn_pattern")

utils.imshow_div(
    metric_cache["avg_attn_pattern"][::5],
    animation_frame=0,
    animation_name="head",
    title="Avg attn by position and head, snapped every 100 epochs",
    xaxis="Pos",
    yaxis="Head",
    zmax=0.5,
    zmin=0.0,
    color_continuous_scale="Blues",
    text_auto=".3f",
)
Click to see the expected output

We can also see this by plotting the average difference between pos 0 and pos 1.

utils.lines(
    (metric_cache["avg_attn_pattern"][:, :, 0] - metric_cache["avg_attn_pattern"][:, :, 1]).T,
    labels=[f"Head {i}" for i in range(4)],
    x=epochs,
    xaxis="Epoch",
    yaxis="Average difference",
    title="Attention to pos 0 - pos 1 by head over training",
    width=900,
    height=450,
)
Click to see the expected output

Small lag to clean up noise

We plot test and train loss over training.

We further define trig loss as the loss where we extract out just the directions of the logits corresponding to $\cos(w(x+y)),\sin(w(x+y))$ in the key frequencies. We run this on all of the data, and on just the training set.

Observations: * Trig loss on all data and train loss on just the training data are identical, showing that these dimensions are only used for a general algorithm treating train and test equally, rather than memorisation. * Trig loss crashes before test loss crashes, and during the grokking period trig loss proportionately is much lower (by a factor of 10^4-10^5), but after grokking they return to a low ratio. This suggests that there's a small lag between the phase change where the model fully learns the general algorithm and where it cleans up the noise left over by the memorisation circuit

Aside: Projecting onto the trig dimensions requires all of the data to be input into the model. To calculate the trig train loss, we first get the logits for all of the data, then project onto the trig components, then throw away the test data logits.

def trig_loss(model: HookedTransformer, mode: str = "all"):
    logits = model(all_data)[:, -1, :-1]

    trig_logits = []
    for freq in key_freqs:
        cos_xplusy_dir, sin_xplusy_dir = get_trig_sum_directions(freq)
        cos_xplusy_proj = project_onto_direction(logits, cos_xplusy_dir.flatten())
        sin_xplusy_proj = project_onto_direction(logits, sin_xplusy_dir.flatten())
        trig_logits.extend([cos_xplusy_proj, sin_xplusy_proj])
    trig_logits = sum(trig_logits)

    return utils.test_logits(trig_logits, bias_correction=True, original_logits=logits, mode=mode)


get_metrics(model, metric_cache, trig_loss, "trig_loss")
get_metrics(model, metric_cache, partial(trig_loss, mode="train"), "trig_loss_train")

line_labels = ["test_loss", "train_loss", "trig_loss", "trig_loss_train"]
plot_metric(
    [metric_cache[lab] for lab in line_labels],
    labels=line_labels,
    title="Different losses over training",
)
plot_metric(
    [metric_cache["test_loss"] / metric_cache["trig_loss"]],
    title="Ratio of trig and test loss",
)
Click to see the expected output

Development of squared sum of the weights

Another data point is looking at the sum of squared weights for each parameter. Here we see several phases:

  • (0-1K) The model first uses the neurons to memorise (which significantly increases the total weights of $W_{in}$ and $W_{out}$ but not the rest)
  • (1K - 8K) It then smoothes out the computation across the model, so all weight matrices have the same total sum. In parallel, all matrices have total sum decreasing, presumably as it learns to use the trig directions.
  • (8K-13K) It then groks the solution and things rapidly decrease. Presumably, it has learned how to use the trig directions well enough that it can clean up all other directions used for memorisation.
  • (13K-43K) Then all weights plateau
  • (43K-) In the total weight graph, we see a small but noticeable kink when we zoom in at this point, a final phase change. (955 to 942)
def sum_sq_weights(model):
    return [param.pow(2).sum().item() for name, param in model.named_parameters()]


get_metrics(model, metric_cache, sum_sq_weights, "sum_sq_weights")

plot_metric(
    metric_cache["sum_sq_weights"].T,
    title="Sum of squared weights for each parameter",
    labels=[name.split(".")[-1] for name, _ in model.named_parameters()],
    log_y=False,
)
plot_metric(
    [einops.reduce(metric_cache["sum_sq_weights"], "epoch param -> epoch", "sum")],
    title="Total sum of squared weights",
    log_y=False,
)
Click to see the expected output