3️⃣ Training on MNIST from scratch

Learning Objectives
  • Implement more forward and backward functions, including for indexing, summing, and matrix multiplication
  • Learn how to build higher-level abstractions like parameters and modules on top of individual functions and tensors
  • Complete the process of building up a neural network from scratch and training it via gradient descent.

Congrats on implementing backprop! Soon we'll be able to train a full model from scratch, but first we'll go through a bunch of backward functions which will be necessary for training (as well as ones that cover some interesting cases). These should be a lot like your log_back and multiply_back0, multiplyback1 examples earlier.

More backward functions!

Note - some of these exercises can get a bit repetitive, and so you're welcome to skip through many of these exercises if you don't find them interesting, and/or you're pressed for time. The exercises in the section "Parameters & Modules" and beyond are much more conceptually valuable.

Additionally, most of these functions can be implemented simply in 1 or 2 lines, so if you find yourself writing a lot more than that then you might want to look at the solution instead.

Exercise - negative

```yaml Difficulty: 🔴⚪⚪⚪⚪ Importance: 🔵⚪⚪⚪⚪

You should spend up to ~5 minutes on this exercise (it's not a trick question, it is as simple as it looks!). ```

torch.negative just performs -x elementwise. Make your own version negative using wrap_forward_fn. Note, you don't need to worry about unbroadcasting here because np.negative won't change the input shape (technically it can since np.negative(x, out) is actually the negative version of x broadcasted to the shape of out, but we won't be using it in this way during our exercises).

def negative_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
    """Backward function for f(x) = -x elementwise."""
    raise NotImplementedError()


negative = wrap_forward_fn(np.negative)
BACK_FUNCS.add_back_func(np.negative, 0, negative_back)

tests.test_negative_back(Tensor)
Solution
def negative_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
    """Backward function for f(x) = -x elementwise."""
    return -grad_out

Exercise - exp

```yaml Difficulty: 🔴⚪⚪⚪⚪ Importance: 🔵🔵⚪⚪⚪

You should spend up to 5-10 minutes on this exercise. ```

Make your own version of torch.exp. The backward function should express the result in terms of the out parameter - this more efficient than expressing it in terms of x.

def exp_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
    """Backward function for f(x) = exp(x) elementwise."""
    raise NotImplementedError()


exp = wrap_forward_fn(np.exp)
BACK_FUNCS.add_back_func(np.exp, 0, exp_back)

tests.test_exp_back(Tensor)
Solution
def exp_back(grad_out: Arr, out: Arr, x: Arr) -> Arr:
    """Backward function for f(x) = exp(x) elementwise."""
    return out * grad_out

Exercise - reshape

```yaml Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵⚪⚪⚪⚪

You should spend up to 5-10 minutes on this exercise. ```

reshape is a bit different than the functions we've dealt with so far: it changes the shape of the tensor, not its values. In other words, the backward function needs to be able to map from the gradient $\partial L / \partial \mathbf{x_r}$ to $\partial L / \partial \mathbf{x}$, where $\mathbf{x_r}$ is the reshaped version of $\mathbf{x}$.

Depending how you wrote wrap_forward_fn and backprop, you might need to go back and adjust them to handle this - if you're failing tests but think your implementation is correct, we recommend you go back to these functions and check them.

This function should just be a single line.

def reshape_back(grad_out: Arr, out: Arr, x: Arr, new_shape: tuple) -> Arr:
    """Backward function for torch.reshape."""
    raise NotImplementedError()


reshape = wrap_forward_fn(np.reshape)
BACK_FUNCS.add_back_func(np.reshape, 0, reshape_back)

tests.test_reshape_back(Tensor)
Solution (and explanation)

Explanation: the reshape operation that takes us from the tensor $\frac{\partial L}{\partial \mathbf{x_r}}$ to $\frac{\partial L}{\partial \mathbf{x}}$ is exactly the inverse of the forward reshape operation that produced $\mathbf{x_r}$ from $\mathbf{x}$. In other words, we want to to take grad_out and reshape it back to the shape of x.

def reshape_back(grad_out: Arr, out: Arr, x: Arr, new_shape: tuple) -> Arr:
    """Backward function for torch.reshape."""
    return np.reshape(grad_out, x.shape)

Exercise - permute

```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵⚪⚪⚪⚪

You should spend up to 5-10 minutes on this exercise. ```

In NumPy, the equivalent of torch.permute is called np.transpose, so we will wrap that. Permute is somewhat similar to reshape, but the difference is that it does actually change the order of elements in the underlying array.

Hint - just like with reshape, the inverse of a transposition is another transposition. You might find the function np.argsort useful for getting the inverse transposition.

This function should also just be a single line.

def permute_back(grad_out: Arr, out: Arr, x: Arr, axes: tuple) -> Arr:
    """
    Backward function for torch.permute. Works by inverting the transposition in the forward
    function.
    """
    raise NotImplementedError()


BACK_FUNCS.add_back_func(np.transpose, 0, permute_back)
permute = wrap_forward_fn(np.transpose)

tests.test_permute_back(Tensor)
Solution (and explanation)

The inverse of transposing with axes is transposing using np.argsort(axes). To see this: the forward transpose will send axis j to axis axes[j], and so we want the inverse transposition axes_inv to satisfy axes_inv[axes[j]] = j, in other words the indices of axes_inv should sort axes - this is exactly what np.argsort does.

def permute_back(grad_out: Arr, out: Arr, x: Arr, axes: tuple) -> Arr:
    """
    Backward function for torch.permute. Works by inverting the transposition in the forward
    function.
    """
    return np.transpose(grad_out, np.argsort(axes))

Exercise - sum

```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵⚪⚪⚪

You should spend up to 20-30 minutes on this exercise. ```

The output can also be smaller than the input, such as when calling torch.sum.

Recall that when we looked at broadcasting, the backwards operation was summing over the broadcasted dimensions. This is because a broadcast operation effectively copies our tensor, giving it more gradient paths that we need to sum over. Similarly, the backwards operation for summing is broadcasting. We can intuitively see this as follows: if we have some value L = L(x_summed) where x_summed is the result of summing x over some number of dimensions, then editing x_summed[i, j, ...] += delta has the same downstream effect as editing any one of the x values which were summed over to get x_summed[i, j, ...]. So to get the gradient of L wrt x, we need to copy (broadcast) the gradient of L wrt x_summed up to the full size of x.

Implementing sum_back should have 2 steps:

  1. Adding new dims if they were summed over with keepdim=False. You can do this with np.expand_dims, for example if arr has shape (2, 3) then np.expand_dims(arr, (0, 2)) has shape (1, 2, 1, 3), i.e. it's a new tensor with dummy dimensions created at indices 0 and 2.
  2. Broadcasting along dims that were summed over. Since after step (1) you've effectively reduced to the keepdim=True case, you can now use np.broadcast_to to get the correct shape.

Note, if you get weird errors that you can't explain, and these exceptions don't even go away when you use the solutions provided, this could mean that your implementation of wrap_forward_fn was wrong in a way which wasn't picked up by the tests. You should return to this function and try to fix it (or just use the solution).

def sum_back(grad_out: Arr, out: Arr, x: Arr, dim=None, keepdim=False):
    """Backward function for torch.sum"""
    raise NotImplementedError()


def _sum(x: Arr, dim=None, keepdim=False) -> Arr:
    """Like torch.sum, calling np.sum internally."""
    return np.sum(x, axis=dim, keepdims=keepdim)


sum = wrap_forward_fn(_sum)
BACK_FUNCS.add_back_func(_sum, 0, sum_back)

tests.test_sum_keepdim_false(Tensor)
tests.test_sum_keepdim_true(Tensor)
tests.test_sum_dim_none(Tensor)
tests.test_sum_nonscalar_grad_out(Tensor)
Help - I'm not sure how to handle the case where dim=None.

You can actually handle this pretty easily - if dim=None then grad_out will be a scalar, so it's always fine to broadcast it along the dims that were summed over! This means you can skip step (1), i.e. this step only needs to handle the case where keepdim=False and dim is not None.

Help - I get the error "Encountered error when running `backward` in the test for nonscalar grad_out."

This error is likely due to the fact that you're expanding your tensor in a way that doesn't refer to the dimensions being summed over (i.e. the dim argument).

Remember that in the previous exercise we assumed that the tensors were broadcastable with each other, and our functions could just internally call np.broadcast_to as a result. But here, one tensor is the sum over another tensor's dimensions, and if keepdim=False then they might not broadcast. For instance, if x.shape = (2, 5), out = x.sum(dim=1) has shape (2,) and grad_out.shape = (2,), then the tensors grad_out and x are not broadcastable.

How can you carefully handle the case where keepdim=False and dim doesn't just refer to dimensions at the start of the tensor? (Hint - try and use np.expand_dims).

Solution
def sum_back(grad_out: Arr, out: Arr, x: Arr, dim=None, keepdim=False):
    """Backward function for torch.sum"""
    # Step (1): if keepdim=False, then we need to add back in dims, so grad_out and x have the
    # same number of dims. We don't bother with the dim=None case, since then grad_out is a scalar
    # and this will be handled by our broadcasting in step (2).
    if (not keepdim) and (dim is not None):
        grad_out = np.expand_dims(grad_out, dim)
# Step (2): repeat grad_out along the dims over which x was summed
    return np.broadcast_to(grad_out, x.shape)

Elementwise add, subtract, divide

```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵⚪⚪⚪

You should spend up to 10-15 minutes on this exercise. ```

These are exactly analogous to the multiply case. Note that Python and NumPy have the notion of "floor division", which is a truncating integer division as in 7 // 3 = 2. You can ignore floor division: - we only need the usual floating point division which is called "true division".

Use lambda functions to define and register the backward functions each in one line. We've given you the first one.

add = wrap_forward_fn(np.add)
subtract = wrap_forward_fn(np.subtract)
true_divide = wrap_forward_fn(np.true_divide)

BACK_FUNCS.add_back_func(np.add, 0, lambda grad_out, out, x, y: unbroadcast(grad_out, x))
# YOUR CODE HERE - continue adding to BACK_FUNCS, for each of the 3 functions & both arg orders

tests.test_add_broadcasted(Tensor)
tests.test_subtract_broadcasted(Tensor)
tests.test_truedivide_broadcasted(Tensor)
Solution
BACK_FUNCS.add_back_func(np.add, 0, lambda grad_out, out, x, y: unbroadcast(grad_out, x))
BACK_FUNCS.add_back_func(np.add, 1, lambda grad_out, out, x, y: unbroadcast(grad_out, y))
BACK_FUNCS.add_back_func(np.subtract, 0, lambda grad_out, out, x, y: unbroadcast(grad_out, x))
BACK_FUNCS.add_back_func(np.subtract, 1, lambda grad_out, out, x, y: unbroadcast(-grad_out, y))
BACK_FUNCS.add_back_func(
    np.true_divide, 0, lambda grad_out, out, x, y: unbroadcast(grad_out / y, x)
)
BACK_FUNCS.add_back_func(
    np.true_divide, 1, lambda grad_out, out, x, y: unbroadcast(grad_out * (-x / y**2), y)
)

Indexing

If we have the gradient of L wrt x[index], what is the gradient of L wrt x? The answer is it'll be an array of zeros, filled in with the values of dL/dx[index] at the appropriate index positions. For example, if x = [1, 2, 3] and L = x[0], then we trivially have dL/dx[0] = 1, and we can compute dL/dx = [1, 0, 0] in this way.

In its full generality, exactly how you can index a torch.Tensor is really complicated and there are quite a few cases to handle separately. Our implementation only handles 2 cases:

  • The index is an integer or tuple of integers.
  • The index is a tuple of (array or Tensor) representing coordinates. Each array is 1D and of equal length. Some coordinates may be repeated. This is Integer array indexing.

This latter case is very important, because it describes how we index correct logprobs / probabilities. For example if we're training a classifier and we have tensors logprobs.shape = (batch_size, n_classes) and targets.shape = (n_classes,) then we index the correct logprobs using logprobs[arange(batch_size), targets] (note that the arange function has been given to you earlier, when we defined the Tensor class - it's a simple wrapper around np.arange).

