Exercise Status: All exercises complete and verified

☆ Bonus

Learning Objectives
  • Improve your RLHF implementation via techniques like differential learning rates, frozen layers, or adaptive KL penalties
  • Perform some exploratory mechanistic interpretability on RLHF'd models
  • Learn about the trlX library, which is designed to train transformers via RLHF in a way which abstracts away many of the low-level details

Extensions of today's RLHF exercises

Large models

We're already working with gpt2-medium which is considerably larger than most of the models you worked with in most of the transformers & interpretability material. Can you go even larger, e.g. gpt2-xl or more?

See this page for a table of model properties, for all models currently supported by TransformerLens. Note that if you use different model classes then you might need to change some parts of your code (e.g. if the name of the hook point where you added the value head happens to be different). You might also need to make other adjustments e.g. a smaller batch size (or a larger number of minibatches per batch, which is equivalent to smaller minibatch sizes).

Differential Learning Rates / Frozen Layers

When doing any kind of finetuning, it's common practice to either freeze earlier layers or have a smaller learning rate for them. You may have seen this in the feature extraction with ResNet34 exercises in the first week. In the exercises here we've trained all layers of the model equally, but you might want to play around with differential learning rates.

Note that you can accomplish this using parameter groups - we already used parameter groups above to have a different learning rate for our base model and value head. It should be relatively straightforward to extend this to splitting parameters over different layers into different groups (hint - you can use itertools.chain to convert several iterables into a single iterable).

You can also try entirely freezing earlier layers - this might also reduce your memory usage, and allow you to train larger models without getting cuda errors.

Hyperparameter sweeps

You can do this to find the best possible hyperparamters for your RLHF training. Don't just measure on reward, can you use some combination of reward and avg kl diff to create a better metric? Can you use wandb's built-in Bayesian search methods to more effectively sweep?

Note - don't forget temperature when it comes to hyperparameter tuning. Temperature has an important effect on how the model learns, e.g. if the temperature is too high then the model will produce very high-variance outputs which will have very high KL with the reference distribution, and it'll be more likely to collapse into some incoherent mode.

Adaptive KL penalty

The KL divergence penalty coefficient can be modified adaptively based on the KL divergence between the current policy and the previous policy. If the KL divergence is outside a predefined target range, we can adjust the penalty coefficient to bring it closer to the target range. Here is an example implementation:

class AdaptiveKLController:
    def __init__(self, init_kl_coef, hparams):
        self.value = init_kl_coef
        self.hparams = hparams

    def update(self, current, n_steps):
        target = self.hparams.target
        proportional_error = np.clip(current / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.hparams.horizon
        self.value *= mult

TRL / trlX

We've been focusing on building RLHF from the ground up, but there are several libraries which exist to abstract away manuy of the low-level implementation details we had to wrestle with. One of the best-known is TRL (Transformer Reinforcement Learning). The main docs page can be found here, and this page gives a quickstart guide. You may find it much easier to use this library than to implement everything yourself!

Read their documentation pages, and see what techniques they use to make RLHF more effective. Are there any that we haven't implemented here? Can you implement them yourself?

You might also be interested in trlX, an expanded fork of TRL built by CarperAI to handle larger models for online and offline training (although their APIs are pretty similar).

Learn a human preference reward model

We've been working with a pre-supplied reward function, but you can try and train your own!

We'll give some brief points of guidance here, for the task of training a reward function on the summarization task. Note that these instructions have been provided externally, so they've not yet been tested and might not work particularly well.

  1. Get a supervised baseline
    • Here is a link to download the dataset for the TL;DR challenge containing posts from the Reddit corpus. Each post contains keys content and summary which are the original post and the human-written summary respectively.
    • You should throw out all summaries shorter than 24 tokens or longer than 48 tokens (to diminish the effects of length on quality); and choose a random subset of ~100k summaries to train on.
    • Run training to maximize the log-likelihood of these summaries.
  2. Get reward model by training supervised baseline on human feedback
    • Download comparison data with the code azcopy copy "https://openaipublic.blob.core.windows.net/summarize-from-feedback/dataset/*" . --recursive
    • Modify GPT-2 architecture by adding a randomly-initialized reward head at the end of your model.
      • Architecturally this is similar to the value head from earlier, but it's not the same thing - here we're trying to learn what the human reward will be; we're not doing RL yet.
    • Train your model (starting with base model given by supervised baseline weights, and reward head randomly initialized) to minimize loss = log(sigmoid(reward_model(summary_0) - reward_model(summary_1))), summary_0 is preferred by a human labeler (this data should be in the comparison data you downloaded).
    • You should normalize reward model outputs, like we normalized rewards in RLHF in previous exercises.
  3. Fine-tune supervised baseline using PPO with reward model.
    • For these exercises we suggest using a larger model, ideally GPT2-Large or bigger. Remember you can freeze weights! Regardless, this will still take longer to train than your previous models.

Interp on RLHF'd models

Currently, very little mechanistic interpretability research ahs focused on RLHF'd models. In this blog post, Curt Tigges walks through an example of how we can use mech interp to analyze a model which has been finetuned with a sentiment based reward function using trlX.

The flavour of the actual mech interp done here is very similar to the indirect object identification exercises you might have done during the transformers & interp week. If you didn't do these exercises, we recommend you do them before diving deep into this material.

Lastly, here's a Google Doc brainstorming some ideas for RLHF interpretability. You might find some ideas there (although most of these will be pretty vague goals so possibly too ambitious for a bonus exercise or 1-week project).

Suggested paper replications

As well as the papers in this section, you might be interested in browsing this GitHub repo, which contains links to a large number of RLHF-related papers.

Deep Reinforcement Learning from Human Preferences

This was the seminal paper in RLHF. They applied it to the domain of tasks like MuJoCo (which you might already have worked with during your PPO day). Can you set up a reward function and an interface which allows you to choose between two different sets of trajectories, and learn a reward function to maximize?

Some more technical details here - the authors train the reward function at the same time as they train the model. In other words, after a certain number of iterations of (rollout phase, learning phase), they add a third reward model learning phase, where the current policy generates many pairs of trajectories of some fixed timestep and the human rater chooses which one is best. They famously trained the Hopper agent to perform repeated backflips using just 900 queries.

Here is the link mentioned in the image caption.

Note - we strongly recommend doing the PPO exercises on MuJoCo before attempting this replication. We also recommend using Colab, since MuJoCo is notoriously difficult to install all the dependencies for!

Recursively Summarizing Books with Human Feedback

A major challenge for scaling ML is training models to perform tasks that are very difficult or time-consuming for humans to evaluate. To test scalable alignment techniques, the authors trained a model to summarize entire books, by first summarizing small sections of a book, then summarizing those summaries into a higher-level summary, and so on. A demonstration can be found here. There is also a repository containing code to run their models, including the supervised baseline, the trained reward model, and the RL fine tuned policy.

You may also wish to do this in a less directed way - see the bonus exercise “Learn a human preference reward model” above.

Extentions for LoRA

Extend Lora to MLP layers (optional)

This is a pretty finicky exercise, and mostly involves looking up the various locations you can add hook functions to. Modify HookedTransformer to add extra LoRA layers across the MLP project up and project down layers.

My solution doesn't work for Llama models!

For non-GPT2 models this is even more annoying, as the architecture is different and involves Gated Linear Units (GLUs). We leave this to you to work out.

Mixed Precision (optional)

We can squeeze even more out of the training process by training in mixed precision. We can load the models in bfloat16, and train with that instead of float32.

Due to this issue on mixed precision, TransformerLens uses float32 for LayerNorms even if the rest of the model is using bfloat16 or float16. This means the intermediate activations are in float32, even though the weights are in bfloat16.

complete LoraMixedPrecision

To handle this, you should modify hooks to:

  • store the original dtype of the input
  • use the passed in dtype to convert the input
  • do the normal computations
  • convert the output back to the original dtype

You can use super().lora_hook_qkv(act, hook) to call the original lora_hook_qkv method, and then wrap this in the appropriate code to cast the types back and forth.

class LoraHooksMixedPrecision(LoraHooks):
    """
    Defines the LoRA hooks needed for the Attention layer of the transformer, but allow for mixed precision.
    """

    def lora_hook_qkv(
        self, 
        resid_pre_normed: Float[Tensor, "batch pos d_model"], 
        hook: HookPoint
    ) -> Float[Tensor, "batch pos n_heads d_head"]:

        raise NotImplementedError()

    def lora_hook_out(
        self, attn_out: Float[Tensor, "batch pos n_heads d_head"], hook: HookPoint
    ) -> Float[Tensor, "batch pos n_heads d_head"]:
        raise NotImplementedError()
Solution
class LoraHooksMixedPrecision(LoraHooks):
    """
    Defines the LoRA hooks needed for the Attention layer of the transformer, but allow for mixed precision.
    """

    def lora_hook_qkv(
        self, 
        resid_pre_normed: Float[Tensor, "batch pos d_model"], 
        hook: HookPoint
    ) -> Float[Tensor, "batch pos n_heads d_head"]:

        # EXERCISE
        # raise NotImplementedError()
        # END EXERCISE
        # SOLUTION
        hook_location = hook.name.split(".")[-1]

        orig_dtype = resid_pre_normed.dtype
        resid_pre_normed = resid_pre_normed.to(self.dtype)

        super().lora_hook_qkv(resid_pre_normed, hook)

        lora_qkv_out = lora_qkv_out.to(orig_dtype)
        return lora_qkv_out
        # END SOLUTION

    def lora_hook_out(
        self, attn_out: Float[Tensor, "batch pos n_heads d_head"], hook: HookPoint
    ) -> Float[Tensor, "batch pos n_heads d_head"]:

        # EXERCISE
        # raise NotImplementedError()
        # END EXERCISE
        # SOLUTION
        orig_dtype = attn_out.dtype
        attn_out = attn_out.to(self.dtype)

        super().lora_hook_out(attn_out, hook)

        lora_attn_out = lora_attn_out.to(orig_dtype)
        return lora_attn_out
        # END SOLUTION