2️⃣ Training Neural Networks
Learning Objectives
- Understand how to work with transforms, datasets and dataloaders
- Understand the basic structure of a training loop
- Learn how to write your own validation loop
Transforms, Datasets & DataLoaders
Before we use this model to make any predictions, we first need to think about our input data. Below is a block of code to fetch and process MNIST data. We will go through it line by line.
MNIST_TRANSFORM = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(0.1307, 0.3081),
]
)
def get_mnist(trainset_size: int = 10_000, testset_size: int = 1_000) -> tuple[Subset, Subset]:
"""Returns a subset of MNIST training data."""
# Get original datasets, which are downloaded to "./data" for future use
mnist_trainset = datasets.MNIST(
exercises_dir / "data", train=True, download=True, transform=MNIST_TRANSFORM
)
mnist_testset = datasets.MNIST(
exercises_dir / "data", train=False, download=True, transform=MNIST_TRANSFORM
)
# # Return a subset of the original datasets
mnist_trainset = Subset(mnist_trainset, indices=range(trainset_size))
mnist_testset = Subset(mnist_testset, indices=range(testset_size))
return mnist_trainset, mnist_testset
mnist_trainset, mnist_testset = get_mnist()
mnist_trainloader = DataLoader(mnist_trainset, batch_size=64, shuffle=True)
mnist_testloader = DataLoader(mnist_testset, batch_size=64, shuffle=False)
# Get the first batch of test data, by starting to iterate over `mnist_testloader`
for img_batch, label_batch in mnist_testloader:
print(f"{img_batch.shape=}\n{label_batch.shape=}\n")
break
# Get the first datapoint in the test set, by starting to iterate over `mnist_testset`
for img, label in mnist_testset:
print(f"{img.shape=}\n{label=}\n")
break
t.testing.assert_close(img, img_batch[0])
assert label == label_batch[0].item()
The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision, and torchvision.transforms provides access to a suite of functions for preprocessing data. We define a transform for the MNIST data (which is applied to each image in the dataset) by composing ToTensor (which converts a PIL.Image object into a PyTorch tensor) and Normalize (which takes arguments for the mean and standard deviation, and performs the linear transformation x -> (x - mean) / std). For the latter, we use 0.1307 and 0.3081 which are the empirical mean & std of the raw data (so after this transformation, the data will have mean 0 and variance 1).
Next, we define our datasets using torchvision.datasets. The first argument tells us where to save our data to (so that when we run this in the future we won't have to re-save it), and transform=MNIST_TRANSFORM tells us that we should apply our previously defined transform to each element in our dataset. We also use Subset which allows us to return a slice of the dataset rather than the whole thing (because our model won't need much data to train!).
Finally, since our dataset only allows for iteration over individual datapoints, we wrap it in DataLoader which enables iteration over batches. It also provides useful arguments like shuffle, which determine whether we randomize the order after each epoch. The code above demonstrates iteration over the dataset & dataloader respectively, showing how the first element in the dataloader's first batch equals the first element in the dataset (note that this wouldn't be true for the training set, because we've shuffled it).
Aside - why batch sizes are often powers of 2
It's common to see batch sizes which are powers of two. The motivation is for efficient GPU utilisation, since processor architectures are normally organised around powers of 2, and computational efficiency is often increased by having the items in each batch split across processors. Or at least, that's the idea. The truth is a bit more complicated, and some studies dispute whether it actually saves time, so at this point it's more of a standard convention than a hard rule which will always lead to more efficient training.
Before proceeding, try and answer the following questions:
Question - can you explain why we include a data normalization function in torchvision.transforms ?
One consequence of unnormalized data is that you might find yourself stuck in a very flat region of the domain, and gradient descent may take much longer to converge.
Normalization isn't strictly necessary for this reason, because any rescaling of an input vector can be effectively undone by the network learning different weights and biases. But in practice, it does usually help speed up convergence.
Normalization also helps avoid numerical issues.
Question - what is the benefit of using shuffle=True when defining our dataloaders? What might the problem be if we didn't do this?
Shuffling is done during the training to make sure we aren't exposing our model to the same cycle (order) of data in every epoch. It is basically done to ensure the model isn't adapting its learning to any kind of spurious pattern.
Aside - tqdm
You might have seen some blue progress bars running when you first downloaded your MNIST data. These were generated using a library called tqdm, which is also a really useful tool when training models or running any process that takes a long period of time.
The tqdm function wraps around an iterable, and displays a progress bar as you iterate through it. The code below shows a minimal example:
from tqdm.notebook import tqdm
import time
for i in tqdm(range(100)):
time.sleep(0.1)
There are some more advanced features of tqdm too, for example:
- If you define the progress bar
pbar = tqdm(...)before your iteration, then you have the option of adding extra information to it usingpbar.set_descriptionorpbar.set_postfix - You can specify the total number of iterations with
tqdm(iterable, total=...); this is actually very important when the iterable is something likeenumerate(...)which doesn't have a length attribute, since tqdm will usually try and infer the total from callinglenon the iterable you pass it.
Here's some code that demonstrates these extra features:
word = "hello!"
pbar = tqdm(enumerate(word), total=len(word))
t0 = time.time()
for i, letter in pbar:
time.sleep(1.0)
pbar.set_postfix(i=i, letter=letter, time=f"{time.time()-t0:.3f}")
Aside - device
One last thing to discuss before we move onto training our model: GPUs. We'll discuss this in more detail in later exercises. For now, this page should provide a basic overview of how to use your GPU. A few things to be aware of here:
- The
tomethod is really useful here - it can move objects between different devices (i.e. CPU and GPU) as well as changing a tensor's datatype.- Note that
tois never inplace for tensors (i.e. you have to callx = x.to(device)), but when working with models, callingmodel = model.to(device)ormodel.to(device)are both perfectly valid.
- Note that
- Errors from having one tensor on cpu and another on cuda are very common. Some useful practices to avoid this:
- Throw in assert statements, to make sure tensors are on the same device
- Remember that when you initialise an array (e.g. with
t.zerosort.arange), it will be on CPU by default. - Tensor methods like
new_zerosornew_fullare useful, because they'll create tensors which match the device and dtype of the base tensor.
It's common practice to put a line like this at the top of your file, defining a global variable which you can use in subsequent modules and functions (excluding the print statement):
device = t.device(
"mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)
# If this is CPU, we recommend figuring out how to get cuda access (or MPS if you're on a Mac).
print(device)
Training loop
Below is a very simple training loop, which you can run to train your model.
In later exercises, we'll try to modularize our training loops. This will involve things like creating a Trainer class which wraps around our model, and giving it methods like training_step and validation_step which correspond to different parts of the training loop. This will make it easier to add features like logging and validation, and will also make our code more readable and easier to refactor. However, for now we've kept things simple.
model = SimpleMLP().to(device)
batch_size = 128
epochs = 3
mnist_trainset, _ = get_mnist()
mnist_trainloader = DataLoader(mnist_trainset, batch_size=batch_size, shuffle=True)
optimizer = t.optim.Adam(model.parameters(), lr=1e-3)
loss_list = []
for epoch in range(epochs):
pbar = tqdm(mnist_trainloader)
for imgs, labels in pbar:
# Move data to device, perform forward pass
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
# Calculate loss, perform backward pass
loss = F.cross_entropy(logits, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Update logs & progress bar
loss_list.append(loss.item())
pbar.set_postfix(epoch=f"{epoch + 1}/{epochs}", loss=f"{loss:.3f}")
line(
loss_list,
x_max=epochs * len(mnist_trainset),
labels={"x": "Examples seen", "y": "Cross entropy loss"},
title="SimpleMLP training on MNIST",
width=700,
)
Click to see the expected output
Let's break down the important parts of this code.
The batch size is the number of samples in each batch (i.e. the number of samples we feed into the model at once). While training our model, we differentiate with respect to the average loss over all samples in the batch (so a smaller batch usually means the loss is more noisy). However, if you're working with large models, then often having a batch size too large will result in a memory error. This will be relevant for models later on in the course, but for now we're working with very small models so this isn't an issue.
Next, we get our training set, via the helper function get_mnist. This helper function used torchvision.datasets.MNIST to load in data, and then (optionally) the torch.utils.data.Subset function to return a subset of this data. Don't worry about the details of this function, it's not the kind of thing you'll need to know by heart.
We then define our optimizer, using torch.optim.Adam. The torch.optim module gives a wide variety of modules, such as Adam, SGD, and RMSProp. Adam is generally the most popular and seen as the most effective in the majority of cases. We'll discuss optimizers in more detail tomorrow, but for now it's enough to understand that the optimizer calculates the amount to update parameters by (as a function of those parameters' gradients, and sometimes other inputs), and performs this update step. The first argument passed to our optimizer is the parameters of our model (because these are the values that will be updated via gradient descent), and you can also pass keyword arguments to the optimizer which change its behaviour (e.g. the learning rate).
Lastly, we have the actual training loop. We iterate through our training data, and for each batch we:
- Evaluate our model on the batch of data, to get the logits for our class predictions,
- Calculate the loss between our logits and the true class labels,
- Backpropagate the loss through our model (this step accumulates gradients in our model parameters),
- Step our optimizer, which is what actually updates the model parameters,
- Zero the gradients of our optimizer, ready for the next step.
Cross entropy loss
The formula for cross entropy loss over a batch of size $N$ is:
where $p_{n, c}$ is the probability the model assigns to class $c$ for sample $n$, and $y_{n}$ is the true label for this sample.
See this dropdown, if you're still confused about this formula, and how this relates to the information-theoretic general formula for cross entropy.
The cross entropy of a distribution $p$ relate to a distribution $q$ is:
In our case, $q$ is the true distribution (i.e. the one-hot encoded labels, which equals one for $n = y_n$, zero otherwise), and $p$ is our model's output. With these subsitutions, this formula becomes equivalent to the formula for $l$ given above.
See this dropdown, if you're confused about how this is the same as the PyTorch definition.
The PyTorch definition of cross entropy loss is:
$w_c$ are the weights (which all equal one by default), $p_{n, c} = \frac{\exp \left(x_{n, c}\right)}{\sum_{i=1}^C \exp \left(x_{n, i}\right)}$ are the probabilities, and $y_{n, c}$ are the true labels (which are one-hot encoded, i.e. their value is one at the correct label $c$ and zero everywhere else). With this, the formula for $l_n$ reduces to the one we see above (i.e. the mean of the negative log probabilities).
The function torch.functional.cross_entropy expects the unnormalized logits as its first input, rather than probabilities. We get probabilities from logits by applying the softmax function:
where $x_{n, c}$ is the model's output for class $c$ and sample $n$, and $C$ is the number of classes (in the case of MNIST, $C = 10$).
Some terminology notes:
-
When we say logits, we mean the output of the model before applying softmax. We can uniquely define a distribution with a set of logits, just like we can define a distribution with a set of probabilities (and sometimes it's easier to think of a distribution in terms of logits, as we'll see later in the course).
-
When we say unnormalized, we mean the denominator term $\sum_{c'} \exp(x_{n, c'})$ isn't necessarily equal to 1. We can add a constant value onto all the logits which makes this term 1 without changing any of the actual probabilities, then we have the relation $p_{n, c} = \exp(-l_{n, c})$. Here, we call $-l_{n, c}$ the log probabilities (or log probs), since $-l_{n, c} = \log p_{n, c}$.
If you're interested in the intuition behind cross entropy as a loss function, see this post on KL divergence (note that KL divergence and cross entropy differ by an amount which is independent of our model's predictions, so minimizing cross entropy is equivalent to minimizing KL divergence). Also see these two videos:
Aside - dataclasses
Sometimes, when we have a lot of different input parameters to our model, it can be helpful to use dataclasses to keep track of them all. Dataclasses are a special kind of class which come with built-in methods for initialising and printing (i.e. no need to define an __init__ or __repr__). Another advantage of using them is autocompletion: when you type in args. in VSCode, you'll get a dropdown of all your different dataclass attributes, which can be useful when you've forgotten what you called a variable!
Here's an example of how we might rewrite our training code above using dataclasses. We've wrapped all the training code inside a single argument called train, which takes a SimpleMLPTrainingArgs object as its only argument.
@dataclass
class SimpleMLPTrainingArgs:
"""
Defining this class implicitly creates an __init__ method, which sets arguments as below, e.g.
self.batch_size=64. Any of these fields can also be overridden when you create an instance, e.g.
SimpleMLPTrainingArgs(batch_size=128).
"""
batch_size: int = 64
epochs: int = 3
learning_rate: float = 1e-3
def train(args: SimpleMLPTrainingArgs) -> tuple[list[float], SimpleMLP]:
"""
Trains & returns the model, using training parameters from the `args` object. Returns the model,
and loss list.
"""
model = SimpleMLP().to(device)
mnist_trainset, _ = get_mnist()
mnist_trainloader = DataLoader(mnist_trainset, batch_size=args.batch_size, shuffle=True)
optimizer = t.optim.Adam(model.parameters(), lr=args.learning_rate)
loss_list = []
for epoch in range(args.epochs):
pbar = tqdm(mnist_trainloader)
for imgs, labels in pbar:
# Move data to device, perform forward pass
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
# Calculate loss, perform backward pass
loss = F.cross_entropy(logits, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Update logs & progress bar
loss_list.append(loss.item())
pbar.set_postfix(epoch=f"{epoch + 1}/{args.epochs}", loss=f"{loss:.3f}")
return loss_list, model
args = SimpleMLPTrainingArgs()
loss_list, model = train(args)
line(
loss_list,
x_max=args.epochs * len(mnist_trainset),
labels={"x": "Examples seen", "y": "Cross entropy loss"},
title="SimpleMLP training on MNIST",
width=700,
)
Click to see the expected output
Exercise - add a validation loop
```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵🔵🔵
You should spend up to ~20 minutes on this exercise. It is very important that you understand training loops and how they work, because we'll be doing a lot of model training in this way. ```
Edit the train function above to include a validation loop. Train your model, making sure you measure the accuracy at the end of each epoch.
Here are a few tips to help you:
- You'll need a dataloader for the testset, just like we did for the trainset. It doesn't matter whether you shuffle the testset or not, because we're not updating our model parameters during validation (we usually set
shuffle=Falsefor testsets).- You can set the same batch size as for your training set (we'll discuss more optimal choices for this later in the course).
- During the validation step, you should be measuring accuracy, which is defined as the fraction of correctly classified images.
- Note that (unlike loss) accuracy should only be logged after you've gone through the whole validation set. This is because your model doesn't update between computing different accuracies, so it doesn't make sense to log all of them separately.
- Computing accuracy is meant to be a very short operation, so you shouldn't need a progress bar.
- You can wrap your forward pass in
with t.inference_mode():to make sure that your model is in inference mode during validation (i.e. gradients don't propagate).
def train(args: SimpleMLPTrainingArgs) -> tuple[list[float], list[float], SimpleMLP]:
"""
Trains the model, using training parameters from the `args` object.
Returns:
The model, and lists of loss & accuracy.
"""
# YOUR CODE HERE - add a validation loop to the train function from above
return loss_list, accuracy_list, model
args = SimpleMLPTrainingArgs()
loss_list, accuracy_list, model = train(args)
line(
y=[loss_list, [0.1] + accuracy_list], # we start by assuming a uniform accuracy of 10%
use_secondary_yaxis=True,
x_max=args.epochs * len(mnist_trainset),
labels={"x": "Num examples seen", "y1": "Cross entropy loss", "y2": "Test Accuracy"},
title="SimpleMLP training on MNIST",
width=800,
)
Click to see the expected output
Help - I'm not sure how to measure correct classifications.
You can take argmax of the output of your model, using torch.argmax (with the keyword argument dim to specify the dimension you want to take max over).
Help - I get RuntimeError: expected scalar type Float but found Byte.
This is commonly because one of your operations is between tensors with the wrong datatypes (e.g. int and float). You can try adding assert or logging statements in your code, or alternatively if you're in VSCode then you can try navigating to the error line and checking your dtypes using VSCode's built-in debugger.
Solution
def train(args: SimpleMLPTrainingArgs) -> tuple[list[float], list[float], SimpleMLP]:
"""
Trains the model, using training parameters from the args object.
Returns:
The model, and lists of loss & accuracy.
"""
model = SimpleMLP().to(device)
mnist_trainset, mnist_testset = get_mnist()
mnist_trainloader = DataLoader(mnist_trainset, batch_size=args.batch_size, shuffle=True)
mnist_testloader = DataLoader(mnist_testset, batch_size=args.batch_size, shuffle=False)
optimizer = t.optim.Adam(model.parameters(), lr=args.learning_rate)
loss_list = []
accuracy_list = []
accuracy = 0.0
for epoch in range(args.epochs):
# Training loop
pbar = tqdm(mnist_trainloader)
for imgs, labels in pbar:
# Move data to device, perform forward pass
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
# Calculate loss, perform backward pass
loss = F.cross_entropy(logits, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Update logs & progress bar
loss_list.append(loss.item())
pbar.set_postfix(epoch=f"{epoch + 1}/{args.epochs}", loss=f"{loss:.3f}")
# Validation loop
num_correct_classifications = 0
for imgs, labels in mnist_testloader:
# Move data to device, perform forward pass in inference mode
imgs, labels = imgs.to(device), labels.to(device)
with t.inference_mode():
logits = model(imgs)
# Compute num correct by comparing argmaxed logits to true labels
predictions = t.argmax(logits, dim=1)
num_correct_classifications += (predictions == labels).sum().item()
# Compute & log total accuracy
accuracy = num_correct_classifications / len(mnist_testset)
accuracy_list.append(accuracy)
return loss_list, accuracy_list, model
You should find that after the first epoch, the model is already doing much better than random chance (i.e. >80%), and it improves slightly in subsequent epochs.