Index = int | tuple[int, ...] | tuple[Arr] | tuple[Tensor]


def coerce_index(index: Index):
    """Helper function: converts array of tensors to array of numpy arrays."""
    if isinstance(index, tuple) and all(isinstance(i, Tensor) for i in index):
        return tuple([i.array for i in index])
    else:
        return index


def _getitem(x: Arr, index: Index) -> Arr:
    """Like x[index] when x is a torch.Tensor."""
    return x[coerce_index(index)]


def getitem_back(grad_out: Arr, out: Arr, x: Arr, index: Index):
    """
    Backwards function for _getitem.

    Hint: use np.add.at(a, indices, b)
    This function works just like a[indices] += b, except that it allows for repeated indices.
    """
    new_grad_out = np.full_like(x, 0)
    np.add.at(new_grad_out, coerce_index(index), grad_out)
    return new_grad_out


getitem = wrap_forward_fn(_getitem)
BACK_FUNCS.add_back_func(_getitem, 0, getitem_back)

Non-Differentiable Functions

For functions like torch.argmax or torch.eq, there's no sensible way to define gradients with respect to the input tensor. For these, we will still use wrap_forward_fn because we still need to unbox the arguments and box the result, but by passing is_differentiable=False we can avoid doing any unnecessary computation.

We've given you this one as an example:

