3️⃣ Training a Transformer

Learning Objectives
  • Understand how to train a transformer from scratch
  • Write a basic transformer training loop
  • Interpret the transformer's falling cross entropy loss with reference to features of the training data (e.g. bigram frequencies)

Now that we've built our transformer, and verified that it performs as expected when we load in weights, let's try training it from scratch!

This is a lightweight demonstration of how you can actually train your own GPT-2 with this code! Here we train a tiny model on a tiny dataset, but it's fundamentally the same code for training a larger/more real model (though you'll need beefier GPUs and data parallelism to do it remotely efficiently, and fancier parallelism for much bigger ones).

For our purposes, we'll train a 4 layer model with 16 heads per layer, with context length 128, for 10*500 steps of batch size 32, just to show what it looks like (and so the notebook doesn't melt your colab / machine!).

Create Model

model_cfg = Config(
    debug=False,
    d_model=32,
    n_heads=16,
    d_head=2,
    d_mlp=32 * 4,
    n_layers=4,
    n_ctx=128,
    d_vocab=reference_gpt2.cfg.d_vocab,
)
model = DemoTransformer(model_cfg)

Training Args

Note, for this optimization we'll be using weight decay.

@dataclass
class TransformerTrainingArgs:
    batch_size: int = 32
    epochs: int = 10
    max_steps_per_epoch: int = 500
    lr: float = 1e-3
    weight_decay: float = 1e-2
    wandb_project: str | None = "day1-demotransformer"
    wandb_name: str | None = None


args = TransformerTrainingArgs()

Create Data

We load in the TinyStories dataset, a dataset of synthetically generated simple stories only using a small vocabulary of words that typical 3 to 4-year-olds can understand. This dataset was designed for exploring how small a LLM can be that can still generate coherent text.

dataset = datasets.load_dataset("roneneldan/TinyStories", split="train")
print(dataset)
print(dataset[0]["text"])

tokenize_and_concatenate is a useful function which takes our dataset of strings, and returns a dataset of token IDs ready to feed into the model. We then create a dataloader from this tokenized dataset. The useful method train_test_split can give us a training and testing set.

tokenized_dataset = tokenize_and_concatenate(
    dataset,
    reference_gpt2.tokenizer,
    streaming=False,
    max_length=model.cfg.n_ctx,
    column_name="text",
    add_bos_token=True,
    num_proc=4,
)

