1️⃣ Optimizers

Learning Objectives
  • Understand how different optimization algorithms work
  • Translate pseudocode for these algorithms into code
  • Understand the idea of loss landscapes, and how they can visualize specific challenges in the optimization process

Reading

Some of these are strongly recommended, while others are optional. If you like, you can jump back to some of these videos while you're going through the material, if you feel like you need to.

Gradient Descent

Tomorrow, we'll look in detail about how the backpropagation algorithm works. But for now, let's take it as read that calling loss.backward() on a scalar loss will result in the computation of the gradients $\frac{\partial loss}{\partial w}$ for every parameter w in the model, and store these values in w.grad. How do we use these gradients to update our parameters in a way which decreases loss?

A loss function can be any differentiable function such that we prefer a lower value. To apply gradient descent, we start by initializing the parameters to random values (the details of this are subtle), and then repeatedly compute the gradient of the loss with respect to the model parameters. It can be proven that for an infinitesimal step, moving in the direction of the gradient would increase the loss by the largest amount out of all possible directions.

We actually want to decrease the loss, so we subtract the gradient to go in the opposite direction. Taking infinitesimal steps is no good, so we pick some learning rate $\lambda$ (also called the step size) and scale our step by that amount to obtain the update rule for gradient descent:

$$\theta_t \leftarrow \theta_{t-1} - \lambda \nabla L(\theta_{t-1})$$

We know that an infinitesimal step will decrease the loss, but a finite step will only do so if the loss function is linear enough in the neighbourhood of the current parameters. If the loss function is too curved, we might actually increase our loss.

The biggest advantage of this algorithm is that for N bytes of parameters, you only need N additional bytes of memory to store the gradients, which are of the same shape as the parameters. GPU memory is very limited, so this is an extremely relevant consideration. The amount of computation needed is also minimal: one multiply and one add per parameter.

The biggest disadvantage is that we're completely ignoring the curvature of the loss function, not captured by the gradient consisting of partial derivatives. Intuitively, we can take a larger step if the loss function is flat in some direction or a smaller step if it is very curved. Generally, you could represent this by some matrix P that pre-multiplies the gradients to rescale them to account for the curvature. $P$ is called a preconditioner, and gradient descent is equivalent to approximating $P$ by an identity matrix, which is a very bad approximation.

Most competing optimizers can be interpreted as trying to do something more sensible for $P$, subject to the constraint that GPU memory is at a premium. In particular, constructing $P$ explicitly is infeasible, since it's an $N \times N$ matrix and N can be hundreds of billions. One idea is to use a diagonal $P$, which only requires N additional memory. An example of a more sophisticated scheme is Shampoo.

The algorithm is called Shampoo because you put shampoo on your hair before using conditioner, and this method is a pre-conditioner.

If you take away just one thing from this entire curriculum, please don't let it be this.

Stochastic Gradient Descent

The terms gradient descent and SGD are used loosely in deep learning. To be technical, there are three variations:

  • Batch gradient descent - the loss function is the loss over the entire dataset. This requires too much computation unless the dataset is small, so it is rarely used in deep learning.
  • Stochastic gradient descent - the loss function is the loss on a randomly selected example. Any particular loss may be completely in the wrong direction of the loss on the entire dataset, but in expectation it's in the right direction. This has some nice properties but doesn't parallelize well, so it is rarely used in deep learning.
  • Mini-batch gradient descent - the loss function is the loss on a batch of examples of size batch_size. This is the standard in deep learning.

The class torch.optim.SGD can be used for any of these by varying the number of examples passed in. We will be using only mini-batch gradient descent in this course.

Batch Size

In addition to choosing a learning rate or learning rate schedule, we need to choose the batch size or batch size schedule as well. Intuitively, using a larger batch means that the estimate of the gradient is closer to that of the true gradient over the entire dataset, but this requires more compute. Each element of the batch can be computed in parallel so with sufficient compute, one can increase the batch size without increasing wall-clock time. For small-scale experiments, a good heuristic is thus "fill up all of your GPU memory".

At a larger scale, we would expect diminishing returns of increasing the batch size, but empirically it's worse than that - a batch size that is too large generalizes more poorly in many scenarios. The intuition that a closer approximation to the true gradient is always better is therefore incorrect. See this paper for one discussion of this.

For a batch size schedule, most commonly you'll see batch sizes increase over the course of training. The intuition is that a rough estimate of the proper direction is good enough early in training, but later in training it's important to preserve our progress and not "bounce around" too much.

You will commonly see batch sizes that are a multiple of 32. One motivation for this is that when using CUDA, threads are grouped into "warps" of 32 threads which execute the same instructions in parallel. So a batch size of 64 would allow two warps to be fully utilized, whereas a size of 65 would require waiting for a third warp to finish. As batch sizes become larger, this wastage becomes less important.

Powers of two are also common - the idea here is that work can be recursively divided up among different GPUs or within a GPU. For example, a matrix multiplication can be expressed by recursively dividing each matrix into four equal blocks and performing eight smaller matrix multiplications between the blocks.

In tomorrow's exercises, you'll have the option to expore batch sizes in more detail.

Common Themes in Gradient-Based Optimizers

Weight Decay

Weight decay means that on each iteration, in addition to a regular step, we also shrink each parameter very slightly towards 0 by multiplying a scaling factor close to 1, e.g. 0.9999. Empirically, this seems to help but there are no proofs that apply to deep neural networks.

In the case of linear regression, weight decay is mathematically equivalent to having a prior that each parameter is Gaussian distributed - in other words it's very unlikely that the true parameter values are very positive or very negative. This is an example of "inductive bias" - we make an assumption that helps us in the case where it's justified, and hurts us in the case where it's not justified.

For a Linear layer, it's common practice to apply weight decay only to the weight and not the bias. It's also common to not apply weight decay to the parameters of a batch normalization layer. Again, there is empirical evidence (such as Jai et al 2018) and there are heuristic arguments to justify these choices, but no rigorous proofs. Note that PyTorch will implement weight decay on the weights and biases of linear layers by default - see the bonus exercises tomorrow for more on this.