def _argmax(x: Arr, dim=None, keepdim=False):
    """Like torch.argmax."""
    result = np.argmax(x, axis=dim)
    if keepdim:
        return np.expand_dims(result, axis=([] if dim is None else dim))
    return result


argmax = wrap_forward_fn(_argmax, is_differentiable=False)

a = Tensor([1.0, 0.0, 3.0, 4.0], requires_grad=True)
b = a.argmax()
assert not b.requires_grad
assert b.recipe is None
assert b.item() == 3

In-Place Operations

Supporting in-place operations introduces substantial complexity and generally doesn't help performance that much. The problem is that if any of the inputs used in the backward function have been modified in-place since the forward pass, then the backward function will incorrectly calculate using the modified version. PyTorch will warn you when this causes a problem with the error "RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.".

Note - you don't have to fill anything in here; just run the cell. If you're curious, you can implement inplace operations as a bonus exercise at the end, but for now we just warn against inplace operations unless we specify otherwise.

def add_(x: Tensor, other: Tensor, alpha: float = 1.0) -> Tensor:
    """Like torch.add_. Compute x += other * alpha in-place and return tensor."""
    np.add(x.array, other.array * alpha, out=x.array)
    return x


def sub_(x: Tensor, other: Tensor, alpha: float = 1.0) -> Tensor:
    """Like torch.sub_. Compute x -= other * alpha in-place and return tensor."""
    np.subtract(x.array, other.array * alpha, out=x.array)
    return x


def safe_example():
    """This example should work properly."""
    a = Tensor([0.0, 1.0, 2.0, 3.0], requires_grad=True)
    b = Tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
    a.add_(b)
    c = a * b
    c.sum().backward()
    assert a.grad is not None and np.allclose(a.grad.array, [2.0, 3.0, 4.0, 5.0])
    assert b.grad is not None and np.allclose(b.grad.array, [2.0, 4.0, 6.0, 8.0])


