[0.2] - CNNs & ResNets
Please send any problems / bugs on the #errata channel in the Slack group, and ask any questions on the dedicated channels for this chapter of material.
If you want to change to dark mode, you can do this by clicking the three horizontal lines in the top-right, then navigating to Settings → Theme.
Links to all other chapters: (0) Fundamentals, (1) Transformer Interpretability, (2) RL.

Introduction
This section is designed to get you familiar with basic neural networks: how they are structured, the basic operations like linear layers and convolutions which go into making them, and why they work as well as they do. You'll start by making very simple neural networks, and by the end of today you'll build up to assembling ResNet34, a comparatively much more complicated architecture.
For a lecture on the material today, which provides some high-level understanding before you dive into the material, watch the video below:
Content & Learning Objectives
1️⃣ Making your own modules
In the first set of exercises, we'll cover the general structure of modules in PyTorch. You'll also implement your own basic modules, including for ReLU and Linear layers. You'll finish by assembling a very simple neural network.
Learning Objectives
- Learn how to create your own modules in PyTorch, by inheriting from
nn.Module- Assemble the pieces together to create a simple fully-connected network, to classify MNIST digits
2️⃣ Training Neural Networks
Here, you'll learn how to write a training loop in PyTorch. We'll keep it simple for today (and later on we'll experiment with more modular and extensible designs).
Learning Objectives
- Understand how to work with transforms, datasets and dataloaders
- Understand the basic structure of a training loop
- Learn how to write your own validation loop
3️⃣ Convolutions
In this section, you'll read about convolutions, and implement them as an nn.Module (not from scratch; we leave that to the bonus exercises). You'll also learn about maxpooling, and implement that as well.
Learning Objectives
- Learn how convolutions work, and why they are useful for vision models
- Implement your own convolutions, and maxpooling layers
4️⃣ ResNets
Here, you'll combine all the pieces you've learned so far to assemble ResNet34, a much more complex architecture used for image classification.
Learning Objectives
- Learn about skip connections, and how they help overcome the degradation problem
- Learn about batch normalization, and why it is used in training
- Assemble your own ResNet, and load in weights from PyTorch's ResNet implementation
☆ Bonus - Feature Extraction
In this section, you'll learn how to repurpose your ResNet to perform a different task than it was designed for, using feature extraction.
Learning Objectives
- Understand the difference between feature extraction and finetuning
- Perform feature extraction on a pre-trained ResNet
☆ Bonus - Convolutions From Scratch
This section takes you through the low-level details of how to actually implement convolutions. It's not necessary to understand this section to complete the exercises, but it's a good way to get a deeper understanding of how convolutions work.
Learning Objectives
- Understand how array strides work, and why they're important for efficient linear operations
- Learn how to use
as_stridedto perform simple linear operations like trace and matrix multiplication- Implement your own convolutions and maxpooling functions using stride-based methods
Setup code
import json
import sys
from collections import namedtuple
from dataclasses import dataclass
from pathlib import Path
import einops
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torchinfo
from IPython.display import display
from jaxtyping import Float, Int
from PIL import Image
from rich import print as rprint
from rich.table import Table
from torch import Tensor
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, models, transforms
from tqdm.notebook import tqdm
# Make sure exercises are in the path
chapter = "chapter0_fundamentals"
section = "part2_cnns"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
sys.path.append(str(exercises_dir))
MAIN = __name__ == "__main__"
import part2_cnns.tests as tests
import part2_cnns.utils as utils
from plotly_utils import line
Help - I get a NumPy-related error
This is an annoying colab-related issue which I haven't been able to find a satisfying fix for. If you restart runtime (but don't delete runtime), and run just the imports cell above again (but not the %pip install cell), the problem should go away.