Momentum

Momentum means that the step includes a term proportional to a moving average of past gradients. Distill.pub has a great article on momentum, which you should definitely read if you have time. Don't worry if you don't understand all of it; skimming parts of it can be very informative. For instance, the first half discusses the conditioning number (a very important concept to understand in optimisation), and concludes by giving an intuitive argument for why we generally set the momentum parameter close to 1 for ill-conditioned problems (those with a very large conditioning number).

Visualising optimization with pathological curvatures

A pathological curvature is a type of surface that is similar to ravines and is particularly tricky for plain SGD optimization. In words, pathological curvatures typically have a steep gradient in one direction with an optimum at the center, while in a second direction we have a slower gradient towards a (global) optimum. Let’s first create an example surface of this and visualize it. The code below creates 2 visualizations (3D and 2D) and also adds the minimum point to the plot (note this is the min in the visible region, not the global minimum).

def pathological_curve_loss(x: Tensor, y: Tensor):
    # Example of a pathological curvature. There are many more possible, feel free to experiment here!
    x_loss = t.tanh(x) ** 2 + 0.01 * t.abs(x)
    y_loss = t.sigmoid(y)
    return x_loss + y_loss


plot_fn(pathological_curve_loss, min_points=[(0, "y_min")])
Click to see the expected output

In terms of optimization, you can image that x and y are weight parameters, and the curvature represents the loss surface over the space of x and y. Note that in typical networks, we have many, many more parameters than two, and such curvatures can occur in multi-dimensional spaces as well.

Ideally, our optimization algorithm would find the center of the ravine and focuses on optimizing the parameters towards the direction of y. However, if we encounter a point along the ridges, the gradient is much greater in x than y, and we might end up jumping from one side to the other. Due to the large gradients, we would have to reduce our learning rate slowing down learning significantly.

To test our algorithms, we can implement a simple function to train two parameters on such a surface.

Exercise - implement opt_fn_with_sgd

```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵🔵🔵⚪

You should spend up to 15-20 minutes on this exercise. ```

Implement the opt_fn_with_sgd function using torch.optim.SGD. This function optimizes parameters (x, y) (which represent coordinates at which we evaluate a function) using gradient descent on that function value. In other words, this should look just like your optimization loops in previous days' material, except rather than passing in model.parameters() to your optimizer, you pass in (xy,) (because it needs to be an iterable of parameters, not just a single parameter).

Remember, your update steps optimizer.step() will automatically change the values of xy inplace - this means that you shouldn't store past values like xy_list.append(xy) because then past elements of that list will be modified when xy is updated. Instead, you should use something like xy_list.append(xy.detach().clone()) to make sure you're returning a copy of the tensor, which won't continue to be modified.

We've also provided you with a function plot_fn_with_points, which plots a function as well as a list of points produced by functions like the one above. The code below starts from (2.5, 2.5) and adds the resulting trajectory of (x, y) coordinates to the contour plot. Does it find the minimum? Play with the learning rate and momentum a bit and see how close you can get within 100 iterations.

def opt_fn_with_sgd(
    fn: Callable, xy: Float[Tensor, "2"], lr=0.001, momentum=0.98, n_iters: int = 100
) -> Float[Tensor, "n_iters 2"]:
    """
    Optimize the a given function starting from the specified point.

    xy: shape (2,). The (x, y) starting point.
    n_iters: number of steps.
    lr, momentum: parameters passed to the torch.optim.SGD optimizer.

    Return: (n_iters+1, 2). The (x, y) values, from initial values to values after step `n_iters`.
    """
    # Make sure tensor has requires_grad=True, otherwise it can't be optimized (more on this tomorrow!)
    assert xy.requires_grad

    raise NotImplementedError()


points = []

optimizer_list = [
    (optim.SGD, {"lr": 0.1, "momentum": 0.0}),
    (optim.SGD, {"lr": 0.02, "momentum": 0.99}),
]

for optimizer_class, params in optimizer_list:
    xy = t.tensor([2.5, 2.5], requires_grad=True)
    xys = opt_fn_with_sgd(
        pathological_curve_loss, xy=xy, lr=params["lr"], momentum=params["momentum"]
    )
    points.append((xys, optimizer_class, params))
    print(f"{params=}, last point={xys[-1]}")

plot_fn_with_points(pathological_curve_loss, points=points, min_points=[(0, "y_min")])
Click to see the expected output
params={'lr': 0.1, 'momentum': 0.0}, last point=tensor([0.2300, 1.4820])
params={'lr': 0.02, 'momentum': 0.99}, last point=tensor([ 0.7196, -6.4586])
Help - I'm not sure if my opt_fn_with_sgd is implemented properly.

With a learning rate of 0.02 and momentum of 0.99, my SGD was able to reach [ 0.8110, -6.3344] after 100 iterations.

Help - I'm getting Can't call numpy() on Tensor that requires grad.

This is a protective mechanism built into PyTorch. The idea is that once you convert your Tensor to NumPy, PyTorch can no longer track gradients, but you might not understand this and expect backprop to work on NumPy arrays.

All you need to do to convince PyTorch you're a responsible adult is to call detach() on the tensor first, which returns a view that does not require grad and isn't part of the computation graph.

Solution
def opt_fn_with_sgd(
    fn: Callable, xy: Float[Tensor, "2"], lr=0.001, momentum=0.98, n_iters: int = 100
) -> Float[Tensor, "n_iters 2"]:
    """
    Optimize the a given function starting from the specified point.
    xy: shape (2,). The (x, y) starting point.
    n_iters: number of steps.
    lr, momentum: parameters passed to the torch.optim.SGD optimizer.
    Return: (n_iters+1, 2). The (x, y) values, from initial values to values after step n_iters.
    """
    # Make sure tensor has requires_grad=True, otherwise it can't be optimized (more on this tomorrow!)
    assert xy.requires_grad
optimizer = optim.SGD((xy,), lr=lr, momentum=momentum)
xy_list = [xy.detach().clone()]  # so we don't unintentionally modify past values in xy_list
for i in range(n_iters):
        fn(xy[0], xy[1]).backward()
        optimizer.step()
        optimizer.zero_grad()
        xy_list.append(xy.detach().clone())