def unsafe_example():
    """
    This example is expected to compute the wrong gradients, because dc/db is calculated using the
    modified a.
    """
    a = Tensor([0.0, 1.0, 2.0, 3.0], requires_grad=True)
    b = Tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
    c = a * b
    a.add_(b)
    c.sum().backward()
    if a.grad is not None and np.allclose(a.grad.array, [2.0, 3.0, 4.0, 5.0]):
        print("Grad wrt a is OK!")
    else:
        print("Grad wrt a is WRONG!")
    if b.grad is not None and np.allclose(b.grad.array, [0.0, 1.0, 2.0, 3.0]):
        print("Grad wrt b is OK!")
    else:
        print("Grad wrt b is WRONG!")


safe_example()
unsafe_example()

Mixed Scalar-Tensor Operations

You may have been wondering why our Tensor class has to define both __mul__ and __rmul__ magic methods.

Without __rmul__ defined, executing 2 * a when a is a Tensor would try to call 2.__mul__(a), and the built-in class int would be confused about how to handle this.

Since we have defined __rmul__ for you at the start, and you implemented multiply to work with floats as arguments, the following should "just work".

a = Tensor([0, 1, 2, 3], requires_grad=True)
(a * 2).sum().backward()
b = Tensor([0, 1, 2, 3], requires_grad=True)
(2 * b).sum().backward()
assert a.grad is not None
assert b.grad is not None
assert np.allclose(a.grad.array, b.grad.array)

Exercise - max

```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵⚪⚪⚪

You should spend up to 10-15 minutes on this exercise. ```

Since this is an elementwise function, we can think about the scalar case. For scalar $x$, $y$, the derivative for $\max(x, y)$ wrt $x$ is 1 when $x > y$ and 0 when $x < y$. What should happen when $x = y$?

Intuitively, since $\max(x, x)$ is equivalent to the identity function which has a derivative of 1 wrt $x$, it makes sense for the sum of our partial derivatives wrt $x$ and $y$ to also therefore total 1. The convention used by PyTorch is to split the derivative evenly between the two arguments. We will follow this behavior for compatibility, but it's just as legitimate to say it's 1 wrt $x$ and 0 wrt $y$, or some other arbitrary combination that sums to one.

Help - I'm not sure how to implement this function.

Try returning grad_out * bool_sum, where bool_sum is an array constructed from the sum of two boolean arrays.

You can alternatively use np.where.

Help - I'm passing the first test but not the second.

This probably means that you haven't implemented unbroadcast. You'll need to do this, to get grad_out into the right shape before you use it in np.where.

def maximum_back0(grad_out: Arr, out: Arr, x: Arr, y: Arr):
    """Backwards function for max(x, y) wrt x."""
    raise NotImplementedError()


def maximum_back1(grad_out: Arr, out: Arr, x: Arr, y: Arr):
    """Backwards function for max(x, y) wrt y."""
    raise NotImplementedError()


maximum = wrap_forward_fn(np.maximum)
BACK_FUNCS.add_back_func(np.maximum, 0, maximum_back0)
BACK_FUNCS.add_back_func(np.maximum, 1, maximum_back1)

tests.test_maximum(Tensor)
tests.test_maximum_broadcasted(Tensor)
Solution
def maximum_back0(grad_out: Arr, out: Arr, x: Arr, y: Arr):
    """Backwards function for max(x, y) wrt x."""
    bool_sum = (x > y) + 0.5  (x == y)
    return unbroadcast(grad_out  bool_sum, x)
def maximum_back1(grad_out: Arr, out: Arr, x: Arr, y: Arr):
    """Backwards function for max(x, y) wrt y."""
    bool_sum = (x < y) + 0.5  (x == y)
    return unbroadcast(grad_out  bool_sum, y)

Exercise - functional ReLU

```yaml Difficulty: 🔴⚪⚪⚪⚪ Importance: 🔵⚪⚪⚪⚪

You should spend ~5 minutes on this exercise. ```

A simple and correct ReLU function can be defined in terms of your maximum function. Note the PyTorch version also supports in-place operation, which we are punting to the bonus section for now.

Again, at $x = 0$ your derivative could reasonably be anything between 0 and 1 inclusive, but we've followed PyTorch in making it 0.5. This means you can just use the maximum function defined above!

def relu(x: Tensor) -> Tensor:
    """Like torch.nn.function.relu(x, inplace=False)."""
    raise NotImplementedError()


tests.test_relu(Tensor)
Solution
def relu(x: Tensor) -> Tensor:
    """Like torch.nn.function.relu(x, inplace=False)."""
    return maximum(x, 0.0)

Exercise - 2D matmul

```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵⚪⚪

You should spend up to 20-25 minutes on this exercise. ```

