4️⃣ ResNets

Learning Objectives
  • Learn about skip connections, and how they help overcome the degradation problem
  • Learn about batch normalization, and why it is used in training
  • Assemble your own ResNet, and load in weights from PyTorch's ResNet implementation

Reading

You should move on once you can answer the following questions:

"Batch Normalization allows us to be less careful about initialization." Explain this statement.

Weight initialisation methods like Xavier (which we encountered yesterday) are based on the idea of making sure the activations have approximately the same distribution across layers at initialisation. But batch normalization ensures that this is the case as signals pass through the network.

Give at least 2 reasons why batch normalization improves the performance of neural networks.

Reasons you can give here include:

Input normalization avoids extreme activation values, which helps stabilize gradient-based optimization methods. Internal covariate shift is reduced, i.e. the mean and standard deviation is kept constant across the layers. * Regularisation effect: noise internal to each minibatch is reduced.

Note, some of these points overlap because they gesture to the same underlying ideas.

If you have an input tensor of size (batch, channels, width, height), and you apply a batchnorm layer, how many learned parameters will there be?

A mean and standard deviation is calculated for each channel (i.e. each calculation is done across the batch, width, and height dimensions). So the number of learned params will be 2 * channels.

In the paper, the diagram shows additive skip connections (i.e. F(x) + x). One can also form concatenated skip connections, by "gluing together" F(x) and x into a single tensor. Give one advantage and one disadvantage of these, relative to additive connections.

One advantage of concatenation: the subsequent layers can re-use middle representations; maintaining more information which can lead to better performance. Also, this still works if the tensors aren't exactly the same shape. One disadvantage: less compact, so there may be more weights to learn in subsequent layers.

Crucially, both the addition and concatenation methods have the property of preserving information, to at least some degree of fidelity. For instance, you can [use calculus to show](https://theaisummer.com/skip-connections/#:~:text=residual%20skip%20connections.-,ResNet%3A%20skip%20connections%C2%A0via%C2%A0addition,-The%20core%20idea) that both methods will fix the vanishing gradients problem.

In this section, we'll do a more advanced version of the exercise in part 1. Rather than building a relatively simple network in which computation can be easily represented by a sequence of simple layers, we're going to build a more complex architecture which requires us to define nested blocks.

We'll start by defining a few more nn.Module objects, which we hadn't needed before.

Sequential

Firstly, now that we're working with large and complex architectures, we should create a version of nn.Sequential. As the name suggests, when an nn.Sequential is fed an input, it sequentially applies each of its submodules to the input, with the output from one module feeding into the next one.

The implementation is given to you below. A few notes:

  • In initalization, we add to the _modules dictionary.
    • This is a special type of dict called an ordered dictionary, which preserves the order of elements that get added (although Python sort-of does this now by default).
    • When we call self.parameters(), this recursively goes through all modules in self._modules, and returns the params in those modules. This means we can nest sequentials within sequentials!
  • The special __getitem__ and __setitem__ methods determine behaviour when we get and set modules within the sequential.
  • The repr of the base class nn.Module already recursively prints out the submodules, so we don't need to write anything in extra_repr.
    • To see how this works in practice, try defining a Sequential which takes a sequence of modules that you've defined above, and see what it looks like when you print it.

Don't worry about deeply understanding this code. The main takeaway is that nn.Sequential is a useful list-like object to store modules, and apply them all sequentially.

Aside - initializing Sequential with an OrderedDict

The actual nn.Sequential module can be initialized with an ordered dictionary, rather than a list of modules. For instance, rather than doing this:

seq = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 30)
)

we can do this:

from collections import OrderedDict
seq = nn.Sequential(OrderedDict([
    ("linear1", nn.Linear(10, 20)),
    ("relu", nn.ReLU()),
    ("linear2", nn.Linear(20, 30))
]))

This is handy if we want to give each module an descriptive name.

The Sequential implementation below doesn't allow the input to be an OrderedDict. As a bonus exercise, can you rewrite the __init__, __getitem__ and __setitem__ methods to allow the input to be an OrderedDict? If you do this, you'll actually be able to match your eventual ResNet model names exactly to the PyTorch implementation.

class Sequential(nn.Module):
    _modules: dict[str, nn.Module]

    def __init__(self, *modules: nn.Module):
        super().__init__()
        for index, mod in enumerate(modules):
            self._modules[str(index)] = mod

    def __getitem__(self, index: int) -> nn.Module:
        index %= len(self._modules)  # deal with negative indices
        return self._modules[str(index)]

    def __setitem__(self, index: int, module: nn.Module) -> None:
        index %= len(self._modules)  # deal with negative indices
        self._modules[str(index)] = module

    def forward(self, x: Tensor) -> Tensor:
        """Chain each module together, with the output from one feeding into the next one."""
        for mod in self._modules.values():
            x = mod(x)
        return x

