[1.1] - Transformers from scratch

Colab: exercises | solutions

Please send any problems / bugs on the #errata channel in the Slack group, and ask any questions on the dedicated channels for this chapter of material.

If you want to change to dark mode, you can do this by clicking the three horizontal lines in the top-right, then navigating to Settings → Theme.

Links to all other chapters: (0) Fundamentals, (1) Transformer Interpretability, (2) RL.

Introduction

This is a clean, first principles implementation of GPT-2 in PyTorch. The architectural choices closely follow those used by the TransformerLens library (which you'll be using a lot more in later exercises).

The exercises are written to accompany Neel Nanda's TransformerLens library for doing mechanistic interpretability research on GPT-2 style language models. We'll be working with this library extensively in this chapter of the course.

Each exercise will have a difficulty and importance rating out of 5, as well as an estimated maximum time you should spend on these exercises and sometimes a short annotation. You should interpret the ratings & time estimates relatively (e.g. if you find yourself spending about 50% longer on the exercises than the time estimates, adjust accordingly). Please do skip exercises / look at solutions if you don't feel like they're important enough to be worth doing, and you'd rather get to the good stuff!

For a lecture on the material today, which provides some high-level understanding before you dive into the material, watch the video below:

Content & Learning Objectives

1️⃣ Understanding Inputs & Outputs of a Transformer

In this section, we'll take a first look at transformers - what their function is, how information moves inside a transformer, and what inputs & outputs they take.

Learning Objectives
  • Understand what a transformer is used for
  • Understand causal attention, and what a transformer's output represents—algebra operations on tensors
  • Learn what tokenization is, and how models do it
  • Understand what logits are, and how to use them to derive a probability distribution over the vocabulary

2️⃣ Clean Transformer Implementation

Here, we'll implement a transformer from scratch, using only PyTorch's tensor operations. This will give us a good understanding of how transformers work, and how to use them. We do this by going module-by-module, in an experience which should feel somewhat similar to last week's ResNet exercises. Much like with ResNets, you'll conclude by loading in pretrained weights and verifying that your model works as expected.

Learning Objectives
  • Understand that a transformer is composed of attention heads and MLPs, with each one performing operations on the residual stream
  • Understand that the attention heads in a single layer operate independently, and that they have the role of calculating attention patterns (which determine where information is moved to & from in the residual stream)
  • Learn about & implement the following transformer modules:
    • LayerNorm (transforming the input to have zero mean and unit variance)
    • Positional embedding (a lookup table from position indices to residual stream vectors)
    • Attention (the method of computing attention patterns for residual stream vectors)
    • MLP (the collection of linear and nonlinear transformations which operate on each residual stream vector in the same way)
    • Embedding (a lookup table from tokens to residual stream vectors)
    • Unembedding (a matrix for converting residual stream vectors into a distribution over tokens)

3️⃣ Training a Transformer

Next, you'll learn how to train your transformer from scratch. This will be quite similar to the training loops you wrote for ResNet in your first week.

Learning Objectives
  • Understand how to train a transformer from scratch
  • Write a basic transformer training loop
  • Interpret the transformer's falling cross entropy loss with reference to features of the training data (e.g. bigram frequencies)

4️⃣ Sampling from a Transformer

Lastly, you'll learn how to sample from a transformer. This will involve implementing a few different sampling methods, and writing a caching system which can reuse computations from previous forward passes to improve your model's text generation speed.

The second half of this section is less important, and you can skip it if you want.

Learning Objectives
  • Learn how to sample from a transformer
    • This includes basic methods like greedy search or top-k, and more advanced methods like beam search
  • Learn how to cache the output of a transformer, so that it can be used to generate text more efficiently
    • Optionally, rewrite your sampling functions to make use of your caching methods

Setup code

import math
import os
import sys
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Callable

import datasets
import einops
import numpy as np
import torch as t
import torch.nn as nn
import wandb
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast

device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)

# Make sure exercises are in the path
chapter = "chapter1_transformer_interp"
section = "part1_transformer_from_scratch"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
    sys.path.append(str(exercises_dir))

import part1_transformer_from_scratch.solutions as solutions
import part1_transformer_from_scratch.tests as tests
from plotly_utils import imshow

MAIN = __name__ == "__main__"