[2.4] - RLHF
Please send any problems / bugs on the #errata channel in the Slack group, and ask any questions on the dedicated channels for this chapter of material.
If you want to change to dark mode, you can do this by clicking the three horizontal lines in the top-right, then navigating to Settings → Theme.
Links to all other chapters: (0) Fundamentals, (1) Transformer Interpretability, (2) RL.

Introduction
This section is designed to take you through a full implementation of RLHF (Reinforcement Learning from Human Feedback). Much of this follows on directly from the PPO implementation from yesterday, with only a few minor adjustments and new concepts. You'll (hopefully) be pleased to learn that we're disposing of OpenAI's gym environment for this final day of exercises, and instead going back to our week 1 roots with TransformerLens!
We'll start by discussing how the RL setting we've used for tasks like CartPole and Atari fits into the world of autoregressive transformer language models. We'll then go through standard parts of the PPO setup (e.g. objective function, memory buffer, rollout and learning phases) and show how to adapt them for our transformer. Finally, we'll put everything together into a RLHFTrainer class, and perform RLHF on our transformer!
Note - these exercises assume you're running on an A100 (either a virtual machine or Colab Pro+). If you're running on machine with much less VRAM (<24GB), we recommend setting
LOW_GPU_MEM = Truebelow. This will switch the model to RLHF from"gpt2-medium"to"gpt2-small", as well as adjust some other parameters like the batch size, the number of tokens generated, and some hyperparameters.
For a lecture on the material today, which provides some high-level understanding before you dive into the material, watch the video below:
Content & Learning Objectives
1️⃣ RLHF on transformer language models
Most of the exercises today build towards the implementation of the RLHFTrainer class, similar to how DQN and PPO have worked these last few days.
Learning Objectives
- Understand how the RL agent / action / environment paradigm works in the context of autoregressive transformer models
- Understand how the RLHF algorithm works, and how it fits on top of PPO
- Learn about value heads, and how they can be used to turn transformers into actor & critic networks with shared architectures
- Write a full RLHF training loop, and use it to train your transformer with the "maximize output of periods" reward function
- Observe and understand the instances of mode collapse that occur when training with this reward function
- Experiment with different reward functions & training hyperparameters
2️⃣ LoRA
Learning Objectives
- Understand the mechanism behind Low-Rank Adaptors, and how they allow for fine-tuning with less resources.
- Implement LoRA in a transformer model.
- Fine-tune larger models that would otherwise take too much VRAM to be possible.
3️⃣ GRPO LoRA
GRPO is a variant of PPO specialised for doing RLHF on LLMs. It forgoes the critic, and uses the average reward over many rollouts as a baseline instead.
Learning Objectives
- Understand and implement GRPO
- Use GRPO + LoRA together to finetune a model.
☆ Bonus
This section offers some suggested ways to extend the core RLHF exercises.
Learning Objectives
- Improve your RLHF implementation via techniques like differential learning rates, frozen layers, or adaptive KL penalties
- Perform some exploratory mechanistic interpretability on RLHF'd models
- Learn about the trlX library, which is designed to train transformers via RLHF in a way which abstracts away many of the low-level details
Reading
- Illustrating Reinforcement Learning from Human Feedback (RLHF) (~10 minutes)
- An accessible and mostly non-technical introduction to RLHF, which discusses it in context of the full pipeline for training autoregressive transformer language models (starting with pretraining, which is what we did in the first day of last week).
- RLHF+ChatGPT: What you must know (~5 minutes)
- The first half of this video provides a high-level overview of RLHF, discussing things like mode collapse, and relates this to the shoggoth meme that many of you have likely seen!
- DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models (~20 minutes)
- Save reading this now until you get to the section for GRPO, and skim as required.
Setup code
import os
import sys
import time
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Literal
import einops
import numpy as np
import torch as t
import torch.nn as nn
import wandb
from eindex import eindex
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from tabulate import tabulate
from torch import Tensor
from tqdm import tqdm
from transformer_lens import HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint
# Make sure exercises are in the path
chapter = "chapter2_rl"
section = "part4_rlhf"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
sys.path.append(str(exercises_dir))
from part4_rlhf import tests, tests_lora # , tl_ext
device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")
MAIN = __name__ == "__main__"