Implement your version of torch.matmul, restricting it to the simpler case where both inputs are 2D (this means we don't need to worry about unbroadcasting or anything).

Note - althought the solution to this exercise is very short (just one line), you may find the actual mathematical derivation a bit tricky. We've given hints to help you, which we recommend using if you're stuck.

def _matmul2d(x: Arr, y: Arr) -> Arr:
    """Matrix multiply restricted to the case where both inputs are exactly 2D."""
    return x @ y


def matmul2d_back0(grad_out: Arr, out: Arr, x: Arr, y: Arr) -> Arr:
    raise NotImplementedError()


def matmul2d_back1(grad_out: Arr, out: Arr, x: Arr, y: Arr) -> Arr:
    raise NotImplementedError()


matmul = wrap_forward_fn(_matmul2d)
BACK_FUNCS.add_back_func(_matmul2d, 0, matmul2d_back0)
BACK_FUNCS.add_back_func(_matmul2d, 1, matmul2d_back1)

tests.test_matmul2d(Tensor)
Help - I need a hint about the math

Let $X$, $Y$ and $M$ denote the variables x, y and out, so we have the matrix relation $M = XY$. The object grad_out is a tensor with elements grad_out[p, q] $ = \frac{\partial L}{\partial M_{p q}}$.

The output of matmul2d_back0 should be the gradient of $L$ wrt $X$, i.e. it should have elements $\frac{\partial L}{\partial X_{i j}}$. Can you write this in terms of the elements of x, y, out and grad_out?

Help - I need the math explained

We can write $\frac{\partial L}{\partial X_{i j}}$ as:

$$ \begin{aligned} \frac{\partial L}{\partial X_{i j}} &=\sum_{pq} \frac{\partial L}{\partial M_{p q}} \frac{\partial M_{p q}}{\partial X_{i j}} \\ &=\sum_{pqr} \left[\text{ grad\_out }\right]_{p q} \frac{\partial (X_{p r} Y_{r q})}{\partial X_{i j}} \\ &=\sum_{q} \left[\text{ grad\_out }\right]_{iq} Y_{j q} \\ &= \left[\text{ grad\_out } \times Y^{\top}\right]_{ij} \end{aligned} $$

where the second line follows because $M_{pq} = \sum_r X_{pr} Y_{rq}$ (and we can rearrange the summands), and the third line follows because $\frac{\partial{X_{pr}}}{X_{ij}} = 1 \text{ if } (p, r) = (i, j), \text{ else } 0$.

In other words, the x.grad attribute should be is grad_out @ y.T.

You can calculate the gradient wrt y in a similar way - we leave this as an exercise for the reader.

Solution
def _matmul2d(x: Arr, y: Arr) -> Arr:
    """Matrix multiply restricted to the case where both inputs are exactly 2D."""
    return x @ y
def matmul2d_back0(grad_out: Arr, out: Arr, x: Arr, y: Arr) -> Arr:
    return grad_out @ y.T
def matmul2d_back1(grad_out: Arr, out: Arr, x: Arr, y: Arr) -> Arr:
    return x.T @ grad_out

Parameters & Modules

We've now written enough backwards passes that we can go up a layer and write our own nn.Parameter and nn.Module. These are important abstractions that help us building up neural networks.

Below is a simple implementation of Parameter. It is itself a Tensor, shares storage with the provided Tensor and requires_grad is True by default - that's it! Make sure you understand the code being run in this cell to test the functionality of this class.

class Parameter(Tensor):
    def __init__(self, tensor: Tensor, requires_grad=True):
        """Share the array with the provided tensor."""
        return super().__init__(tensor.array, requires_grad=requires_grad)

    def __repr__(self):
        return f"Parameter containing:\n{super().__repr__()}"


x = Tensor([1.0, 2.0, 3.0])
p = Parameter(x)
assert p.requires_grad
assert p.array is x.array
assert (
    repr(p)
    == "Parameter containing:\nTensor(array([1., 2., 3.], dtype=float32), requires_grad=True)"
)
x.add_(Tensor(np.array(2.0)))
assert np.allclose(p.array, np.array([3.0, 4.0, 5.0])), (
    "in-place modifications to the original tensor should affect the parameter"
)

Just like torch.Tensor, the nn.Module class has a lot of functionality which we mostly don't care about today. We will just implement enough to get our network training.

Below is a simple implementation. We'll explain it a bit below (if you're already experienced in Python then this might be obvious to you and you can skip it).

  • Single-underscore attributes are a notational convention; they're not treated differently by Python but they're used to indicate that the attribute is private and shouldn't be accessed directly by anyone using the class.
    • _modules is a dict mapping module names to module objects. The modules method returns an iterator over these modules.
    • _parameters is similar, with added recursion to include submodule parameters.
  • Double-underscore attributes are special methods that determine how your class instance behaves when you use certain syntax.
    • __call__ determines what the module does when you call it like a function. In this case, module(*args, **kwargs) calls module.forward(**args, **kwargs) - this is why we only ever need to implement forward in the modules we've written so far.
    • __setattr__ manages attribute setting (i.e. running module.attr = value actually calls module.__setattr__("attr", value)). The default behaviour is to add the attribute to self.__dict__, but we've added custom logic so modules & parameters are also added to self._modules and self._parameters respectively - this is basically how logic like module.parameters() can work.
      • Note that there's a related method __getattr__ which specifies attribute getting behaviour when lookup in self.__dict__ fails.