return t.stack(xy_list)

Build Your Own Optimizers

Now let's build our own drop-in replacement for these three classes from torch.optim. For each of the exercises you'll have to translate pseudocode that we give you into actual code. If you want an extra challenge, you can try and work directly from the pseudocode in the PyTorch documentation page rather than what we give you.

A warning regarding in-place operations

Be careful with expressions like x = x + y and x += y. They are NOT equivalent in Python.

  • The first one allocates a new Tensor of the appropriate size and adds x and y to it, then rebinds x to point to the new variable. The original x is not modified.
  • The second one modifies the storage referred to by x to contain the sum of x and y - it is an "in-place" operation. x.add_(y) and torch.add(x, y, out=x) also work the same way.

Another example: if x and y are the same shape, then x = y won't change the value of x inplace, but x.copy_(y) will (i.e. changing its values to the values of y).

When you're updating parameters in your network you should use inplace operations (because your optimizer was passed an iterable of parameters, and so defining a new parameter value via theta = theta - step will take it out of the optimizer's scope - it will continue to point to the old, unmodified version).

However, be careful of using inplace operations where you shouldn't be - you don't want to accidentally do something like modify the gradients manually!

Exercise - implement SGD

```yaml Difficulty: 🔴🔴🔴🔴⚪ Importance: 🔵🔵🔵⚪⚪

You should spend up to 25-35 minutes on this exercise. This is the first of several exercises like it. The first will probably take the longest. ```

First, you should implement stochastic gradient descent. It should be like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0. The pseudocode simplifies to:

$ b_0 \leftarrow 0 \ \text {for } t=1 \text { to } \ldots \text { do } \ \quad\; g_t \leftarrow \nabla_\theta f_t\left(\theta_{t-1}\right) \ \quad\; \text {if } \lambda \neq 0 \ \quad\;\quad\; g_t \leftarrow g_t+\lambda \theta_{t-1} \ \quad\; \text {if } \mu \neq 0 \ \quad\;\quad\; b_t \leftarrow \mu b_{t-1} + g_t \ \quad\;\quad\; g_t \leftarrow b_t \ \quad\; \theta_t \leftarrow \theta_{t-1} - \gamma g_t $

where $\theta_t$ are the parameters, $g_t$ are the gradients (after being modified by operations like weight decay & momentum if necessary), and $b_t$ are the values we track to implement momentum.

Derivation of the simplified pseudocode

We start by removing the "if nesterov" and "if maximize" sections, since we're not using either of those. We also substitute $\tau=0$ since we're not using dampening. This gives us:

$ \text {for } t=1 \text { to } \ldots \text { do } \\ \quad\; g_t \leftarrow \nabla_\theta f_t\left(\theta_{t-1}\right) \\ \quad\; \text {if } \lambda \neq 0 \\ \quad\;\quad\; g_t \leftarrow g_t+\lambda \theta_{t-1} \\ \quad\; \text {if } \mu \neq 0 \\ \quad\;\quad\; \text{if } t>1 \\ \quad\;\quad\;\quad\; b_t \leftarrow \mu b_{t-1} + g_t \\ \quad\;\quad\; else \\ \quad\;\quad\;\quad\; b_t \leftarrow g_t \\ \quad\;\quad\; g_t \leftarrow b_t \\ \quad\; \theta_t \leftarrow \theta_{t-1} - \gamma g_t $

Finally, we observe that we can set $b_0 = 0$ and then remove the special case handling of the $t=1$ case, which gives us the pseudocode above.

You should complete the step method below, which implements the algorithm described by the pseudocode above. Note that we've added the torch.inference_mode decorator to the step method, which is equivalent to using the context manager with torch.inference_mode():. This is similar to torch.no_grad; the difference between them isn't worth getting into here but in general know that torch.inference_mode is mostly preferred.

The configurations used during tests.test_sgd will start simple (e.g. all parameters set to zero except lr) and gradually move to more complicated ones. This will help you track exactly where in your model the error is coming from.

You should also read the __init__ and zero_grad methods, making sure you understand how these work and what they are doing. Note that setting grad=None like the code below is treated as equivalent to setting grad equal to a tensor of zeros, i.e. the first time we're required to do an operation on the gradient it'll be replaced with this. Making it be None by default is the standard, so as to not use unnecessary memory.

class SGD:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
    ):
        """Implements SGD with momentum.

        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD
        """
        self.params = list(
            params
        )  # turn params into a list (it might be a generator, so iterating over it empties it)
        self.lr = lr
        self.mu = momentum
        self.lmda = weight_decay

        self.b = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        """Zeros all gradients of the parameters in `self.params`."""
        for param in self.params:
            param.grad = None

    @t.inference_mode()
    def step(self) -> None:
        """Performs a single optimization step of the SGD algorithm."""
        raise NotImplementedError()

    def __repr__(self) -> str:
        return f"SGD(lr={self.lr}, momentum={self.mu}, weight_decay={self.lmda})"


tests.test_sgd(SGD)
Solution
class SGD:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
    ):
        """Implements SGD with momentum.
        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD
        """
        self.params = list(
            params
        )  # turn params into a list (it might be a generator, so iterating over it empties it)
        self.lr = lr
        self.mu = momentum
        self.lmda = weight_decay
self.b = [t.zeros_like(p) for p in self.params]
def zero_grad(self) -> None:
        """Zeros all gradients of the parameters in self.params."""
        for param in self.params:
            param.grad = None
@t.inference_mode()
    def step(self) -> None:
        """Performs a single optimization step of the SGD algorithm."""
        for b, theta in zip(self.b, self.params):
            g = theta.grad
            if self.lmda != 0:
                g = (
                    g + self.lmda  theta
                )  # this shouldn't be inplace since we don't want to modify theta.grad
            if self.mu != 0:
                b.copy_(
                    self.mu  b + g
                )  # this does need to be inplace, since we're modifying the value in self.b
                g = b
            theta -= self.lr * g  # inplace operation, to modify params
