☆ Bonus
Congratulations for getting to the end of the main content! This section gives some suggestions for more features of Weights and Biases to explore, or some other experiments you can run.
Scaling Laws
These bonus exercises are taken directly from Jacob Hilton's online deep learning curriculum (which is what the original version of the ARENA course was based on).
First, you can start by reading the Chinchilla paper. This is a correction to the original scaling laws paper: parameter count scales linearly with token budget for compute-optimal models, not ~quadratically. The difference comes from using a separately-tuned learning rate schedule for each token budget, rather than using a single training run to measure performance for every token budget. This highlights the importance of hyperparameter tuning for measuring scaling law exponents.
You don't have to read the entire paper, just skim the graphs. Don't worry if they don't all make sense yet (it will be more illuminating when we study LLMs next week). Note that, although it specifically applies to language models, the key underlying ideas of tradeoffs between optimal dataset size and model size are generally applicable.
Suggested exercise
Perform your own study of scaling laws for MNIST.
- Write a script to train a small CNN on MNIST, or find one you have written previously.
- Training for a single epoch only, vary the model size and dataset size. For the model size, multiply the width by powers of sqrt(2) (rounding if necessary - the idea is to vary the amount of compute used per forward pass by powers of 2). For the dataset size, multiply the fraction of the full dataset used by powers of 2 (i.e. 1, 1/2, 1/4, ...). To reduce noise, use a few random seeds and always use the full validation set.
- The learning rate will need to vary with model size. Either tune it carefully for each model size, or use the rule of thumb that for Adam, the learning rate should be proportional to the initialization scale, i.e.
1/sqrt(fan_in)for the standard Kaiming He initialization (which is what PyTorch generally uses by default).- Note -
fan_inrefers to the variable $N_{in}$, which isin_featuresfor a linear layer, andin_channels * kernel_size * kernel_sizefor a convolutional layer - in other words, the number of input parameters/activations we take a sumproduct over to get each output activation.
- Note -
- Plot the amount of compute used (on a log scale) against validation loss. The compute-efficient frontier should follow an approximate power law (straight line on a log scale). How does validation accuracy behave?
- Study how the compute-efficient model size varies with compute. This should also follow an approximate power law. Try to estimate its exponent.
- Repeat your entire experiment with 20% dropout to see how this affects the scaling exponents.
Other WandB features
Here are a few more Weights & Biases features you might also want to play around with:
- Logging media and objects in experiments - you'll be doing this during the RL week, and it's useful when you're training generative image models like VAEs and diffusion models.
- Code saving - this captures all python source code files in the current director and all subdirectories. It's great for reproducibility, and also for sharing your code with others.
- Saving and loading PyTorch models - you can do this easily using
torch.save, but it's also possible to do this directly through Weights and Biases as an artifact.
The Optimizer's Curse
The optimizer's curse applies to tuning hyperparameters. The main take-aways are:
- You can expect your best hyperparameter combination to actually underperform in the future. You chose it because it was the best on some metric, but that metric has an element of noise/luck, and the more combinations you test the larger this effect is.
- Look at the overall trends and correlations in context and try to make sense of the values you're seeing. Just because you ran a long search process doesn't mean your best output is really the best.
For more on this, see Preventing "Overfitting" of Cross-Validation Data by Andrew Ng.