dataset_dict = tokenized_dataset.train_test_split(test_size=1000)
train_loader = DataLoader(
    dataset_dict["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = DataLoader(
    dataset_dict["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True
)

When we iterate through these dataloaders, we will find dictionaries with the single key 'tokens', which maps to a tensor of token IDs with shape (batch, seq_len).

first_batch = train_loader.dataset[: args.batch_size]

print(first_batch.keys())
print(first_batch["tokens"].shape)

Training Loop

If you did the material on training loops during the first week, this should all be familiar to you. If not, you can skim that section for an overview of the key concepts. The start of the Training loop section is most important, and the subsections on Modularisation and dataclasses are also very useful. Lastly, we'll also be using Weights and Biases to train our model - you can read about how to use it here. Here are (roughly) all the things you should know for the following exercises:

  • The key parts of a gradient update step are:
    • Calculating the (cross-entropy) loss between a model's output and the true labels,
    • loss.backward() - calculate gradients of the loss with respect to the model parameters,
    • optimizer.step() - update the model parameters using the gradients,
    • optimizer.zero_grad() - zero the gradients so they don't accumulate.
  • We can nicely package up training loops into a class, which includes methods for training and validation steps among other things. This helps with writing code that can be reused in different contexts.
  • We can use dataclasses to store all the arguments relevant to training in one place, and then pass them to our trainer class. Autocompletion is one nice bonus of this!
    • Be careful of scope here, you want to make sure you're referring to self.args within the trainer class, rather than the global args.
  • You can use Weights and Biases to track experiments and log relevant variables. The three essential functions are:
    • wandb.init() - initialize a new run, takes arguments project, name and config (among others).
    • wandb.log() - log a dictionary of variables, e.g. {"loss": loss}. Also takes a step argument.
    • wandb.finish() - called at the end of training (no arguments).

Exercise - write training loop

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪
You should spend up to 10-20 minutes on this exercise.

You should fill in the methods below. Some guidance:

  • Remember we were able to calculate cross entropy loss using the get_log_probs function in the previous section.
  • You should use the optimizer t.optim.AdamW (Adam with weight decay), and with hyperparameters lr and weight_decay taken from your TransformerTrainingArgs dataclass instance.
  • We've given you the argument max_steps_per_epoch, a hacky way of making sure the training phase in each epoch doesn't go on for too long. You can terminate each training phase after this many steps. It's set to a default value that should lead to a very short run demonstrating nontrivial model performance.
  • Remember to move tokens to your device, via tokens.to(device) (this should be a global variable, defined at the top of your notebook).
  • You can refer back to the training loops from the previous chapter of the course if you'd like.
  • We've also provided an instance of the TransformerSampler class so you can generate text from your model during training to see how it's doing. We will cover how sampling works in the next section.
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: DemoTransformer):
        super().__init__()
        self.model = model
        self.args = args
        self.sampler = solutions.TransformerSampler(self.model, reference_gpt2.tokenizer)
        self.optimizer = t.optim.AdamW(
            self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay
        )
        self.step = 0

        self.train_loader = DataLoader(
            dataset_dict["train"],
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
        )
        self.test_loader = DataLoader(
            dataset_dict["test"],
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )

    def training_step(self, batch: dict[str, Int[Tensor, "batch seq"]]) -> Float[Tensor, ""]:
        """
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        """
        raise NotImplementedError()
        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        """
        Evaluate the model on the test set and return the accuracy.
        """
        self.model.eval()
        #
        # YOUR CODE HERE - fill in the `evaluate` method
        #
        self.model.train()
        return accuracy

    def train(self):
        """
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        """
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        accuracy = np.nan

        progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader):
                loss = self.training_step(batch)
                progress_bar.update()
                progress_bar.set_description(
                    f"Epoch {epoch + 1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}"
                )
                if i >= self.args.max_steps_per_epoch:
                    break

            accuracy = self.evaluate()
            sample_text = self.sampler.sample("Once upon a time", max_tokens_generated=50)
            print(sample_text)

        wandb.finish()


# See the full run here: https://api.wandb.ai/links/dquarel/nrxuwnv7
model = DemoTransformer(model_cfg).to(device)
args = TransformerTrainingArgs()
trainer = TransformerTrainer(args, model)
trainer.train()
Solution
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: DemoTransformer):
        super().__init__()
        self.model = model
        self.args = args
        self.sampler = solutions.TransformerSampler(self.model, reference_gpt2.tokenizer)
        self.optimizer = t.optim.AdamW(
            self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay
        )
        self.step = 0
self.train_loader = DataLoader(
            dataset_dict["train"],
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
        )
        self.test_loader = DataLoader(
            dataset_dict["test"],
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )
def training_step(self, batch: dict[str, Int[Tensor, "batch seq"]]) -> Float[Tensor, ""]:
        """
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.
        Remember that batch is a dictionary with the single key 'tokens'.
        """
        tokens = batch["tokens"].to(device)
        logits = self.model(tokens)
        loss = -get_log_probs(logits, tokens).mean()
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1
        wandb.log({"train_loss": loss}, step=self.step)
        return loss
@t.inference_mode()
    def evaluate(self) -> float:
        """
        Evaluate the model on the test set and return the accuracy.
        """
        self.model.eval()
        total_correct, total_samples = 0, 0
for batch in tqdm(self.test_loader, desc="Evaluating"):
            tokens = batch["tokens"].to(device)
            logits: Tensor = self.model(tokens)[:, :-1]
            predicted_tokens = logits.argmax(dim=-1)
            total_correct += (predicted_tokens == tokens[:, 1:]).sum().item()
            total_samples += tokens.size(0)  (tokens.size(1) - 1)
accuracy = total_correct / total_samples
        wandb.log({"accuracy": accuracy}, step=self.step)
        self.model.train()
        return accuracy