def __repr__(self) -> str:
        return f"SGD(lr={self.lr}, momentum={self.mu}, weight_decay={self.lmda})"
tests.test_sgd(SGD)

If you feel comfortable with this implementation, you can skim through the remaining ones, since there's diminishing marginal returns to be gained from doing the actual exercises. We still recommend you read the content on the optimizers before the actual exercises, because they contain useful theory to understand. If you want an extra challenge in the actual exercises, you can try and implement the optimization algorithms directly from the PyTorch documentation pseudocode rather than from the simplified pseudocode we give you.

RMSProp (and adaptive methods)

From SGD, we'll move onto discussing adaptive gradient descent methods. These are methods which automatically adjust the learning rate of each parameter during training, based on the size of gradients at previous steps. In a sense this is similar to how momentum operates in SGD, but we don't tend to describe SGD plus momentum as an adaptive method. When discussing momentum, we usually think of the analogy of a ball rolling down a hill, and the ball's velocity accelerates until it reaches some terminal velocity. The momentum parameter $\mu$ controls the terminal velocity: as $\mu \to 1$ the terminal velocity gets very high, which also means it can take a long time to adjust its speed when it enters new territory. In contrast, adaptive methods are better thought of as deliberate, conscious updates to the learning rate of parameters based on past values. They allow us to speed up when we need to, but without sacrificing our ability to adapt quickly when we enter new regimes.

The first adaptive method we'll look at is RMSprop. This is actually the second main adaptive method that was proposed in the optimization literature, after AdaGrad (however the problem with AdaGrad is that it decays the learning rates too quickly - this is the problem that RMSprop solves). RMSprop is similar to SGD, with an added dynamic: the size of parameter steps are scaled according to the variance of past gradients, with higher variance leading to smaller steps. Intuitively, if you're in a very monotonic region of the loss landscape then you want to take larger steps (since you know where you're going and you just want to get there quickly), whereas if you're in a very noisy region and possibly oscillating around minima then you want to take smaller steps.

One final note - when we're using non-adaptive methods like SGD we tend to have an inverse relationship between the learning rate and the batch size. Broadly speaking, this is because a larger batch size means our gradients will have smaller variance, and so we can safely use a larger learning rate. This generally isn't necessary for adaptive methods since the learning rates will be adjusted automatically during training based on the variance of our gradients - we don't need to manually scale them ourselves. Most commonly during optimization, we'll start with the default hyperparameters for whatever adaptive optimizer we're using, and then adjust from there.

Exercise - implement RMSprop

```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵⚪⚪⚪

You should spend up to 15-25 minutes on this exercise. ```

Below, you should implement RMSprop in the same way as you implemented SGD. The pseudocode is slightly more complicated, since we now have to track 2 variables: $b_t$ for applying the momentum effect, and $v_t$ for tracking the variance of past gradients (we've called these b and v below).

Here is a link to the PyTorch version, alternatively you can use our simplified pseudocode again:

Click here for the simplified pseudocode

$ b_0 \leftarrow 0 \\ \text {for } t=1 \text { to } \ldots \text { do } \\ \quad\; g_t \leftarrow \nabla_\theta f_t\left(\theta_{t-1}\right) \\ \quad\; \text {if } \lambda \neq 0 \\ \quad\;\quad\; g_t \leftarrow g_t+\lambda \theta_{t-1} \\ \quad\; v_t \leftarrow \alpha v_{t-1} + (1-\alpha) g_t^2 \\ \quad\; g_t \leftarrow g_t / (\sqrt{v_t} + \epsilon) \\ \quad\; \text {if } \mu \neq 0 \\ \quad\;\quad\; b_t \leftarrow \mu b_{t-1} + g_t \\ \quad\;\quad\; g_t \leftarrow b_t \\ \quad\; \theta_t \leftarrow \theta_{t-1} - \gamma g_t $

Note that we've reordered the pseudocode slightly differently to the PyTorch docs, so that we divide $g_t$ by $\sqrt{v_t + \epsilon}$ before applying momentum. Both ways are equivalent though.

class RMSprop:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.01,
        alpha: float = 0.99,
        eps: float = 1e-08,
        weight_decay: float = 0.0,
        momentum: float = 0.0,
    ):
        """Implements RMSprop.

        Like the PyTorch version, but assumes centered=False
            https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html
        """
        self.params = list(params)  # turn params into a list (because it might be a generator)
        self.lr = lr
        self.eps = eps
        self.mu = momentum
        self.lmda = weight_decay
        self.alpha = alpha

        self.b = [t.zeros_like(p) for p in self.params]
        self.v = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

    @t.inference_mode()
    def step(self) -> None:
        raise NotImplementedError()

    def __repr__(self) -> str:
        return (
            f"RMSprop(lr={self.lr}, eps={self.eps}, momentum={self.mu}, "
            f"weight_decay={self.lmda}, alpha={self.alpha})"
        )


tests.test_rmsprop(RMSprop)
Solution
class RMSprop:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.01,
        alpha: float = 0.99,
        eps: float = 1e-08,
        weight_decay: float = 0.0,
        momentum: float = 0.0,
    ):
        """Implements RMSprop.
        Like the PyTorch version, but assumes centered=False
            https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html
        """
        self.params = list(params)  # turn params into a list (because it might be a generator)
        self.lr = lr
        self.eps = eps
        self.mu = momentum
        self.lmda = weight_decay
        self.alpha = alpha
self.b = [t.zeros_like(p) for p in self.params]
        self.v = [t.zeros_like(p) for p in self.params]
def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None
@t.inference_mode()
    def step(self) -> None:
        for theta, b, v in zip(self.params, self.b, self.v):
            g = theta.grad
            if self.lmda != 0:
                g = g + self.lmda  theta
            v.copy_(
                self.alpha  v + (1 - self.alpha)  g.pow(2)
            )  # inplace operation, to modify value in self.v
            g = g / (v.sqrt() + self.eps)  # not inplace operation
            if self.mu > 0:
                b.copy_(self.mu  b + g)  # inplace operation, to modify value in self.b
                g = b
            theta -= self.lr * g  # inplace operation, to modify params
