1️⃣ Making your own modules
Learning Objectives
- Learn how to create your own modules in PyTorch, by inheriting from
nn.Module- Assemble the pieces together to create a simple fully-connected network, to classify MNIST digits
Note - from this point on we'll start referring to the PyTorch documentation pages quite a lot. We will also include a lot of content within this material if we want to highlight it for you, however it's also an important skill to be able to use documentation pages to find answers to specific questions & assist you in debugging.
Subclassing nn.Module
One of the most basic parts of PyTorch that you will see over and over is the nn.Module class. All types of neural net components inherit from it, from the simplest nn.Relu to the most complex nn.Transformer. Often, a complex nn.Module will have sub-Modules which implement smaller pieces of its functionality.
Other common Modules you'll see include
nn.Linear, for fully-connected layers with or without a biasnn.Conv2d, for a two-dimensional convolution (we'll see more of these in a future section)nn.Softmax, which implements the softmax function
The list goes on, including activation functions, normalizations, pooling, attention, and more. You can see all the Modules that PyTorch provides here. You can also create your own Modules, as we will do often!
The Module class provides a lot of functionality, but we'll only cover a little bit of it here.
In this section, we'll add another layer of abstraction to all the linear operations we've done in previous sections, by packaging them inside nn.Module objects.
__init__ and forward
A subclass of nn.Module usually looks something like this:
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self, arg1, arg2, ...):
super().__init__()
# Initialization code
def forward(self, x: Tensor) -> Tensor:
# Forward pass code
The initialization sets up attributes that will be used for the life of the Module, like its parameters, hyperparameters, or other sub-Modules it might need to use. These are usually added to the instance with something like self.attribute = attr, where attr might be provided as an argument. Some modules are simple enough that they don't need any persistent attributes, and in this case you can skip the __init__.
The forward method is called on each forward pass of the Module, possibly using the attributes that were set up in the __init__. It should take in the input, do whatever it's supposed to do, and return the result. Subclassing nn.Module automatically makes instances of your class callable, so you can do model(x) on an input x to invoke the forward method.
The nn.Parameter class
A nn.Parameter is a special type of Tensor. Basically, this is the class that torch has provided for storing the weights and biases of a Module. It has some special properties for doing this:
- If a
Parameteris set as an attribute of aModule, it will be auto-detected by torch and returned when you callmodule.parameters()(along with all the otherParametersassociated with theModule, or any of theModule's sub-modules!). - This makes it easy to pass all the parameters of a model into an optimizer and update them all at once.
When you create a Module that has weights or biases, be sure to wrap them in nn.Parameter so that torch can detect and update them appropriately:
class MyModule(nn.Module):
def __init__(self, weights: Tensor, biases: Tensor):
super().__init__()
self.weights = nn.Parameter(weights) # wrapping a tensor in nn.Parameter
self.biases = nn.Parameter(biases)
Printing information with extra_repr
Another useful method is called extra_repr. This allows you to format the string representation of your Module in a way that's more informative than the default. For example, the following:
class MyModule(nn.Module):
def __init__(self, arg1, arg2, ...):
super().__init__()
# Initialization code
def extra_repr(self) -> str:
return f"arg1={self.arg1}, arg2={self.arg2}, ..."
will result in the output "MyModule(arg1=arg1, arg2=arg2, ...)" when you print an instance of this module. You might want to take this opportunity to print out useful invariant information about the module. The Python built-in function getattr might be helpful here (it can be used e.g. as getattr(self, "arg1"), which returns the same as self.arg1 would). For simple modules, it's fine not to implement extra_repr.
ReLU
The first module you should implement is ReLU. This will relatively simple, since it doesn't involve any argument (so we only need to implement forward). Make sure you look at the PyTorch documentation page for ReLU so that you're comfortable with how they work.
ReLU is defined as the element-wise maximum between the input and a tensor of zeros. It's one of the simplest types of nonlinear activation functions. These are essential because linear operations compose to make more linear operations, which is very limiting. On the other hand, the universal approximation theorem tells us that we can approximate any continuous function using a sufficiently large neural network, if we use nonlinear activation functions. It's worth emphasizing that the theory of the UAT and what networks look like in practice are very different - in particular, many versions of the UAT are based on a shallow but extremely wide neural network, on the other hand most of the power of modern neural networks comes from their ability to compose between layers: feeding the output of one layer into the input of another, and create increasingly expressive functions. We'll explore this idea more when we study circuits in next week's interpretability material.
Exercise - implement ReLU
```yaml Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵⚪⚪
You should spend up to ~10 minutes on this exercise. ```
You should fill in the forward method of the ReLU class below.
class ReLU(nn.Module):
def forward(self, x: Tensor) -> Tensor:
raise NotImplementedError()
tests.test_relu(ReLU)
Solution
class ReLU(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return t.maximum(x, t.tensor(0.0))
Linear
Now implement your own Linear module. This applies a simple linear transformation, with a weight matrix and optional bias vector. The PyTorch documentation page is here. Note that this is the first Module you'll implement that has learnable weights and biases.
Question - what type do you think these variables should be?
They have to be torch.Tensor objects wrapped in nn.Parameter in order for nn.Module to recognize them. If you forget to do this, module.parameters() won't include your Parameter, which prevents an optimizer from being able to modify it during training.
Also, in tomorrow's exercises we'll be building a ResNet and loading in weights from a pretrained model, and this is hard to do if you haven't registered all your parameters!
For any layer, initialization is very important for the stability of training: with a bad initialization, your model will take much longer to converge or may completely fail to learn anything. The default PyTorch behavior isn't necessarily optimal and you can often improve performance by using something more custom, but we'll follow it for today because it's simple and works decently well.
Each float in the weight and bias tensors are drawn independently from the uniform distribution on the interval:
where $N_{in}$ is the number of inputs contributing to each output value. The rough intuition for this is that it keeps the variance of the activations at each layer constant, since each one is calculated by taking the sum over $N_{in}$ inputs multiplied by the weights (and standard deviation of the sum of independent random variables scales as the square root of number of variables).
This initialization technique is called uniform Kaiming initialization. A few last notes on initialization methods:
- Kaiming often has a different constant in the numerator depending on what the target variance is, also there are uniform & normal variants of it (we'll only be using the uniform variant)
- Xavier initialization is the other well-known technique, and differs in that it uses $N_{in} + N_{out}$ in the denominator (this makes sense when also considering variance scaling of backward passes as well as forward passes - see the next dropdown for technical details)
Technical details (derivation of distribution)
The key intuition behind Kaiming initialisation (and others like it) is that we want the variance of our activations to be the same through all layers of the model when we initialize. Suppose $x$ and $y$ are activations from two adjacent layers, and $w$ are the weights connecting them (so we have $y_i = \sum_j w_{ij} x_j + b_i$, where $b$ is the bias). With $N_{x}$ as the number of neurons in layer $x$, we have:
For this to be the same as $\sigma_x^2$, we need $\operatorname{Var}(w_{ij}) = \frac{1}{N_x}$, so the standard deviation is $\frac{1}{\sqrt{N_x}}$.
This is not exactly the case for the Kaiming uniform distribution (which has variance $\frac{12}{(2 \sqrt{N_x})^2} = \frac{3}{N_x}$), and as far as I'm aware there's no principled reason why PyTorch does this. But the most important thing is that the variance scales as $O(1 / N_x)$, rather than what the exact scaling constant is.
There are other initializations with some theoretical justification. For instance, Xavier initialization has a uniform distribution in the interval:
which is motivated by the idea of both keeping the variance of activations constant and keeping the gradients constant when we backpropagate.
However, you don't need to worry about any of this here, just implement Kaiming He uniform with a bound of $\frac{1}{\sqrt{N_{in}}}$!
Exercise - implement Linear
```yaml Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵🔵⚪
You should spend up to ~10 minutes on this exercise. ```
Remember, you should define the weights (and bias, if bias=True) in the __init__ block. Also, make sure not to mix up bias (which is the boolean parameter to __init__) and self.bias (which should either be the actual bias tensor, or None if bias is false).
You should also fill in forward (which will multiply the input by the weight matrix and add the bias, if present).
Lastly, you should fill in extra_repr to give a string representation of the Linear module. There are no tests for this method, you should just make sure it's suitably informative (this will help when printing out your model later on).
class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias=True):
"""
A simple linear (technically, affine) transformation.
The fields should be named `weight` and `bias` for compatibility with PyTorch.
If `bias` is False, set `self.bias` to None.
"""
super().__init__()
raise NotImplementedError()
def forward(self, x: Tensor) -> Tensor:
"""
x: shape (*, in_features)
Return: shape (*, out_features)
"""
raise NotImplementedError()
def extra_repr(self) -> str:
raise NotImplementedError()
tests.test_linear_parameters(Linear, bias=False)
tests.test_linear_parameters(Linear, bias=True)
tests.test_linear_forward(Linear, bias=False)
tests.test_linear_forward(Linear, bias=True)
Help - when I print my Linear module, it also prints a large tensor.
This is because you've (correctly) defined self.bias as either torch.Tensor or None, rather than set it to the boolean value of bias used in initialisation.
To fix this, you will need to change extra_repr so that it prints the boolean value of bias rather than the value of self.bias.
Solution
class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias=True):
"""
A simple linear (technically, affine) transformation.
The fields should be named weight and bias for compatibility with PyTorch.
If bias is False, set self.bias to None.
"""
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.bias = bias
sf = 1 / np.sqrt(in_features)
weight = sf (2 t.rand(out_features, in_features) - 1)
self.weight = nn.Parameter(weight)
if bias:
bias = sf (2 t.rand(out_features) - 1)
self.bias = nn.Parameter(bias)
else:
self.bias = None
def forward(self, x: Tensor) -> Tensor:
"""
x: shape (, in_features)
Return: shape (, out_features)
"""
x = einops.einsum(x, self.weight, "... in_feats, out_feats in_feats -> ... out_feats")
if self.bias is not None:
x += self.bias
return x
def extra_repr(self) -> str:
# note, we need to use self.bias is not None, because self.bias is either a tensor or
# None, not bool
return (
f"in_features={self.in_features}, out_features={self.out_features}, "
f"bias={self.bias is not None}"
)
Flatten
Lastly, we've given you the Flatten module rather than including it as an exercise (because it's simple but quite finnicky to implement). This is a standardised way to rearrange our tensors so that they can be fed into a linear layer. It's a bit like einops.rearrange, but more specialised and less flexible (it flattens over some contiguous range of dimensions, rather than allowing for general reshape operations). By default we use Flatten(start_dim=1, end_dim=-1) which means flattening over the dimensions from input.shape[1:], in other words over all except the batch dimension.
Make sure you understand what this module is doing before moving on.
class Flatten(nn.Module):
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
super().__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, input: Tensor) -> Tensor:
"""
Flatten out dimensions from start_dim to end_dim, inclusive of both.
"""
shape = input.shape
# Get start & end dims, handling negative indexing for end dim
start_dim = self.start_dim
end_dim = self.end_dim if self.end_dim >= 0 else len(shape) + self.end_dim
# Get the shapes to the left / right of flattened dims, as well as size of flattened middle
shape_left = shape[:start_dim]
shape_right = shape[end_dim + 1 :]
shape_middle = t.prod(t.tensor(shape[start_dim : end_dim + 1])).item()
return t.reshape(input, shape_left + (shape_middle,) + shape_right)
def extra_repr(self) -> str:
return ", ".join([f"{key}={getattr(self, key)}" for key in ["start_dim", "end_dim"]])
Simple Multi-Layer Perceptron
Now, we can put together these two modules to create a neural network. We'll create one of the simplest networks which can be used to separate data that is non-linearly separable: a single linear layer, followed by a nonlinear function (ReLU), followed by another linear layer. This type of architecture (alternating linear layers and nonlinear functions) is often called a multi-layer perceptron (MLP).
The output of this network will have 10 dimensions, corresponding to the 10 classes of MNIST digits. We can then use the softmax function $x_i \to \frac{e^{x_i}}{\sum_i e^{x_i}}$ to turn these values into probabilities. However, it's common practice for the output of a neural network to be the values before we take softmax, rather than after. We call these pre-softmax values the logits.
Question - can you see what makes logits non-unique (i.e. why any given set of probabilities might correspond to several different possible sets of logits)?
Logits are translation invariant. If you add some constant $c$ to all logits $x_i$, then the new probabilities are:
in other words, the probabilities don't change.
We can define logprobs as the log of the probabilities, i.e. $y_i = \log p_i$. Unlike logits, these are uniquely defined.
Exercise - implement the simple MLP
```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵🔵⚪
You should spend up to ~20 minutes on this exercise. ```
The diagram below shows what your MLP should look like:
Please ask a TA (or message the Slack group) if any part of this diagram is unclear.
class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
raise NotImplementedError()
def forward(self, x: Tensor) -> Tensor:
raise NotImplementedError()
tests.test_mlp_module(SimpleMLP)
tests.test_mlp_forward(SimpleMLP)
Solution
class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
self.flatten = Flatten()
self.linear1 = Linear(in_features=28 * 28, out_features=100)
self.relu = ReLU()
self.linear2 = Linear(in_features=100, out_features=10)
def forward(self, x: Tensor) -> Tensor:
return self.linear2(self.relu(self.linear1(self.flatten(x))))
In the next section, we'll learn how to train and evaluate our model on real data.