BatchNorm2d

Now, we'll implement our BatchNorm2d, the layer described in the reading material you hopefully read above. You'll be implementing it according to the PyTorch docs (with affine=True and track_running_stats=True).

The primary function of batchnorm is to normalize the activations of each layer within the neural network during training. It normalizes each batch of input data to have a mean of 0 and std dev of 1. This normalization helps mitigate the internal covariate shift problem, which refers to the change in the distribution of layer inputs as the network trains. This becomes a particularly big problem as we build deeper networks, because there's more opportunity for the activation distribution to change over time.

Buffers

A question that might have occurred to you as you read about batchnorm - how does averaging over input data work in inference mode, if you only have a single input rather than a batch? The answer is that during training mode we compute a running average of our data's mean and variance, and we use this running average in inference mode.

How do we store these moving averages? We want them to be saved and loaded with the model (because we need these values in order to run our model), but we don't want to update them using gradient descent (so we don't want to use nn.Parameter). So instead, we use the Pytorch buffers feature. These are essentially tensors which are included in model.state_dict() (and so they're saved & loaded with the rest of the model) but not included in model.parameters().

You can create a buffer by calling self.register_buffer from inside a nn.Module. We've initialized the necessary buffers for you in the __init__ method below - you'll need a running mean and variance, as well as a counter for the number of batches seen (technically this isn't strictly necessary because the running mean & variance are updated using an exponential moving average so the update rule is independent of the number of previous updates, but we're doing this so our state dict matches the PyTorch implementation).

Train and Eval Modes

Okay so we have buffers, but how can we make them behave differently in different modes - i.e. updating the running mean & variance in training mode, and using the stored values in eval mode? The answer is that we use the training method of the nn.Module class, which is a boolean attribute that gets flipped when we call self.eval() or self.train(). In the case of batch norm, your code should look like this:

if self.training:
    # Use this data's mean & variance to normalize, then use it to update the buffers
else:
    # Use the buffer mean & variance to normalize

The other commonly used module which has different behaviour in training and eval modes is Dropout - in eval mode this module uses all its inputs, but in training it randomly selects some fraction 1 - p of the input values to zero out and scales the remaining values by 1 / (1 - p).

Note that other normalization modules we'll address later in this course like LayerNorm don't have different behaviour in training and eval modes, because these don't normalize over the batch dimension.

Exercise - implement BatchNorm2d

```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵⚪⚪

You should spend up to 15-30 minutes on this exercise. ```

Implement BatchNorm2d according to the PyTorch docs. We're implementing it with affine=True and track_running_stats=True. All the parameters are defined for you in the __init__ method, your job will be to fill in the forward and extra_repr methods.

A few final tips:

  • Remember to use weight and bias in the fwd pass, after normalizing. You should multiply by weight and add bias.
  • All your tensors (weight, bias, running_mean and running_var) are vectors of length num_features, this should help you figure out what dimensions you're operating on.
  • Remember that the shape of x is (batch, num_features, height, width) which doesn't broadcast with (num_features,). The easiest way to fix this is to reshape the latter to something like (1, num_features, 1, 1), or optionally just (num_features, 1, 1).
class BatchNorm2d(nn.Module):
    # The type hints below aren't functional, they're just for documentation
    running_mean: Float[Tensor, "num_features"]
    running_var: Float[Tensor, "num_features"]
    num_batches_tracked: Int[Tensor, ""]  # This is how we denote a scalar tensor

    def __init__(self, num_features: int, eps=1e-05, momentum=0.1):
        """
        Like nn.BatchNorm2d with track_running_stats=True and affine=True.

        Name the learnable affine parameters `weight` and `bias` in that order.
        """
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum

        self.weight = nn.Parameter(t.ones(num_features))
        self.bias = nn.Parameter(t.zeros(num_features))

        self.register_buffer("running_mean", t.zeros(num_features))
        self.register_buffer("running_var", t.ones(num_features))
        self.register_buffer("num_batches_tracked", t.tensor(0))

    def forward(self, x: Tensor) -> Tensor:
        """
        Normalize each channel.

        Compute the variance using `torch.var(x, unbiased=False)`
        Hint: you may also find it helpful to use the argument `keepdim`.

        x: shape (batch, channels, height, width)
        Return: shape (batch, channels, height, width)
        """
        raise NotImplementedError()

    def extra_repr(self) -> str:
        raise NotImplementedError()


tests.test_batchnorm2d_module(BatchNorm2d)
tests.test_batchnorm2d_forward(BatchNorm2d)
tests.test_batchnorm2d_running_mean(BatchNorm2d)
Help - I'm stuck on this implementation, and need a template.

The easiest way is to structure it like this (we've omitted the reshaping to make sure the mean & variance broadcasts correctly):

if self.training:
    mean = ... # mean of new data
    var = ... # variance of new data
    self.running_mean = ... # update running mean using exponential moving average
    self.running_var = ... # update running variance using exponential moving average
    self.num_batches_tracked += 1
else:
    mean = self.running_mean
    var = self.running_var
x_normed = ... # normalize x using mean and var (make sure mean and var are broadcastable with x)
x_affine = ... # apply affine transformation from self.weight and self.bias (again, be careful of broadcasting)
return x_affine
Help - I'm not sure how to implement the running_mean and running_var formula

To track the running mean, we use an exponentially weighted moving average. The formula for this is as follows, at step $T$ the moving average is given by

$$\sum_{t=1}^{T} \mu (1-\mu)^{T-t} \cdot \text{mean}_{t}.$$
We implement the exponential moving average for the running variance using the same formula.

Solution
def forward(self, x: Tensor) -> Tensor:
    """
    Normalize each channel.
    Compute the variance using torch.var(x, unbiased=False)
    Hint: you may also find it helpful to use the argument keepdim.
    x: shape (batch, channels, height, width)
    Return: shape (batch, channels, height, width)
    """
    # Calculating mean and var over all dims except for the channel dim
    if self.training:
        # Take mean over all dimensions except the feature dimension
        mean = x.mean(dim=(0, 2, 3))
        var = x.var(dim=(0, 2, 3), unbiased=False)
        # Updating running mean and variance, in line with PyTorch documentation
        self.running_mean = (1 - self.momentum)  self.running_mean + self.momentum  mean
        self.running_var = (1 - self.momentum)  self.running_var + self.momentum  var
        self.num_batches_tracked += 1
    else:
        mean = self.running_mean
        var = self.running_var
# Rearranging these so they can be broadcasted
    reshape = lambda x: einops.rearrange(x, "channels -> 1 channels 1 1")
# Normalize, then apply affine transformation from self.weight & self.bias
    x_normed = (x - reshape(mean)) / (reshape(var) + self.eps).sqrt()
    x_affine = x_normed * reshape(self.weight) + reshape(self.bias)
    return x_affine

AveragePool

Let's end our collection of nn.Modules with an easy one 🙂

The ResNet has a Linear layer with 1000 outputs at the end in order to produce classification logits for each of the 1000 classes. Any Linear needs to have a constant number of input features, but the ResNet is supposed to be compatible with arbitrary height and width, so we can't just do a pooling operation with a fixed kernel size and stride.

Luckily, the simplest possible solution works decently: take the mean over the spatial dimensions. Intuitively, each position has an equal "vote" for what objects it can "see".

Exercise - implement AveragePool

```yaml Difficulty: 🔴⚪⚪⚪⚪ Importance: 🔵🔵⚪⚪⚪

You should spend up to 5-10 minutes on this exercise. ```

This should be a pretty straightforward implementation; it doesn't have any weights or parameters of any kind, so you only need to implement the forward method.

class AveragePool(nn.Module):
    def forward(self, x: Tensor) -> Tensor:
        """
        x: shape (batch, channels, height, width)
        Return: shape (batch, channels)
        """
        raise NotImplementedError()


tests.test_averagepool(AveragePool)
Solution
class AveragePool(nn.Module):
    def forward(self, x: Tensor) -> Tensor:
        """
        x: shape (batch, channels, height, width)
        Return: shape (batch, channels)
        """
        return t.mean(x, dim=(2, 3))

Building ResNet

Now we have all the building blocks we need to start assembling your own ResNet! The following diagram describes the architecture of ResNet34 - the other versions are broadly similar.

Note - unless otherwise noted, you should assume convolutions have kernel_size=3, stride=1, padding=1 (this is a shape preserving convolution i.e. the width & height of the input and output will be the same). None of the convolutions have biases.

You don't have to understand every detail in this diagram before proceeding; specific points will be clarified as we go through each exercise.

Question: why do we not care about including biases in the convolutional layers?

Every convolution layer in this network is followed by a batch normalization layer. The first operation in the batch normalization layer is to subtract the mean of each output channel. But a convolutional bias just adds some scalar b to each output channel, increasing the mean by b. This means that for any b added, the batch normalization will subtract b to exactly negate the bias term.

Help - I'm confused about how the nested subgraphs work.

The right-most block in the diagram, ResidualBlock, is nested inside BlockGroup multiple times. When you see ResidualBlock in BlockGroup, you should visualise a copy of ResidualBlock sitting in that position.

Similarly, BlockGroup is nested multiple times (four to be precise) in the full ResNet34 architecture.

Exercise - implement ResidualBlock

```yaml Difficulty: 🔴🔴🔴🔴⚪ Importance: 🔵🔵🔵🔵⚪

You should spend up to 20-30 minutes on this exercise. ```

Implement ResidualBlock by referring to the diagram (i.e. the right-most of the three hierarchical diagrams above).

The left branch starts with a strided convolution which changes the number of features from in_feats to out_feats. It has all conv parameters default i.e. kernel_size=3, stride=1, padding=1 except for the stride which is instead given by first_stride. The second convolution has all default parameters, and maps from out_feats to out_feats (meaning it's fully shape preserving).

As for the right branch - this is meant to essentially be a skip connection, the problem is we can't just use a skip connection because the shapes might not match up (and so we couldn't add them together at the end). The left branch is fully shape preserving if and only if first_stride == 1 and in_feats == out_feats. If this is true then we do set the right branch to be the identity (that's what the "OPTIONAL" annotation refers to), but if this isn't true then we set the right branch to be a 1x1 convolution with stride first_stride, zero padding, and mapping from in_feats to out_feats, followed by a batchnorm layer. This is in a sense the simplest operation we can get which matches the left branch shape, since the convolution is basically just a downsampling operation (keeping pixels based on a ::first_stride slice across the height and width dimensions).

class ResidualBlock(nn.Module):
    def __init__(self, in_feats: int, out_feats: int, first_stride=1):
        """
        A single residual block with optional downsampling.

        For compatibility with the pretrained model, declare the left side branch first using a
        `Sequential`.

        If first_stride is > 1, this means the optional (conv + bn) should be present on the right
        branch. Declare it second using another `Sequential`.
        """
        super().__init__()
        is_shape_preserving = (first_stride == 1) and (
            in_feats == out_feats
        )  # determines if right branch is identity

        raise NotImplementedError()

    def forward(self, x: Tensor) -> Tensor:
        """
        Compute the forward pass. If no downsampling block is present, the addition should just add
        the left branch's output to the input.

        x: shape (batch, in_feats, height, width)

        Return: shape (batch, out_feats, height / stride, width / stride)
        """
        raise NotImplementedError()


tests.test_residual_block(ResidualBlock)
Solution
class ResidualBlock(nn.Module):
    def __init__(self, in_feats: int, out_feats: int, first_stride=1):
        """
        A single residual block with optional downsampling.
        For compatibility with the pretrained model, declare the left side branch first using a
        Sequential.
        If first_stride is > 1, this means the optional (conv + bn) should be present on the right
        branch. Declare it second using another Sequential.
        """
        super().__init__()
        is_shape_preserving = (first_stride == 1) and (
            in_feats == out_feats
        )  # determines if right branch is identity
self.left = Sequential(
            Conv2d(in_feats, out_feats, kernel_size=3, stride=first_stride, padding=1),
            BatchNorm2d(out_feats),
            ReLU(),
            Conv2d(out_feats, out_feats, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(out_feats),
        )
        self.right = (
            nn.Identity()
            if is_shape_preserving
            else Sequential(
                Conv2d(in_feats, out_feats, kernel_size=1, stride=first_stride),
                BatchNorm2d(out_feats),
            )
        )
        self.relu = ReLU()
def forward(self, x: Tensor) -> Tensor:
        """
        Compute the forward pass. If no downsampling block is present, the addition should just add
        the left branch's output to the input.
        x: shape (batch, in_feats, height, width)
        Return: shape (batch, out_feats, height / stride, width / stride)
        """
        x_left = self.left(x)
        x_right = self.right(x)
        return self.relu(x_left + x_right)

Exercise - implement BlockGroup

```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵🔵⚪

You should spend up to 10-15 minutes on this exercise. ```

Implement BlockGroup according to the diagram. There should be n_blocks total blocks in the group. Only the first block has the possibility of having a right branch (because we might have either first_stride > 1 or in_feats != out_feats), but every subsequent block will have the identity instead of a right branch.

Help - I don't understand why all blocks after the first one won't have a right branch.

- The first_stride argument only gets applied to the first block, definitionally (i.e. the purpose of the BlockGroup is to downsample the input by first_stride just once, not on every single block). - After we pass through the first block we can guarantee that the number of channels will be out_feats, so every subsequent block will have out_feats input channels and out_feats output channels.

Combining these two facts, we see that every subsequent block will have a shape-preserving left branch, so it can have the identity as its right branch.

class BlockGroup(nn.Module):
    def __init__(self, n_blocks: int, in_feats: int, out_feats: int, first_stride=1):
        """
        An n_blocks-long sequence of ResidualBlock where only the first block uses the provided
        stride.
        """
        super().__init__()
        # YOUR CODE HERE - define all components of block group
        raise NotImplementedError()

    def forward(self, x: Tensor) -> Tensor:
        """
        Compute the forward pass.

        x: shape (batch, in_feats, height, width)

        Return: shape (batch, out_feats, height / first_stride, width / first_stride)
        """
        raise NotImplementedError()


tests.test_block_group(BlockGroup)
Solution
class BlockGroup(nn.Module):
    def __init__(self, n_blocks: int, in_feats: int, out_feats: int, first_stride=1):
        """
        An n_blocks-long sequence of ResidualBlock where only the first block uses the provided
        stride.
        """
        super().__init__()
        self.blocks = Sequential(
            ResidualBlock(in_feats, out_feats, first_stride),
            *[ResidualBlock(out_feats, out_feats) for _ in range(n_blocks - 1)],
        )
def forward(self, x: Tensor) -> Tensor:
        """
        Compute the forward pass.
        x: shape (batch, in_feats, height, width)
        Return: shape (batch, out_feats, height / first_stride, width / first_stride)
        """
        return self.blocks(x)

Exercise - implement ResNet34

```yaml Difficulty: 🔴🔴🔴🔴⚪ Importance: 🔵🔵🔵🔵⚪

You should spend up to 30-45 minutes on this exercise. This can sometimes involve a lot of fiddly debugging. ```

Last step! Assemble ResNet34 using the diagram.

To test your implementation, you can use the helper function print_param_count which prints out a stylized dataframe comparing your model's parameter count to the PyTorch implementation. Alternatively, you can use the following code to import your own resnet34, and inspect its architecture:

resnet = models.resnet34()
print(torchinfo.summary(resnet, input_size=(1, 3, 64, 64)))
print(torchinfo.summary(my_resnet, input_size=(1, 3, 64, 64)))

Both will give you the shape & size of each of your model's parameters & buffers, and code is provided for both of these methods below.

Note - in order to copy weights from the reference model to your implementation (which we'll do after this exercise), you'll need to have all the parameters defined in the same order as they are in the reference model - in other words, the rows from the two halves of the dataframe created via print_param_count should perfectly match up with each other. This can be a bit fiddly to get right, especially if the names of your parameters are different to the names in the PyTorch implementation. We recommend you look at the __init__ methods of the solution if you're stuck (since it's the order that things are defined in for the various ResNet modules which determines the order of the rows in the dataframe).

This 1-to-1 weight comparison won't always be possible during model replications, for example when we replicate GPT2-Small next week we'll be defining the attention weight matrices differently (in a way that's more condusive to interpretability research). In these cases, you'll need to resort to different debugging methods, like running the models on the same input and checking they give the same output. You can also break this down into smaller steps by running individual models, and by checking the shape before checking values. However in this case we don't need to resort to that, because our implementation is equivalent to the reference model's implementation.

As a more general point, tweaking your model until all the layers match up might be a difficult and frustrating exercise at times, however it's a pretty good example of the kind of low-level model implementation and debugging that is important for your growth as ML engineers! So don't be disheartened if you find it hard to get exactly right (although we certainly recommend looking at the solutions and moving on if you're stuck on this particular exercise for more than ~45 minutes).

class ResNet34(nn.Module):
    def __init__(
        self,
        n_blocks_per_group=[3, 4, 6, 3],
        out_features_per_group=[64, 128, 256, 512],
        first_strides_per_group=[1, 2, 2, 2],
        n_classes=1000,
    ):
        super().__init__()
        out_feats0 = 64
        self.n_blocks_per_group = n_blocks_per_group
        self.out_features_per_group = out_features_per_group
        self.first_strides_per_group = first_strides_per_group
        self.n_classes = n_classes

        # YOUR CODE HERE - define all components of resnet34
        raise NotImplementedError()

    def forward(self, x: Tensor) -> Tensor:
        """
        x: shape (batch, channels, height, width)
        Return: shape (batch, n_classes)
        """
        raise NotImplementedError()


my_resnet = ResNet34()

# (1) Test via helper function `print_param_count`
target_resnet = (
    models.resnet34()
)  # without supplying a `weights` argument, we just initialize with random weights
utils.print_param_count(my_resnet, target_resnet)

# (2) Test via `torchinfo.summary`
print("My model:", torchinfo.summary(my_resnet, input_size=(1, 3, 64, 64)), sep="\n")
print(
    "\nReference model:",
    torchinfo.summary(target_resnet, input_size=(1, 3, 64, 64), depth=2),
    sep="\n",
)
Help - I'm not sure how to construct each of the BlockGroups.

Each BlockGroup takes arguments n_blocks, in_feats, out_feats and first_stride. In the initialisation of ResNet34 below, we're given a list of n_blocks, out_feats and first_stride for each of the BlockGroups. To find in_feats for each block, it suffices to note two things:

1. The first in_feats should be 64, because the input is coming from the convolutional layer with 64 output channels. 2. The out_feats of each layer should be equal to the in_feats of the subsequent layer (because the BlockGroups are stacked one after the other; with no operations in between to change the shape).

You can use these two facts to construct a list in_features_per_group, and then create your BlockGroups by zipping through all four lists.

Help - I'm not sure how to construct the 7x7 conv at the very start.

The stride, padding & output channels are givin in the diagram; the only thing not provided is in_channels. Recall that the input to this layer is an RGB image - can you deduce from this how many input channels your layer should have?

Help - I'm getting the right total parameter count, but my rows don't match up, and I'm not sure how to debug this.

We'll use an example case to illustrate how to debug this. In the following case, our rows match up until the 21st row where we have our first discrepancy:

We can see that the first discrepancy occurs at the first parameter from residual_layers.1, meaning something in the second BlockGroup in our sequential of blockgroups. We can see that the first blockgroup only had left branches but no right branches (this is because for the very first blockgroup we had in_feats == out_feats == 64 and also first_strides_per_group[0] == 1, meaning this first blockgroup was shape-preserving and it didn't need a right branch). So it's the presence of a right branch that's causing the mismatch.

Looking closer at the dataframe, we see that the left-hand parameter (from our model) has shape (128, 64, 1, 1) and has right in its name, so we deduce it's the 1x1 convolutional weight from the right branch. But the parameter from the PyTorch model has shape (128, 64, 3, 3), i.e. it's a convolutional weight with a 3x3 kernel, so must be from the left branch (it also matches the naming convention for the left-branch convolutional weight from the first blockgroup - row index 3 in the dataframe). So we've now figured out what the problem is: your implementation defines the right branch before the left branch in the the ResidualBlock.__init__ method, and to match param orders with the PyTorch model you should swap them around.

Solution
class ResNet34(nn.Module):
    def __init__(
        self,
        n_blocks_per_group=[3, 4, 6, 3],
        out_features_per_group=[64, 128, 256, 512],
        first_strides_per_group=[1, 2, 2, 2],
        n_classes=1000,
    ):
        super().__init__()
        out_feats0 = 64
        self.n_blocks_per_group = n_blocks_per_group
        self.out_features_per_group = out_features_per_group
        self.first_strides_per_group = first_strides_per_group
        self.n_classes = n_classes
self.in_layers = Sequential(
            Conv2d(3, out_feats0, kernel_size=7, stride=2, padding=3),
            BatchNorm2d(out_feats0),
            ReLU(),
            MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
residual_layers = []
        for i in range(len(n_blocks_per_group)):
            residual_layers.append(
                BlockGroup(
                    n_blocks=n_blocks_per_group[i],
                    in_feats=[64, self.out_features_per_group][i],
                    out_feats=self.out_features_per_group[i],
                    first_stride=self.first_strides_per_group[i],
                )
            )
        self.residual_layers = Sequential(residual_layers)
self.out_layers = Sequential(
            AveragePool(),
            Linear(out_features_per_group[-1], n_classes),
        )
def forward(self, x: Tensor) -> Tensor:
        """
        x: shape (batch, channels, height, width)
        Return: shape (batch, n_classes)
        """
        post_first_conv_block = self.in_layers(x)
        post_block_groups = self.residual_layers(post_first_conv_block)
        logits = self.out_layers(post_block_groups)
        return logits

Copying over weights

Now that you've built your ResNet34, we'll copy weights over from PyTorch's pretrained resnet to yours. This is another good way to verify that you've designed the architecture correctly (although if you've passed all tests above and your parameter count order matches up, it's very likely that this code will also work).

def copy_weights(my_resnet: ResNet34, pretrained_resnet: models.resnet.ResNet) -> ResNet34:
    """Copy over the weights of `pretrained_resnet` to your resnet."""

    # Get the state dictionaries for each model, check they have the same number of parameters &
    # buffers
    mydict = my_resnet.state_dict()
    pretraineddict = pretrained_resnet.state_dict()
    assert len(mydict) == len(pretraineddict), "Mismatching state dictionaries."

    # Define a dictionary mapping the names of your parameters / buffers to their values in the
    # pretrained model
    state_dict_to_load = {
        mykey: pretrainedvalue
        for (mykey, myvalue), (pretrainedkey, pretrainedvalue) in zip(
            mydict.items(), pretraineddict.items()
        )
    }

    # Load in this dictionary to your model
    my_resnet.load_state_dict(state_dict_to_load)

    return my_resnet


pretrained_resnet = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1).to(device)
my_resnet = copy_weights(my_resnet, pretrained_resnet).to(device)
print("Weights copied successfully!")
Click to see the expected output

This function uses the state_dict() method, which returns an OrderedDict (documentation here) object containing all the parameter/buffer names and their values. State dicts can be extracted from models, saved to your filesystem (this is a common way to store the results of training a model), and can also be loaded back into a model using the load_state_dict method. (Note that you can also load weights using a regular Python dict, but since Python 3.7, the builtin dict is guaranteed to maintain items in the order they're inserted.)

Running Your Model

We've provided you with some images for your model to classify:

IMAGE_FILENAMES = [
    "chimpanzee.jpg",
    "golden_retriever.jpg",
    "platypus.jpg",
    "frogs.jpg",
    "fireworks.jpg",
    "astronaut.jpg",
    "iguana.jpg",
    "volcano.jpg",
    "goofy.jpg",
    "dragonfly.jpg",
]

IMAGE_FOLDER = section_dir / "resnet_inputs"

images = [Image.open(IMAGE_FOLDER / filename) for filename in IMAGE_FILENAMES]

Our images are of type PIL.Image.Image, so we can just call them in a cell to display them, or alternatively use a function like IPython's display:

display(images[0])

We now need to define a transform object like we did for MNIST. We will use the same transforms to convert the PIL image to a tensor, and to normalize it. But we also want to resize the images to height=224, width=224, because not all of them start out with this size and we need them to be consistent before passing them through our model.

In the normalization step, we'll use a mean of [0.485, 0.456, 0.406], and a standard deviation of [0.229, 0.224, 0.225] (these are the mean and std dev of images from ImageNet). Note that the means and std devs have three elements, because ImageNet contains RGB rather than monochrome images, and we're normalising over each of the three RGB channels separately.

IMAGE_SIZE = 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

IMAGENET_TRANSFORM = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ]
)

prepared_images = t.stack([IMAGENET_TRANSFORM(img) for img in images], dim=0).to(device)
assert prepared_images.shape == (len(images), 3, IMAGE_SIZE, IMAGE_SIZE)

Exercise - verify your model's predictions

```yaml Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵⚪⚪

You should spend up to ~10 minutes on this exercise. ```

Lastly, you should run your model with these prepared images, and verify that your predictions are the same as the model's predictions.

You can do this by filling in the predict function below, then running the code. We've also provided you with a file imagenet_labels.json which you can use to get the actual classnames of imagenet data, and see what your model's predictions actually are.

When you run the code, you should find that your top prediction probabilities are within about 0.01% of the reference model's probabilities most (not all) of the time. This kind of error is not uncommon when you have slightly different orders of linear operations or small implementation details which differ between models, and which can introduce floating point errors that compound as we move through the model. As a bonus exercise (which may or may not break your sanity), you're welcome to try and work through our implementation, comparing it to the PyTorch model's implementation and find where the discrepancy comes from!

Tip - the torch method torch.max will return a tuple of (values, indices) if you supply a dimension argument dim.

@t.inference_mode()
def predict(
    model: nn.Module, images: Float[Tensor, "batch rgb h w"]
) -> tuple[Float[Tensor, "batch"], Int[Tensor, "batch"]]:
    """
    Returns the maximum probability and predicted class for each image, as a tensor of floats and
    ints respectively.
    """
    model.eval()
    raise NotImplementedError()


with open(section_dir / "imagenet_labels.json") as f:
    imagenet_labels = list(json.load(f).values())

# Check your predictions match those of the pretrained model
my_probs, my_predictions = predict(my_resnet, prepared_images)
pretrained_probs, pretrained_predictions = predict(pretrained_resnet, prepared_images)
assert (my_predictions == pretrained_predictions).all()
t.testing.assert_close(my_probs, pretrained_probs, atol=5e-4, rtol=0)  # tolerance of 0.05%
print("All predictions match!")

# Print out your predictions, next to the corresponding images
for i, img in enumerate(images):
    table = Table("Model", "Prediction", "Probability")
    table.add_row("My ResNet", imagenet_labels[my_predictions[i]], f"{my_probs[i]:.3%}")
    table.add_row(
        "Reference Model",
        imagenet_labels[pretrained_predictions[i]],
        f"{pretrained_probs[i]:.3%}",
    )
    rprint(table)
    display(img)
Click to see the expected output (for first image)
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ Model            Prediction                          Probability ┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ My ResNet       │ chimpanzee, chimp, Pan troglodytes │ 90.515%     │
│ Reference Model │ chimpanzee, chimp, Pan troglodytes │ 90.513%     │
└─────────────────┴────────────────────────────────────┴─────────────┘
Help! My model is predicting roughly the same percentage for every category!

This can indicate that your model weights are randomly initialized, meaning the weight loading process didn't actually take. Or, you reinitialized your model by accident after loading the weights.

Solution
@t.inference_mode()
def predict(
    model: nn.Module, images: Float[Tensor, "batch rgb h w"]
) -> tuple[Float[Tensor, "batch"], Int[Tensor, "batch"]]:
    """
    Returns the maximum probability and predicted class for each image, as a tensor of floats and
    ints respectively.
    """
    model.eval()
    logits = model(images)
    probabilities = logits.softmax(dim=-1)
    return probabilities.max(dim=-1)

If you've done everything correctly, your version should give the same classifications, and the percentages should match at least to a couple decimal places.

If it does, congratulations, you've now run an entire ResNet, using barely any code from torch.nn! The only things we used were nn.Module and nn.Parameter.

If it doesn't, you get to practice model debugging! Remember to use the utils.print_param_count function that was provided.

Aside - hooks

One problem you might have encountered is that your model outputs NaNs rather than actual numbers. When debugging this, it's useful to try and identify which module the error first appears in. This is a great use-case for hooks, which are something we'll be digging a lot more into during our mechanistic interpretability exercises later on.

A hook is basically a function which you can attach to a particular nn.Module, which gets executed during your model's forward or backward passes. Here, we'll only consider forward hooks. A hook function's type signature is:

def hook(module: nn.Module, inputs: list[Tensor], output: Tensor) -> None:
    pass

The inputs argument is a list of the inputs to the module (often just one tensor), and the output argument is the output of the module. This hook gets registered to a module by calling module.register_forward_hook(hook). During forward passes, the hook function will run.

Here is some code which will check for NaNs in the output of each module, and raise a ValueError if it finds any. We've also given you an example tiny network which produces a NaN in the output of the second layer, to demonstrate it on.

class NanModule(nn.Module):
    """
    Define a module that always returns NaNs (we will use hooks to identify this error).
    """

    def forward(self, x):
        return t.full_like(x, float("nan"))


def hook_check_for_nan_output(module: nn.Module, input: tuple[Tensor], output: Tensor) -> None:
    """
    Hook function which detects when the output of a layer is NaN.
    """
    if t.isnan(output).any():
        raise ValueError(f"NaN output from {module}")


def add_hook(module: nn.Module) -> None:
    """
    Register our hook function in a module.

    Use model.apply(add_hook) to recursively apply the hook to model and all submodules.
    """
    module.register_forward_hook(hook_check_for_nan_output)


def remove_hooks(module: nn.Module) -> None:
    """
    Remove all hooks from module.

    Use module.apply(remove_hooks) to do this recursively.
    """
    module._backward_hooks.clear()
    module._forward_hooks.clear()
    module._forward_pre_hooks.clear()


# Create our model with a NaN in the middle, and apply a hook fn to it which checks for NaNs
model = nn.Sequential(nn.Identity(), NanModule(), nn.Identity())
model = model.apply(add_hook)

# Run the model, and our hook function should raise an error that gets caught by the try-except
try:
    input = t.randn(3)
    output = model(input)
except ValueError as e:
    print(e)

# Remove hooks at the end
model = model.apply(remove_hooks)

When you run this code, you should find it raising an error at the NanModule.

Important - when you're working with PyTorch hooks, make sure you remember to remove them at the end of each use! This is a classic source of bugs, and one of the things that make PyTorch hooks so janky. When we study TransformerLens in the next chapter, we'll use a version of hooks that is essentially the same under the hood, but comes with quite a few quality of life improvements!