class Module:
    _modules: dict[str, "Module"]
    _parameters: dict[str, Parameter]

    def __init__(self):
        self._modules: dict[str, "Module"] = {}
        self._parameters: dict[str, Parameter] = {}

    def modules(self) -> Iterator["Module"]:
        """Return the direct child modules of this module, not including self."""
        yield from self._modules.values()

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        """
        Return an iterator over Module parameters.

        recurse: if True, the iterator includes parameters of submodules, recursively.
        """
        yield from self._parameters.values()
        if recurse:
            for mod in self.modules():
                yield from mod.parameters(recurse=True)

    def __setattr__(self, key: str, val: Any) -> None:
        """
        If val is a Parameter or Module, store it in the appropriate _parameters or _modules dict.
        Otherwise, call __setattr__ from the superclass.
        """
        if isinstance(val, Parameter):
            self._parameters[key] = val
        elif isinstance(val, Module):
            self._modules[key] = val
        super().__setattr__(key, val)

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def forward(self):
        raise NotImplementedError("Subclasses must implement forward!")

    def __repr__(self):
        _indent = lambda s_, nSpaces: re.sub("\n", "\n" + (" " * nSpaces), s_)
        lines = [f"({key}): {_indent(repr(module), 2)}" for key, module in self._modules.items()]
        return "".join(
            [
                self.__class__.__name__ + "(",
                "\n  " + "\n  ".join(lines) + "\n" if lines else "",
                ")",
            ]
        )


class TestInnerModule(Module):
    def __init__(self):
        super().__init__()
        self.param1 = Parameter(Tensor([1.0]))
        self.param2 = Parameter(Tensor([2.0]))


class TestModule(Module):
    def __init__(self):
        super().__init__()
        self.inner = TestInnerModule()
        self.param3 = Parameter(Tensor([3.0]))


mod = TestModule()
assert list(mod.modules()) == [mod.inner]
assert list(mod.parameters()) == [mod.param3, mod.inner.param1, mod.inner.param2]
print("Manually verify that the repr looks reasonable:")
print(mod)
print("All tests for `Module` passed!")

Exercise - implement Linear

```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵🔵⚪

You should spend up to 20-25 minutes on this exercise. ```

Now, let's go one level of abstraction higher and create a Linear module. This should inherit from Module and have __init__ & forward methods just like your linear module inheriting from nn.Module in previous exercises. In fact, your code can probably be extremely similar to the time you implemented Linear in the earlier exercises, except you'll need to use methods we've defined already. You should be able to do everything you need in forward using just the matmul operator @, the transpose operator .T (which is equivalent to .permute(-1, -2) as you can see in the tensor class above) and standard tensor addition +.

To restate the task in case you don't remember it from the previous exercises, you should:

  • Define self.weight and self.bias in __init__, with both tensors having a uniform distribution in the range [-sf, sf] where sf = 1/sqrt(in_features),
  • Write the appropriate affine operation in forward, i.e. multiplying by self.weight and adding self.bias if it exists.

Don't forget to wrap your weights as Parameter(Tensor(...)).

class Linear(Module):
    weight: Parameter
    bias: Parameter | None

    def __init__(self, in_features: int, out_features: int, bias=True):
        """
        A simple linear (technically, affine) transformation.

        The fields should be named `weight` and `bias` for compatibility with PyTorch.
        If `bias` is False, set `self.bias` to None.
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        raise NotImplementedError()

    def forward(self, x: Tensor) -> Tensor:
        """
        x: shape (*, in_features)
        Return: shape (*, out_features)
        """
        raise NotImplementedError()

    def extra_repr(self) -> str:
        return (
            f"in_features={self.in_features}, out_features={self.out_features}, "
            f"bias={self.bias is not None}"
        )


linear = Linear(3, 4)
assert isinstance(linear.weight, Tensor)
assert linear.weight.requires_grad

input = Tensor([[1.0, 2.0, 3.0]])
output = linear(input)
assert output.requires_grad

expected_output = input @ linear.weight.T + linear.bias
np.testing.assert_allclose(output.array, expected_output.array)

print("All tests for `Linear` passed!")
Solution
class Linear(Module):
    weight: Parameter
    bias: Parameter | None
def __init__(self, in_features: int, out_features: int, bias=True):
        """
        A simple linear (technically, affine) transformation.
        The fields should be named weight and bias for compatibility with PyTorch.
        If bias is False, set self.bias to None.
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
sf = in_features**-0.5
        self.weight = Parameter(Tensor(sf  (2  np.random.rand(out_features, in_features) - 1)))
        self.bias = Parameter(Tensor(sf  (2  np.random.rand(out_features) - 1))) if bias else None
def forward(self, x: Tensor) -> Tensor:
        """
        x: shape (, in_features)
        Return: shape (, out_features)
        """
        out = (
            x @ self.weight.T
        )  # transpose has been defined as .permute(-1, -2), see the Tensor class
        if self.bias is not None:
            out = out + self.bias
        return out
def extra_repr(self) -> str:
        return (
            f"in_features={self.in_features}, out_features={self.out_features}, "
            f"bias={self.bias is not None}"
        )

Finally, for the sake of completeness, we'll define a ReLU module:

class ReLU(Module):
    def forward(self, x: Tensor) -> Tensor:
        return relu(x)

Now we can define a MLP suitable for classifying MNIST, with zero PyTorch dependency!