def train(self):
        """
        Trains the model, for self.args.epochs epochs. Also handles wandb initialisation, and early stopping
        for each epoch at self.args.max_steps_per_epoch steps.
        """
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        accuracy = np.nan
progress_bar = tqdm(total=self.args.max_steps_per_epoch  self.args.epochs)
for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader):
                loss = self.training_step(batch)
                progress_bar.update()
                progress_bar.set_description(
                    f"Epoch {epoch + 1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}"
                )
                if i >= self.args.max_steps_per_epoch:
                    break
accuracy = self.evaluate()
            sample_text = self.sampler.sample("Once upon a time", max_tokens_generated=50)
            print(sample_text)
wandb.finish()

When you run the code for the first time, you'll have to login to Weights and Biases, and paste an API key into VSCode. After this is done, your Weights and Biases training run will start. It'll give you a lot of output text, one line of which will look like:

View run at https://wandb.ai/<USERNAME>/<PROJECT-NAME>/runs/<RUN-NAME>

which you can click on to visit the run page.

Note - to see the plots more clearly in Weights and Biases, you can click on the edit panel of your plot (the small pencil symbol at the top-right), then move the smoothing slider to the right.

A note on this loss curve (optional)

What's up with the shape of our loss curve? It seems like we start at around 10-11, drops down very fast, but then levels out. It turns out, this is all to do with the kinds of algorithms the model learns during training.

When it starts out, your model will be outputting random noise, which might look a lot like "predict each token with approximately uniform probability", i.e. $Q(x) = 1/d_\text{vocab}$ for all $x$. This gives us a cross entropy loss of $\log (d_\text{vocab})$.

d_vocab = model.cfg.d_vocab

print(f"d_vocab = {d_vocab}")
print(f"Cross entropy loss on uniform distribution = {math.log(d_vocab):.3f}")
d_vocab = 50257
Cross entropy loss on uniform distribution = 10.825

The next thing we might expect the model to learn is the frequencies of words in the english language. After all, small common tokens like " and" or " the" might appear much more frequently than others. This would give us an average cross entropy loss of:

$$ - \sum_x p_x \log p_x $$

where $p_x$ is the actual frequency of the word in our training data.

We can evaluate this quantity as follows:

toks = tokenized_dataset[:]["tokens"].flatten()

d_vocab = model.cfg.d_vocab
freqs = t.bincount(toks, minlength=d_vocab)
probs = freqs.float() / freqs.sum()

distn = t.distributions.categorical.Categorical(probs=probs)
entropy = distn.entropy()

print(f"Entropy of training data = {entropy:.3f}")
Entropy of training data = 7.349

After unigram frequencies, the next thing our model usually learns is bigram frequencies (i.e. the frequency of pairs of adjacent tokens in the training data). For instance, "I" and " am" are common tokens, but their bigram frequency is much higher than it would be if they occurred independently. Bigram frequencies actually take you pretty far, since they also help with:

  • Some simple grammatical rules (e.g. a full stop being followed by a capitalized word)
  • Weird quirks of tokenization (e.g. " manip" being followed by "ulative")
  • Common names (e.g. "Barack" being followed by " Obama")