def __repr__(self) -> str:
        return (
            f"RMSprop(lr={self.lr}, eps={self.eps}, momentum={self.mu}, "
            f"weight_decay={self.lmda}, alpha={self.alpha})"
        )
tests.test_rmsprop(RMSprop)

Adam, and "momentum"

We'll end by implementing Adam and AdamW, two of the most popular optimizers in deep learning. These combine the benefits of RMSprop and SGD with momentum: they have the same variance-based scaling as RMSprop, but they also have an update rule based on the first moment of gradients as well.

There's an important clarification to make here - the first order adjustment of Adam is sometimes called momentum as a shorthand, but there's an important sense in which it isn't. The key difference is that SGD's momentum causes acceleration until we hit terminal velocity, which could be very large for $\mu \approx 1$. In contrast, Adam's momentum is an exponentially weighted moving average - the parameter $\beta_1$ controls how quickly it adjusts (with a value closer to 1 meaning it adjust to newer values more slowly), but it doesn't change the terminal velocity in any sense. Mathematically, the difference between these two is minimal (all you'd need to do is take Adam's update rule $m_t \leftarrow \beta_1 m_{t-1} + (1-\beta_1) g_t$ and change it to $m_t \leftarrow \beta_1 m_{t-1} + g_t$ for it to have the same qualitative behaviour as SGD), but this extra factor makes a lot of difference!

Exercise - implement Adam

```yaml Difficulty: 🔴🔴🔴🔴⚪ Importance: 🔵🔵🔵⚪⚪

You should spend up to 15-20 minutes on this exercise. ```

This should just be an extension of your RMSprop implementation. You still have 2 variables to track, but now the variable $b_t$ for applying momentum has been replaced with $m_t$ for tracking the exponentially weighted moving average of first order moments.

Here's a link to the PyTorch version, alternatively you can use the simplified pseudocode below:

Click here for the simplified pseudocode

$ \text {for } t=1 \text { to } \ldots \text { do } \\ \quad\; g_t \leftarrow \nabla_\theta f_t\left(\theta_{t-1}\right) \\ \quad\; \text {if } \lambda \neq 0 \\ \quad\;\quad\; g_t \leftarrow g_t+\lambda \theta_{t-1} \\ \quad\; m_t \leftarrow \beta_1 m_{t-1} + (1-\beta_1) g_t \\ \quad\; v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g_t^2 \\ \quad\; \widehat{m_t} \leftarrow m_t / (1 - \beta_1^t) \\ \quad\; \widehat{v_t} \leftarrow v_t / (1 - \beta_2^t) \\ \quad\; \theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} / (\sqrt{\widehat{v_t}} + \epsilon) $

Note - we center our first & second moment estimators by dividing by $1 - \beta^t$, which means for this optimizer we do have to track the variable $t$ (make sure to remember to increment it after each use of the step function). We do this because Adam's exponentially weighted moving average would otherwise take a while to converge to the true mean (since its estimates initially behave like the truncated sum of a geometric series). We leave it as an exercise for the reader to derive this (hint - try assuming the expected value $\mathbb{E}[g_t] = g_0$ is the same for all $t$, what does the expression $\mathbb{E}[m_t]$ simplify to?).

class Adam:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.001,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-08,
        weight_decay: float = 0.0,
    ):
        """Implements Adam.

        Like the PyTorch version, but assumes amsgrad=False and maximize=False
            https://pytorch.org/docs/stable/generated/torch.optim.Adam.html
        """
        self.params = list(params)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.lmda = weight_decay
        self.t = 1

        self.m = [t.zeros_like(p) for p in self.params]
        self.v = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

    @t.inference_mode()
    def step(self) -> None:
        raise NotImplementedError()

    def __repr__(self) -> str:
        return (
            f"Adam(lr={self.lr}, beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}, "
            f"weight_decay={self.lmda})"
        )


tests.test_adam(Adam)
Solution
class Adam:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.001,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-08,
        weight_decay: float = 0.0,
    ):
        """Implements Adam.
        Like the PyTorch version, but assumes amsgrad=False and maximize=False
            https://pytorch.org/docs/stable/generated/torch.optim.Adam.html
        """
        self.params = list(params)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.lmda = weight_decay
        self.t = 1
self.m = [t.zeros_like(p) for p in self.params]
        self.v = [t.zeros_like(p) for p in self.params]
def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None
@t.inference_mode()
    def step(self) -> None:
        for theta, m, v in zip(self.params, self.m, self.v):
            g = theta.grad
            if self.lmda != 0:
                g = g + self.lmda  theta
            m.copy_(self.beta1  m + (1 - self.beta1)  g)
            v.copy_(self.beta2  v + (1 - self.beta2)  g.pow(2))
            m_hat = m / (1 - self.beta1self.t)
            v_hat = v / (1 - self.beta2self.t)
            theta -= self.lr  m_hat / (v_hat.sqrt() + self.eps)
        self.t += 1
def __repr__(self) -> str:
        return (
            f"Adam(lr={self.lr}, beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}, "
            f"weight_decay={self.lmda})"
        )

Exercise - implement AdamW

```yaml Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵⚪⚪⚪

You should spend up to 10-15 minutes on this exercise. ```

Finally, you'll adapt your Adam implementation to implement AdamW. This is a very small modification of the Adam update rule, where we apply weight decay in a different way (by modifying the weights $theta_t$ themselves, rather than modifying the gradients $g_t$ and then using those modified gradients in the first & second moment calculations). This means that, unlike with Adam, using weight decay is equivalent to having a Gaussian prior on the weights with mean zero (or alternatively, equivalent to L2 regularization). This is seen as the more "correct" way to implement weight decay, and so AdamW is now generally preferred over Adam.

You can read more about this variant of Adam here. The PyTorch docs are here, and the pseudocode is again provided for you below (but for this exercise we do recommend trying to go without it - having to work with more complex pseudocode and parse out the bits that actually matter is a useful exercise!).

Click here for the simplified pseudocode