class MLP(Module):
    def __init__(self):
        super().__init__()
        self.linear1 = Linear(28 * 28, 64)
        self.linear2 = Linear(64, 64)
        self.relu1 = ReLU()
        self.relu2 = ReLU()
        self.output = Linear(64, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = x.reshape((x.shape[0], 28 * 28))
        x = self.relu1(self.linear1(x))
        x = self.relu2(self.linear2(x))
        x = self.output(x)
        return x

Exercise - implement cross_entropy

```yaml Difficulty: 🔴🔴🟠⚪⚪ Importance: 🔵🔵🔵⚪⚪

You should spend up to 10-15 minutes on this exercise. ```

Make use of your integer array indexing to implement cross_entropy. See the documentation page here.

We discussed this briefly in the section on indexing earlier, but the kind of indexing you should be doing on your logprobs is logprobs[range(batch_size), true_labels], since this is equivalent to returning the vector of length batch_size with elements [logprobs[0, true_labels[0]], logprobs[1, true_labels[1]], ...]. Rather than using range, you should be using the arange function we've provided for you (this is equivalent to torch's torch.arange function, and is defined just below the Tensor class).

Note - if you're using the exp function, it's usually good to make your implementation numerically stable (since taking the exponential of large numbers is prone to overflow). The common solution here is to subtract the maximum value of the tensor from all elements. However, you don't need to worry about that here (consider it a bonus exercise).

def cross_entropy(logits: Tensor, true_labels: Tensor) -> Tensor:
    """Like torch.nn.functional.cross_entropy with reduction='none'.

    logits: shape (batch, classes)
    true_labels: shape (batch,). Each element is the index of the correct label in the logits.

    Return: shape (batch, ) containing the per-example loss.
    """
    raise NotImplementedError()


tests.test_cross_entropy(Tensor, cross_entropy)
Help - I'm not sure how to get logprobs from logits.

They're equal up to a constant: logprobs = logits - log(sum(exp(logits))) (where the sum is over the last dimension).

To see why this is true: let's define C = logits - logprobs. We know sum(exp(logits - C)) = sum(exp(logprobs)) = 1 (this is by definition of logprobs). Factoring out the exp(-C) term, we get exp(C) = sum(exp(logits)), hence C = log(sum(exp(logits))) as required.

Solution
def cross_entropy(logits: Tensor, true_labels: Tensor) -> Tensor:
    """Like torch.nn.functional.cross_entropy with reduction='none'.
    logits: shape (batch, classes)
    true_labels: shape (batch,). Each element is the index of the correct label in the logits.
    Return: shape (batch, ) containing the per-example loss.
    """
    batch_size = logits.shape[0]
    logprobs = logits - logits.exp().sum(-1, keepdim=True).log()
    return -logprobs[arange(0, batch_size), true_labels]

or alternatively we can solve a slightly different way, which is still equivalent:

true = logits[arange(0, batch_size), true_labels]
return -log(exp(true) / exp(logits).sum(1))

NoGrad context manager

The last thing our backpropagation system needs is the ability to turn it off completely like torch.no_grad (or torch.inference_mode). We've given you an implementation below, which works by modifying the global grad_tracking_enabled variable.

A few notes on the actual python here (again for people who are more familiar with Python and understand it can skip this):

  • The global keyword is required in order to modify the global grad_tracking_enabled variable. We can still reference its value without this keyword, but we wouldn't be able to change it.
  • The special __enter__ and __exit__ methods are part of the protocol for context managers, which is a more pythonic way of doing this kind of thing. If we have a context manager block like with NoGrad(): ..., then we'll run NoGrad().__enter__() before any of the code in this block, and NoGrad().__exit__() after the block finishes.
class NoGrad:
    """Context manager that disables grad inside the block. Like torch.no_grad."""

    was_enabled: bool

    def __enter__(self):
        """
        Method which is called whenever the context manager is entered, i.e. at the start of the
        `with NoGrad():` block. This disables gradient tracking (but stores the value it had before,
        so we can set it back to this on exit).
        """
        global grad_tracking_enabled
        self.was_enabled = grad_tracking_enabled
        grad_tracking_enabled = False

    def __exit__(self, type, value, traceback):
        """
        Method which is called whenever we exit the context manager. This sets the global
        `grad_tracking_enabled` variable back to the value it had before we entered the context
        manager.
        """
        global grad_tracking_enabled
        grad_tracking_enabled = self.was_enabled


assert grad_tracking_enabled
with NoGrad():
    assert not grad_tracking_enabled
assert grad_tracking_enabled
print(
    "Verified that we've disabled gradients inside `NoGrad`, then set back to its previous "
    "value once we exit."
)

Exercise - implement SGD

```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵⚪⚪

You should spend up to 10-20 minutes on this exercise. ```

In today's final exercise, you should implement the SGD class methods zero_grad and step. This should be pretty familiar if you've gone through yesterday's exercises on optimizers (although without all the bells and whistles from those exercises, because we're literally just implementing plain SGD with no momentum, weight decay or anything).

Important note - in yesterday's exercises it was important to use inplace operations, so we would actually modify the existing tensor data rather than creating new tensors, and this is also the case here. The inplace operation += is supported, since under the hood this calls __iadd__ which we've defined in our Tensor class (same for subtraction, the underlying method here is __isub__). Note that we did discuss earlier how inplace operations are very risky for backprop, this is generally true however here we're using it for parameter updates which aren't meant to be differentiated and which are performed just before zeroing all gradients - this makes it safe in this particular context.