After approximating bigram frequencies, we need to start using smarter techniques, like trigrams (which can only be implemented using attention heads), induction heads (which we'll learn a lot more about in the next set of exercises!), and fact memorization or more basic grammar and syntax rules. Marginal improvements start getting harder around this point, leading to a flattening of our loss curve.

Exercise (optional) - log completions

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵⚪⚪⚪⚪
You should spend up to 20-40 minutes on this exercise, if you choose to attempt it. Note, you might want to come back to this exercise *after* you learn how sampling works.

Choose a handful of prompts, and log the model's completions on those sentences. We recommend you do this with a lower frequency than loss is logged (e.g. once every 10-100 batches).

The wandb syntax for logging text is pretty simple. Firstly, you can just print output as stdout and this is also logged to Weights & Biases (you can find it under the "Logs" section of your run). Alternatively, you can log data in the form of a table, and have it appear next to your other charts:

wandb.log({"completions_table": wandb.Table(
    data = data,
    columns = ["epoch", "step", "text"]
)})

where data is a list of length-3 lists, with each list containing (epoch, step, text). If you choose this option, we recommend logging the table less frequently than you're sampling from the model, to make sure you're not sending too much data (because unfortunately wandb doesn't have methods to incrementally update the table during logging).

If you want to try this before going through the sampling exercises (which are quite long!), you can use the code below to sample output from the model. Note that the TransformerSampler object is already in inference mode, so you don't need to worry about this.

def sampling_fn(model: DemoTransformer, prompt: str) -> str:
    sampler = solutions.TransformerSampler(model, reference_gpt2.tokenizer)
    output = sampler.sample(prompt, temperature=0.7, top_p=0.95, max_tokens_generated=16)
    return output


model = DemoTransformer(model_cfg).to(device)

# Should be entirely random, because it uses a newly initialized model
print(sampling_fn(model, prompt="John and Mary went to the"))
John and Mary went to theLittlealmernaut estranged broadcaster Workers reapp skull consecutivepexuaniaarrow drilling Burnett ASDMusic
# YOUR CODE HERE - rewrite the TransformerTrainer.train method, so that it logs completions


prompt_list = [
    "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for",
    "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.",
    "John and Mary went to the",
]

model = DemoTransformer(model_cfg).to(device)
args = TransformerTrainingArgsLogText()
trainer = TransformerTrainer(args, model)
trainer.train(sampling_fn, prompt_list)
# Read full report here - https://api.wandb.ai/links/callum-mcdougall/5ex16e5w
Solution
@dataclass
class TransformerTrainingArgsLogText(TransformerTrainingArgs):
    text_sample_freq: int = 20
    table_log_freq: int = 200
def __post_init__(self):
        assert self.table_log_freq >= self.text_sample_freq, (
            "You should log the table less frequently than you add text to it."
        )
def train_log_text(self: TransformerTrainer, sampling_fn: Callable, prompt_list: list[str]):
    """
    Trains the model, for self.args.epochs epochs. Also handles wandb initialisation, and early stopping
    for each epoch at self.args.max_steps_per_epoch steps.
    This also takes 2 extra arguments:
        sampling_fn: function which takes model & a single prompt (i.e. text string) and returns text string output
        prompt_list: list of prompts we'll log output on
    """
    wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
    accuracy = np.nan
    progress_bar = tqdm(total=self.args.max_steps_per_epoch  self.args.epochs)
# Create a list for storing data
    completions_list = []
for epoch in range(self.args.epochs):
        for i, batch in enumerate(self.train_loader):
            loss = self.training_step(batch)
            progress_bar.update()
            progress_bar.set_description(
                f"Epoch {epoch + 1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}"
            )
# Control the adding of text to the table, and the logging of text
            if self.step % self.args.text_sample_freq == 0:
                text_completions = [sampling_fn(self.model, prompt) for prompt in prompt_list]
                completions_list.append([epoch, self.step, text_completions])
            if self.step % self.args.table_log_freq == 0:
                wandb.log(
                    {
                        "completions_table": wandb.Table(
                            data=completions_list,
                            columns=[
                                "epoch",
                                "step",
                                *[f"prompt_{i}" for i in range(len(prompt_list))],
                            ],
                        )
                    }
                )
if i >= self.args.max_steps_per_epoch:
                break
accuracy = self.evaluate()
wandb.finish()
TransformerTrainer.train = train_log_text
prompt_list = [
    "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for",
    "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.",
    "John and Mary went to the",
]
model = DemoTransformer(model_cfg).to(device)
args = TransformerTrainingArgsLogText()
trainer = TransformerTrainer(args, model)
trainer.train(sampling_fn, prompt_list)
# Read full report here - https://api.wandb.ai/links/callum-mcdougall/5ex16e5w

You shouldn't expect to see perfect logical coherence from your model, but you should at least see that it respects basic word frequencies, and follows basic rules of grammar some of the time. Hopefully this gives some perspective on how difficult training a transformer can be!