$ \text {for } t=1 \text { to } \ldots \text { do } \\ \quad\; g_t \leftarrow \nabla_\theta f_t\left(\theta_{t-1}\right) \\ \quad\; \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ \quad\; m_t \leftarrow \beta_1 m_{t-1} + (1-\beta_1) g_t \\ \quad\; v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g_t^2 \\ \quad\; \widehat{m_t} \leftarrow m_t / (1 - \beta_1^t) \\ \quad\; \widehat{v_t} \leftarrow v_t / (1 - \beta_2^t) \\ \quad\; \theta_t \leftarrow \theta_t - \gamma \widehat{m_t} / (\sqrt{\widehat{v_t}} + \epsilon) $

class AdamW:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.001,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-08,
        weight_decay: float = 0.0,
    ):
        """Implements Adam.

        Like the PyTorch version, but assumes amsgrad=False and maximize=False
            https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
        """
        self.params = list(params)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.lmda = weight_decay
        self.t = 1

        self.m = [t.zeros_like(p) for p in self.params]
        self.v = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

    @t.inference_mode()
    def step(self) -> None:
        raise NotImplementedError()

    def __repr__(self) -> str:
        return (
            f"AdamW(lr={self.lr}, beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}, "
            f"weight_decay={self.lmda})"
        )


tests.test_adamw(AdamW)
Solution
class AdamW:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.001,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-08,
        weight_decay: float = 0.0,
    ):
        """Implements Adam.
        Like the PyTorch version, but assumes amsgrad=False and maximize=False
            https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
        """
        self.params = list(params)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.lmda = weight_decay
        self.t = 1
self.m = [t.zeros_like(p) for p in self.params]
        self.v = [t.zeros_like(p) for p in self.params]
def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None
@t.inference_mode()
    def step(self) -> None:
        for theta, m, v in zip(self.params, self.m, self.v):
            g = theta.grad
            theta = 1 - self.lr  self.lmda
            m.copy_(self.beta1  m + (1 - self.beta1)  g)
            v.copy_(self.beta2  v + (1 - self.beta2)  g.pow(2))
            m_hat = m / (1 - self.beta1self.t)
            v_hat = v / (1 - self.beta2self.t)
            theta -= self.lr * m_hat / (v_hat.sqrt() + self.eps)
        self.t += 1
def __repr__(self) -> str:
        return (
            f"AdamW(lr={self.lr}, beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}, "
            f"weight_decay={self.lmda})"
        )

Plotting multiple optimisers

Finally, we've provided some code which should allow you to plot more than one of your optimisers at once.

Exercise - experiment with different optimizers & params

```yaml Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵⚪⚪

You should spend up to 20-30 minutes on this exercise. ```

We've given you a function below which works just like opt_fn_with_sgd from earlier, but takes in a general optimizer and hyperparameters (as a dictionary of keyword arguments like lr and momentum).

You should use this function to play around with different optimizers and hyperparameters, comparing their performance. The code below gives one example of such a comparison, run it now and see what you get:

def opt_fn(
    fn: Callable,
    xy: Tensor,
    optimizer_class,
    optimizer_hyperparams: dict,
    n_iters: int = 100,
) -> Tensor:
    """Optimize the a given function starting from the specified point.

    optimizer_class: one of the optimizers you've defined, either SGD, RMSprop, or Adam
    optimzer_kwargs: keyword arguments passed to your optimiser (e.g. lr and weight_decay)
    """
    assert xy.requires_grad

    optimizer = optimizer_class([xy], **optimizer_hyperparams)

    xy_list = [
        xy.detach().clone()
    ]  # so that we don't unintentionally modify past values in `xy_list`

    for i in range(n_iters):
        fn(xy[0], xy[1]).backward()
        optimizer.step()
        optimizer.zero_grad()
        xy_list.append(xy.detach().clone())

    return t.stack(xy_list)


points = []

optimizer_list = [
    (SGD, {"lr": 0.03, "momentum": 0.99}),
    (RMSprop, {"lr": 0.02, "alpha": 0.99, "momentum": 0.8}),
    (Adam, {"lr": 0.2, "betas": (0.99, 0.99), "weight_decay": 0.005}),
    (AdamW, {"lr": 0.2, "betas": (0.99, 0.99), "weight_decay": 0.005}),
]

for optimizer_class, params in optimizer_list:
    xy = t.tensor([2.5, 2.5], requires_grad=True)
    xys = opt_fn(
        pathological_curve_loss,
        xy=xy,
        optimizer_class=optimizer_class,
        optimizer_hyperparams=params,
    )
    points.append((xys, optimizer_class, params))

plot_fn_with_points(pathological_curve_loss, min_points=[(0, "y_min")], points=points)
Click to see the expected output

Note that the focus shouldn't be on figuring out "which one is the best optimizer" - this loss landscape (and other examples we'll give you) were specifically designed to be pathological, and exhibit interesting kinds of behaviours from optimizers. The focus should instead be on understanding how the characteristics of optimizers we discussed in the previous sections are reflected visually in the plots produced on these loss landscapes. Some questions you might want to ask:

  • We discussed that Adam (and AdamW) center their first and second moments, so that the early values are large - otherwise they start off small and take a long time to grow. Is this reflected in the plots, i.e. with Adam/AdamW taking larger early steps relative to SGD or RMSprop?
  • The momentum used in SGD and RMSprop causes acceleration until "terminal velocity", which is usually a higher cap than Adam and AdamW. Is this reflected in the step size (and the instability) of those optimizers? Do Adam and AdamW seem to adapt slightly faster when they enter new terrain?
  • Are there any landscapes where weight decay is advantageous, and can you see why it would be?

Some more functions you might want to try out (with their minima marked on the plots):

def bivariate_gaussian(x, y, x_mean=0.0, y_mean=0.0, x_sig=1.0, y_sig=1.0):
    norm = 1 / (2 * np.pi * x_sig * y_sig)
    x_exp = 0.5 * ((x - x_mean) ** 2) / (x_sig**2)
    y_exp = 0.5 * ((y - y_mean) ** 2) / (y_sig**2)
    return norm * t.exp(-x_exp - y_exp)