class SGD:
    def __init__(self, params: Iterable[Parameter], lr: float):
        """Vanilla SGD with no additional features."""
        self.params = list(params)
        self.lr = lr
        self.b = [None for _ in self.params]

    def zero_grad(self) -> None:
        """Iterates through params, and sets all grads to None."""
        raise NotImplementedError()

    def step(self) -> None:
        """Iterates through params, and updates each of them by subtracting `param.grad * lr`."""
        raise NotImplementedError()


tests.test_sgd(Parameter, Tensor, SGD)
Solution
class SGD:
    def __init__(self, params: Iterable[Parameter], lr: float):
        """Vanilla SGD with no additional features."""
        self.params = list(params)
        self.lr = lr
        self.b = [None for _ in self.params]
def zero_grad(self) -> None:
        """Iterates through params, and sets all grads to None."""
        for p in self.params:
            p.grad = None
def step(self) -> None:
        """Iterates through params, and updates each of them by subtracting param.grad  lr."""
        with NoGrad():
            for p in self.params:
                p -= p.grad  self.lr

Training Your Network

We've already looked at data loading and training loops earlier in the course, so we'll provide a minimal version of these today as well as the data loading code.

train_loader, test_loader = get_mnist()
visualize(train_loader)

To finish the day, below is some code for a training/testing loop for MNIST images, which also logs & plots the results.

Note, it's normal to encounter some bugs and glitches at this point - just go back and fix them until everything runs! Because backprop is annoying and fiddly and depends heavily on exactly how the implementation works (with too many edge cases to test all of them), you may have to resort to replacing your code with the reference solution until you find the source of the error - this is a bit frustrating, but we'd be lying if we said ML isn't without its share of slow debugging sessions!

def train(
    model: MLP,
    train_loader: DataLoader,
    optimizer: SGD,
    epoch: int,
    train_loss_list: list | None = None,
):
    print(f"Epoch: {epoch}")
    progress_bar = tqdm(train_loader)
    for data, target in progress_bar:
        data, target = Tensor(data.numpy()), Tensor(target.numpy())
        optimizer.zero_grad()
        output = model(data)
        loss = cross_entropy(output, target).sum() / len(output)
        loss.backward()
        progress_bar.set_description(f"Train set: Avg loss: {loss.item():.3f}")
        optimizer.step()
        if train_loss_list is not None:
            train_loss_list.append(loss.item())


def test(model: MLP, test_loader: DataLoader, test_accuracy_list: list | None = None):
    test_loss = 0
    test_accuracy = 0
    with NoGrad():
        for data, target in test_loader:
            data, target = Tensor(data.numpy()), Tensor(target.numpy())
            output: Tensor = model(data)
            test_loss += cross_entropy(output, target).sum().item()
            pred = output.argmax(dim=1, keepdim=True)
            test_accuracy += (pred == target.reshape(pred.shape)).sum().item()
    n_data = len(test_loader.dataset)
    test_loss /= n_data
    print(
        f"Test set:  Avg loss: {test_loss:.3f}, Accuracy: {test_accuracy}/{n_data} "
        f"({test_accuracy / n_data:.1%})"
    )
    if test_accuracy_list is not None:
        test_accuracy_list.append(test_accuracy / n_data)


num_epochs = 5
model = MLP()
start = time.time()
train_loss_list = []
test_accuracy_list = []
optimizer = SGD(model.parameters(), 0.01)
for epoch in range(num_epochs):
    train(model, train_loader, optimizer, epoch, train_loss_list)
    test(model, test_loader, test_accuracy_list)

print(f"\nCompleted in {time.time() - start: .2f}s")

line(
    [train_loss_list, test_accuracy_list],
    x_max=num_epochs,
    yaxis2_range=[0, 1],
    use_secondary_yaxis=True,
    labels={"x": "Batches seen", "y1": "Cross entropy loss", "y2": "Test accuracy"},
    title="MLP training on MNIST from scratch!",
    width=800,
)
Epoch: 0
Train set: Avg loss: 1.864: 100%|██████████| 118/118 [00:02<00:00, 55.00it/s]
Test set:  Avg loss: 1.848, Accuracy: 5763/10000 (57.6%)
Epoch: 1
Train set: Avg loss: 1.056: 100%|██████████| 118/118 [00:02<00:00, 56.57it/s]
Test set:  Avg loss: 0.998, Accuracy: 7843/10000 (78.4%)
Epoch: 2
Train set: Avg loss: 0.783: 100%|██████████| 118/118 [00:02<00:00, 54.62it/s]
Test set:  Avg loss: 0.652, Accuracy: 8315/10000 (83.2%)
Epoch: 3
Train set: Avg loss: 0.548: 100%|██████████| 118/118 [00:02<00:00, 54.98it/s]
Test set:  Avg loss: 0.518, Accuracy: 8600/10000 (86.0%)
Epoch: 4
Train set: Avg loss: 0.425: 100%|██████████| 118/118 [00:02<00:00, 56.93it/s]
Test set:  Avg loss: 0.447, Accuracy: 8768/10000 (87.7%)

Completed in  11.24s

Note - this training loop (if done correctly) will look to the one we used in earlier sections is that we're using SGD rather than Adam. You can try adapting your Adam code from the previous day's exercises, and get the same results as you have in earlier sections.

If it works then congratulations - you've implemented a fully-functional autograd system!