[1.5.1] Balanced Bracket Classifier

Colab: exercises | solutions

Please send any problems / bugs on the #errata channel in the Slack group, and ask any questions on the dedicated channels for this chapter of material.

If you want to change to dark mode, you can do this by clicking the three horizontal lines in the top-right, then navigating to Settings → Theme.

Links to all other chapters: (0) Fundamentals, (1) Transformer Interpretability, (2) RL.

Introduction

When models are trained on synthetic, algorithmic tasks, they often learn to do some clean, interpretable computation inside. Choosing a suitable task and trying to reverse engineer a model can be a rich area of interesting circuits to interpret! In some sense, this is interpretability on easy mode - the model is normally trained on a single task (unlike language models, which need to learn everything about language!), we know the exact ground truth about the data and optimal solution, and the models are tiny. So why care?

Working on algorithmic problems gives us the opportunity to:

  • Practice interpretability, and build intuitions and learn techniques.
  • Refine our understanding of the right tools and techniques, by trying them out on problems with well-understood ground truth.
  • Isolate a particularly interesting kind of behaviour, in order to study it in detail and understand it better (e.g. Anthropic's Toy Models of Superposition paper).
  • Take the insights you've learned from reverse-engineering small models, and investigate which results will generalise, or whether any of the techniques you used to identify circuits can be automated and used at scale.

The algorithmic problem we'll work on in these exercises is bracket classification, i.e. taking a string of parentheses like "(())()" and trying to output a prediction of "balanced" or "unbalanced". We will find an algorithmic solution for solving this problem, and reverse-engineer one of the circuits in our model that is responsible for implementing one part of this algorithm.

This page contains a large number of exercise. Each exercise will have a difficulty and importance rating out of 5, as well as an estimated maximum time you should spend on these exercises and sometimes a short annotation. You should interpret the ratings & time estimates relatively (e.g. if you find yourself spending about 50% longer on the exercises than the time estimates, adjust accordingly). Please do skip exercises / look at solutions if you don't feel like they're important enough to be worth doing, and you'd rather get to the good stuff!

Motivation

In A Mathematical Framework for Transformer Circuits, we got a lot of traction interpreting toy language models - that is, transformers trained in exactly the same way as larger models, but with only 1 or 2 layers. It seems likely that there’s a lot of low-hanging fruit left to pluck when studying toy language models!

So, why care about studying toy language models? The obvious reason is that it’s way easier to get traction. In particular, the inputs and outputs of a model are intrinsically interpretable, and in a toy model there’s just not as much space between the inputs and outputs for weird complexity to build up. But the obvious objection to the above is that, ultimately, we care about understanding real models (and ideally extremely large ones like GPT-3), and learning to interpret toy models is not the actual goal. This is a pretty valid objection, but there are two natural ways that studying toy models can be valuable:

The first is by finding fundamental circuits that recur in larger models, and motifs that allow us to easily identify these circuits in larger models. A key underlying question here is that of universality: does each model learn its own weird way of completing its task, or are there some fundamental principles and algorithms that all models converge on?

The second is by forming a better understanding of how to reverse engineer models - what are the right intuitions and conceptual frameworks, what tooling and techniques do and do not work, and what weird limitations we might be faced with. For instance, the work in A Mathematical Framework presents ideas like the residual stream as the central object, and the significance of the QK-Circuits and OV-Circuits, which seem to generalise to many different models. We'll also see an example later in these exercises which illustrates how MLPs can be thought of as a collection of neurons which activate on different features, just like many seem to in language models. But there’s also ways it can be misleading, and some techniques that work well in toy models seem to generalise less well.

The purpose / structure of these exercises

At a surface level, these exercises are designed to guide you through a partial interpretation of the bidirectional model trained on bracket classification. But it's also designed to make you a better interpretability researcher! As a result, most exercises will be doing a combination of:

  1. Showing you some new feature/component of the circuit, and
  2. Teaching you how to use tools and interpret results in a broader mech interp context.

As you're going through these exercises, it's easy to get lost in the fiddly details of the techniques you're implementing or the things you're computing. Make sure you keep taking a high-level view, asking yourself what questions you're currently trying to ask and how you'll interpret the output you're getting, as well as how the tools you're currently using are helping guide you towards a better understanding of the model.

Content & Learning Objectives

1️⃣ Bracket classifier

This section describes how transformers can be used for classification, and the details of how this works in TransformerLens (using permanent hooks). It also takes you through the exercise of hand-writing a solution to the balanced brackets problem.

This section mainly just lays the groundwork; it is very light on content.

Learning Objectives
  • Understand how transformers can be used for classification.
  • Understand how to implement specific kinds of transformer behaviour (e.g. masking of padding tokens) via permanent hooks in TransformerLens.
  • Start thinking about the kinds of algorithmic solutions a transformer is likely to find for problems such as these, given its inductive biases.

2️⃣ Moving backwards

Here, you'll perform logit attribution, and learn how to work backwards through particular paths of a model to figure out which components matter most for the final classification probabilities.

This is the first time you'll have to deal with LayerNorm in your models.

This section should be familiar if you've done logit attribution for induction heads (although these exercises are slightly more challenging from a coding perspective). The LayerNorm-based exercises are a bit fiddly!

Learning Objectives
  • Understand how to perform logit attribution.
  • Understand how to work backwards through a model to identify which components matter most for the final classification probabilities.
  • Understand how LayerNorm works, and look at some ways to deal with it in your models.

3️⃣ Total elevation circuit

This section is quite challenging both from a coding and conceptual perspective, because you need to link the results of your observations and interventions to concrete hypotheses about how the model works.

In the largest section of the exercises, you'll examine the attention patterns in different heads, and interpret them as performing some human-understandable algorithm (e.g. copying, or aggregation). You'll use your observations to make deductions about how a particular type of balanced brackets failure mode (mismatched number of left and right brackets) is detected by your model.

This is the first time you'll have to deal with MLPs in your models.

Learning Objectives
  • Practice connecting distinctive attention patterns to human-understandable algorithms, and making deductions about model behaviour.
  • Understand how MLPs can be viewed as a collection of neurons.
  • Build up to a full picture of the total elevation circuit and how it works.

☆ Bonus exercises

Lastly, there are a few optional bonus exercises which build on the previous content (e.g. having you examine different parts of the model, or use your understanding of how the model works to generate adversarial examples).

This final section is less guided, although the suggested exercises are similar in flavour to the previous section.

Learning Objectives
  • Use your understanding of how the model works to generate adversarial examples.
  • Take deeper dives into specific anomalous features of the model.

Setup code

import json
import sys
from functools import partial
from pathlib import Path

import circuitsvis as cv
import einops
import torch as t
from IPython.display import display
from jaxtyping import Bool, Float, Int
from sklearn.linear_model import LinearRegression
from torch import Tensor, nn
from tqdm import tqdm
from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig, utils
from transformer_lens.hook_points import HookPoint

device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)
t.set_grad_enabled(False)

# Make sure exercises are in the path
chapter = "chapter1_transformer_interp"
section = "part51_balanced_bracket_classifier"
exercises_dir = next(p for p in Path.cwd().parents if p.name == chapter) / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
    sys.path.append(str(exercises_dir))

import part51_balanced_bracket_classifier.tests as tests
import plotly_utils
from part51_balanced_bracket_classifier.brackets_datasets import BracketsDataset, SimpleTokenizer
from plotly_utils import bar, hist

MAIN = __name__ == "__main__"