means = [(1.0, -0.5), (-1.0, 0.5), (-0.5, -0.8)]


def neg_trimodal_func(x, y):
    """
    This function has 3 global minima, at `means`. Unstable methods can overshoot these minima, and
    non-adaptive methods can fail to converge to them in the first place given how shallow the
    gradients are everywhere except in the close vicinity of the minima.
    """
    z = -bivariate_gaussian(x, y, x_mean=means[0][0], y_mean=means[0][1], x_sig=0.2, y_sig=0.2)
    z -= bivariate_gaussian(x, y, x_mean=means[1][0], y_mean=means[1][1], x_sig=0.2, y_sig=0.2)
    z -= bivariate_gaussian(x, y, x_mean=means[2][0], y_mean=means[2][1], x_sig=0.2, y_sig=0.2)
    return z


plot_fn(neg_trimodal_func, x_range=(-2, 2), y_range=(-2, 2), min_points=means)
def rosenbrocks_banana_func(x: Tensor, y: Tensor, a=1, b=100) -> Tensor:
    """
    This function has a global minimum at `(a, a)` so in this case `(1, 1)`. It's characterized by a
    long, narrow, parabolic valley (parameterized by `y = x**2`). Various gradient descent methods
    have trouble navigating this valley because they often oscillate unstably (gradients from the
    `b`-term dwarf the gradients from the `a`-term).

    See more on this function: https://en.wikipedia.org/wiki/Rosenbrock_function.
    """
    return (a - x) ** 2 + b * (y - x**2) ** 2 + 1


plot_fn(
    rosenbrocks_banana_func,
    x_range=(-2.5, 2.5),
    y_range=(-2, 4),
    z_range=(0, 100),
    min_points=[(1, 1)],
)
Some example visualizations & observations

Let's start with the negative trimodal function. You should find that weight decay massively helps performance here, but this is for pretty uninteresting reasons - it essentially adds a slope towards the origin, and when the ball rolls towards the origin it will probably also get caught in one of the three minima. So it doesn't tell us much about the actual optimizers.

More interestingly, we can compare the optimizers when they have weight decay switched off. You should find that Adam can outperform SGD and RMSprop here, because the way it uses "momentum" is better suited to this task. For one thing, the first and second moment centering means it can take larger early steps relative to SGD and RMSprop (which both take a while to accelerate). For another, momentum causes RMSprop step sizes to increase in an unstable way, which is why it will overshoot the minima and get stuck on the other side without careful hyperparameter tuning. SGD is even worse - because of its lack of variance-based scaling, it'll utterly fail to move anywhere unless it starts out very close to one of the three minima.

optimizer_list = [
    (SGD, {"lr": 0.1, "momentum": 0.5}),
    (RMSprop, {"lr": 0.1, "alpha": 0.99, "momentum": 0.5}),
    (Adam, {"lr": 0.1, "betas": (0.9, 0.999)}),
]
points = []
for optimizer_class, params in optimizer_list:
    xy = t.tensor([1.0, 1.0], requires_grad=True)
    xys = opt_fn(neg_trimodal_func, xy=xy, optimizer_class=optimizer_class, optimizer_hyperparams=params)
    points.append((xys, optimizer_class, params))
plot_fn_with_points(neg_trimodal_func, points=points, x_range=(-2, 2), y_range=(-2, 2), min_point=means)

Next, Rosenbrock's banana. This function has a global minimum at (1, 1) inside a long, narrow, parabolic-shaped valley. Basic gradient descent often zigzags back and forth along the valley, making very slow progress. Momentum is absolutely essential to perform well in this task. This is a rare case where SGD plus momentum does converge faster than Adam because the higher terminal velocity enables larger step sizes plus the extreme slope of the loss landscape prevents the kind of instability that usually hinders SGD. However, some caveats: SGD requires a very small step size to prevent unstable oscillations (given how steep the valley is), whereas Adam is much more stable. Furthermore, if we extend the number of iterations, we see that Adam does also converge, and it does so with fewer oscillations than SGD (it stays within the parabolic valley).

optimizer_list = [
    (SGD, {"lr": 0.001, "momentum": 0.99}),
    (Adam, {"lr": 0.1, "betas": (0.9, 0.999)}),
]
points = []
for optimizer_class, params in optimizer_list:
    xy = t.tensor([-1.5, 2.5], requires_grad=True)
    xys = opt_fn(
        rosenbrocks_banana_func, xy=xy, optimizer_class=optimizer_class, optimizer_hyperparams=params, n_iters=500
    )
    points.append((xys, optimizer_class, params))
plot_fn_with_points(
    rosenbrocks_banana_func, x_range=(-2.5, 2.5), y_range=(-2, 4), z_range=(0, 100), min_points=[(1, 1)], points=points
)

## Bonus - parameter groups

> If you're interested in these exercises then you can go through them, if not then you can move on to the next section (weights and biases).

Rather than passing a single iterable of parameters into an optimizer, you have the option to pass a list of parameter groups, each one with different hyperparameters. As an example of how this might work:

optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

The first argument here is a list of dictionaries, with each dictionary defining a separate parameter group. Each should contain a params key, which contains an iterable of parameters belonging to this group. The dictionaries may also contain keyword arguments. If a parameter is not specified in a group, PyTorch uses the value passed as a keyword argument. So the example above is equivalent to:

optim.SGD([
    {'params': model.base.parameters(), 'lr': 1e-2, 'momentum': 0.9},
    {'params': model.classifier.parameters(), 'lr': 1e-3, 'momentum': 0.9}
])

All parameters have default values, with the exception of lr which is why you need to specify it either as a keyword arg to the optimizer or separately in each group.

PyTorch optimisers will store all their params and hyperparams in the param_groups attribute - this is why when we want to modify an optimizer's learning rate (which we'll do later on in the course), even if we didn't specify any parameter groups we'll still need to use optimizer.param_groups[0].lr = new_lr.

### When to use parameter groups

Parameter groups can be useful in several different circumstances. A few examples:

Finetuning a model by freezing earlier layers and only training later layers is an extreme form of parameter grouping. We can use the parameter group syntax to apply a modified form, where the earlier layers have a smaller learning rate. This allows these earlier layers to adapt to the specifics of the problem, while making sure they don't forget all the useful features they've already learned. Often it's good to treat weights and biases differently, e.g. effects like weight decay are often applied to weights but not biases. PyTorch doesn't differentiate between these two, so you'll have to do this manually using paramter groups. This in particular, you might be doing later in the course, if you choose the "train BERT from scratch" exercises during the transformers chapter. On the subject of transformers, weight decay is often not applied to embeddings and layernorms in transformer models.

More generally, if you're trying to replicate a paper, it's important to be able to use all the same training details that the original authors did, so you can get the same results.

### Exercise - rewrite SGD to use parameter groups

>

> Difficulty: 🔴🔴🔴🔴⚪
> Importance: 🔵⚪⚪⚪⚪
> 
> You should spend up to 30-40 minutes on this exercise.
> It's somewhat useful to understand the idea of parameter groups, less so to know how they're actually implemented.
>

You should rewrite the SGD optimizer from the earlier exercises, to use param_groups. This will involve filling in the 3 methods __init__, zero_grad, and step. Some guidance:

- In __init__ you should create self.param_groups, which is a list of dictionaries with each one containing "params" as well as all the hyperparameters for that group. Remember the hierarchy for hparams: "specified for group" > "specified as a keyword argument" > "default value". - In zero_grad and step the logic is the same as before, but now you need a double nested for loop: once over the param groups in self.param_groups, and once over the params in each group. For the latter, make sure you're using the group-specific hyperparameters (i.e. the ones you hopefully stored in self.param_groups in the init method).

class SGD:
    def __init__(self, params, kwargs):
        """Implements SGD with momentum.
        Accepts parameters in groups, or an iterable.
        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD
        """
        # Deal with case where we didn't supply groups, so we just make it into a single dictionary
        if not isinstance(params, (list, tuple)):
            params = [{"params": params}]
# Make sure each group["params"] is a list of params not a generator (so we don't iterate
        # over & destroy it!)
        for p in params:
            p["params"] = list(p["params"])
self.param_groups = []
# YOUR CODE HERE - fill in self.param_groups
        raise NotImplementedError()
def zero_grad(self) -> None:
        raise NotImplementedError()
@t.inference_mode()
    def step(self) -> None:
        raise NotImplementedError()
tests.test_sgd_param_groups(SGD)
Solution
class SGD:
    def __init__(self, params, kwargs):
        """Implements SGD with momentum.
        Accepts parameters in groups, or an iterable.
        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD
        """
        # Deal with case where we didn't supply groups, so we just make it into a single dictionary
        if not isinstance(params, (list, tuple)):
            params = [{"params": params}]
# Make sure each group["params"] is a list of params not a generator (so we don't iterate
        # over & destroy it!)
        for p in params:
            p["params"] = list(p["params"])
self.param_groups = []
for param_group in params:
            # Set hyperparameters hierarchically: specified for group > specified as a keyword
            # argument > default value. We do this via dict merge (right takes precedence over left).
            param_group = {"momentum": 0.0, "weight_decay": 0.0, kwargs, param_group}
# Check that "lr" is supplied
            assert "lr" in param_group, "Error: one of the param groups didn't specify 'lr'"
# Set "params" and "b" in our group
            param_group["b"] = [t.zeros_like(p) for p in param_group["params"]]
self.param_groups.append(param_group)
def zero_grad(self) -> None:
        for param_group in self.param_groups:
            for p in param_group["params"]:
                p.grad = None
@t.inference_mode()
    def step(self) -> None:
        # loop through each param group
for param_group in self.param_groups:
            # Get hparams for this group
            lmda = param_group["weight_decay"]
            mu = param_group["momentum"]
            lr = param_group["lr"]
# Same code as for SGD implementation before, but using group-specific hparams
            for b, theta in zip(param_group["b"], param_group["params"]):
                g = theta.grad
                if lmda != 0:
                    g = g + lmda  theta  # not inplace, since we're not changing theta.grad
                if mu != 0:
                    b.copy_(mu  b + g)  # needs to be inplace, since we're changing self.b value
                    g = b
                theta -= lr * g  # inplace operation, to modify params
tests.test_sgd_param_groups(SGD)
## Bonus - Muon Optimizer Hot off the press is a new optimizer called *Muon*. Muon is an optimizer specialized only for the parameters of a network that are *hidden* and *at least 2-dimensional*. * For image classification, we skip anything that directly interfaces with the input, or the output. * For language models, this means skipping the embedding and unembedding layers. The dimensionality requirement also means skipping biases, $\gamma$ or $\beta$ in layernorms, etc. All other parameters are optimized using Adam as per usual.
Comparison of Muon to AdamW optimizer for NanoGPT training speedrun. Taken from https://x.com/kellerjordan0/status/1842300916864844014
Muon has shown great promise for language models, shaving a massive 40%(!) off the training time for [nanoGPT speedrun](https://www.tylerromero.com/posts/nanogpt-speedrun-worklog/), a collaborative project to train a model as performance as GPT-2 as fast as possible, based on [Andrej Karpathy's nanoGPT implementation](https://github.com/karpathy/llm.c/discussions/481). We might make this into an actual exercise later, but for now, here's a series of resources should you wish to implement it yourself: * [Introduction to Muon](https://kellerjordan.github.io/posts/muon/) * [NanoGPT Speedrun Project](https://github.com/KellerJordan/modded-nanogpt/) Current word record is sub-3 minutes(!!)* - On a 8xH100 cluster, about 16 PFLOPS of power. On a single consumer GPU (RTX 3090), (assuming no issues with out-of-memory), with ~140TFLOPS of power, this would take ~5 hours, still incredibly impressive. * [Twitter thread on Muon](https://x.com/kellerjordan0/status/1842300916864844014) * [Derivation of Muon](https://jeremybernste.in/writing/deriving-muon) * Not required reading, but if you're curious about the math * [Reference Muon implementation in PyTorch](https://github.com/KellerJordan/Muon)