Exercise Status: All exercises complete and verified

3️⃣ GRPO LoRA

Learning Objectives
  • Understand and implement GRPO
  • Use GRPO + LoRA together to finetune a model.

Group Relative Policy Optimization (🚧 Under construction 🚧)

GRPO is a variant of PPO specialised for doing RLHF on LLMs. It was first described in Apr 2024 for use for fine-tuning DeepSeek to achieve better performance on tasks that require reasoning, by reinforcing rollouts that lead to correct answers.

The main differences between PPO and GRPO is that: * PPO uses a critic head to estimate the baseline. GRPO removes the critic entirely, and instead performs many rollouts, and uses the average reward over those rollouts as a baseline function. * PPO computes the advantages using GAE. GRPO simply uses the normalized rewards for the set of rollouts as the advantages.

Letting $o_{1:T}$ be a sequence of tokens, the joy function (loss function but we maximize) is given as

$$ J_{\text{GRPO}}(\theta) = \widehat{\mathbb{E}} \left[ \frac{1}{T} \sum_{t=1}^{T} \left( \min\Big[ \rho_\theta(o_t \mid o_{ where
$$ \rho_\theta(o_t \mid o_{ is the probability ratio of the new policy to the old policy.

The advantages are now just the normalized rewards:

$$ \hat{A}_t = (r_t - \text{mean}(\mathbf{r})) / \text{std}(\mathbf{r}) $$
where $\mu_r$ is the mean of the rewards vector, and $\sigma_r$ is the standard deviation of the rewards vector.

For the moment, we just superclass the existing TransformerWithValueHeadLora class and skip the value head. This is hacky, but it's a quick way to get the code working.

class TransformerWithLora(TransformerWithValueHeadLora):
    "We don't need the value head for training with GRPO"

    lora: nn.ModuleList
    lora_fwd_hooks: list[tuple[str, Callable]]
    dtype: t.dtype
    device: t.device

    def get_value_head_params(self):
        return iter([])  # no value head parameters

    @classmethod
    def from_pretrained(cls, *args, lora_alpha: float = 32, rank: int = 4, **kwargs):
        model = super(TransformerWithLora, cls).from_pretrained(*args, use_value_head=False, **kwargs)
        model.value_head_output = None
        return model

    @property
    def fwd_hooks(self):
        return self.lora_fwd_hooks  # no value head hook

    def forward_with_value_head(
        self, tokens: Int[Tensor, "batch seq"]
    ) -> tuple[Float[Tensor, "batch seq d_vocab"], Float[Tensor, "batch seq"]]:
        """
        Forward pass with LoRA enabled, but value head is not used.
        """
        logits, value = super().forward_with_value_head(tokens)
        assert value is None, "Value head got run somehow?"
        return logits

In GRPO-style training we optimize only the policy objective plus regularizers, without a value head or critic loss. This simplifies the architecture when your reward is available at the sequence level and you propagate it per generated token.

  • We re-use the optimizer and scheduler helpers.
  • In the rollout, we compute rewards per sample, optionally normalize them, and use them as advantages for all generated positions.
  • In learning, we maximize the clipped objective with an entropy bonus, and subtract the KL penalty computed against the frozen reference model.

Exercise: Construct GRPO trainer

Difficulty: 🔴🔴🔴🔴🔴
Importance: 🔵🔵🔵🔵🔵
You should spend up to 40 minutes on this exercise.

Construct a GRPO trainer class that inherits from RLHFTrainer and overrides the rollout_phase and learning_phase methods.

We recommend copying the solution for RLHFTrainer for PPO, and then modifying it to work for GRPO. This will msotly involve chopping parts out, or replacing parts (e.g. calculation of the advantage.)

You could also redefine calc_value_function_loss and compute_advantages, and then try to use RLHFTrainer as is.

The rough changes should be * Drop the value head and the associated critic loss * Use normalized rewards as the advantages. As advantages are of shape (minibatch, seq_len), and rewards are of shape (minibatch,), we need to deal with this somehow. Looking at Section 4 in the GRPO paper: - 4.1.2 Outcome Supervision: Treat each advantage as the reward we get at the end of the sequence. - Essentially we repeat the rewards for each token in the sequence. We use this approach. - 4.1.2 Process Supervision: Query the reward function for every prefix directly, and the advantage becomes the returns

$$ \hat{A}_t = \sum_{t'=t}^{T} \tilde{\mathbf{r}}_{t'} $$
where $\tilde{\mathbf{r}} = \mathbf{r} - \text{mean}(\mathbf{r}) / \text{std}(\mathbf{r})$ is the normalized rewards. This can get expensive if done on a per-token basis, but for large CoT's the generation is broken into "thoughts" rather than tokens (e.g. sentences?).

Hint for `compute_rlhf_objective`

If you modify TransformerWithLora to return a tensor of zeros of appropriate size, and redefine calc_value_function_loss to just return zero, you should be able to use compute_rlhf_objective as is. We don't do that here, but just redefine the function and remove parts.

@dataclass
class GrpoArgs(RLHFArgs):
    lora_rank: int = 4
    lora_alpha: float = 32


class GrpoTrainer(RLHFTrainer):
    model: TransformerWithLora
    memory: ReplayMemory

    def __init__(self, args: RLHFArgs):
        # duplicates code from RLHFTrainerLora
        t.manual_seed(args.seed)
        self.args = args
        self.run_name = f"{args.wandb_project_name}__seed{args.seed}__{time.strftime('%Y%m%d-%H%M%S')}"

        self.model = TransformerWithLora.from_pretrained(args.base_model).to(device).train()
        self.ref_model = self.model
        self.optimizer, self.scheduler = get_optimizer_and_scheduler(self.args, self.model)
        self.prefix_len = len(self.model.to_str_tokens(self.args.prefix, prepend_bos=self.args.prepend_bos))

    def compute_rlhf_objective(self, minibatch: ReplayMinibatch):
        raise NotImplementedError()

    def rollout_phase(self) -> ReplayMemory:
        raise NotImplementedError()
print("Training GRPO model (example setup)")
grpo_args = GrpoArgs(
    use_wandb=False,
    kl_coef=2.5,
    total_phases=30,
    warmup_steps=0,
    reward_fn=reward_fn_char_count,
    base_lr=1e-3,
    # batch_size=8,
    # num_minibatches=2,
    gen_len=16,
)
grpo_trainer = GrpoTrainer(grpo_args)
grpo_trainer.train()  # Uncomment to run a tiny smoke test
Solution
@dataclass
class GrpoArgs(RLHFArgs):
    lora_rank: int = 4
    lora_alpha: float = 32


class GrpoTrainer(RLHFTrainer):
    model: TransformerWithLora
    memory: ReplayMemory

    def __init__(self, args: RLHFArgs):
        # duplicates code from RLHFTrainerLora
        t.manual_seed(args.seed)
        self.args = args
        self.run_name = f"{args.wandb_project_name}__seed{args.seed}__{time.strftime('%Y%m%d-%H%M%S')}"

        self.model = TransformerWithLora.from_pretrained(args.base_model).to(device).train()
        self.ref_model = self.model
        self.optimizer, self.scheduler = get_optimizer_and_scheduler(self.args, self.model)
        self.prefix_len = len(self.model.to_str_tokens(self.args.prefix, prepend_bos=self.args.prepend_bos))

    def compute_rlhf_objective(self, minibatch: ReplayMinibatch):

        gen_len_slice = slice(-self.args.gen_len - 1, -1)

        logits, values = self.model.forward_with_value_head(minibatch.sample_ids)

        logprobs = get_logprobs(logits, minibatch.sample_ids, self.prefix_len)

        clipped_surrogate_objective = calc_clipped_surrogate_objective(
            logprobs,
            minibatch.logprobs,
            minibatch.advantages,
            self.args.clip_coef,
            self.args.gen_len,
        )
        entropy_bonus = calc_entropy_bonus(logits[:, gen_len_slice], self.args.ent_coef, self.args.gen_len)
        kl_penalty = calc_kl_penalty(
            logits[:, gen_len_slice],
            minibatch.ref_logits[:, gen_len_slice],
            self.args.kl_coef,
            self.args.gen_len,
        )

        ppo_objective_fn = clipped_surrogate_objective + entropy_bonus
        total_objective_function = ppo_objective_fn - kl_penalty

        if self.args.use_wandb:
            with t.inference_mode():
                logratio = logprobs - minibatch.logprobs
                ratio = logratio.exp()
                clipfracs = [((ratio - 1.0).abs() > self.args.clip_coef).float().mean().item()]
            wandb.log(
                dict(
                    total_steps=self.step,
                    lr=self.scheduler.get_last_lr()[0],
                    clipped_surrogate_objective=clipped_surrogate_objective.item(),
                    clipfrac=np.mean(clipfracs),
                    entropy_bonus=entropy_bonus.item(),
                    kl_penalty=kl_penalty.item(),
                ),
                step=self.step,
            )

        return total_objective_function

    def rollout_phase(self) -> ReplayMemory:

        sample_ids, samples = get_samples(
            self.model,
            prompt=self.args.prefix,
            batch_size=self.args.batch_size,
            gen_len=self.args.gen_len,
            temperature=self.args.temperature,
            top_k=self.args.top_k,
            prepend_bos=self.args.prepend_bos,
        )

        with t.inference_mode():
            logits, values = self.model.forward_with_value_head(sample_ids)
            ref_logits = self.ref_model(sample_ids)

        logprobs = get_logprobs(logits, sample_ids, self.prefix_len)

        rewards = self.args.reward_fn(samples)
        rewards_mean = rewards.mean().item()
        rewards_normed = normalize_reward(rewards) if self.args.normalize_reward else rewards

        advantages = rewards_normed

        if self.args.use_wandb:
            wandb.log({"mean_reward": rewards_mean}, step=self.step)

        n_log_samples = min(5, self.args.batch_size)
        ref_logprobs = get_logprobs(ref_logits[:n_log_samples], sample_ids[:n_log_samples], self.prefix_len).sum(-1)
        headers = ["Reward", "Ref logprobs", "Sample"]
        table_data = [[str(int(r)), f"{lp:.2f}", repr(s)] for r, lp, s in zip(rewards.tolist(), ref_logprobs, samples)]
        table = tabulate(table_data, headers, tablefmt="simple_grid", maxcolwidths=[None, None, 90])
        print(f"Phase {self.phase + 1:03}/{self.args.total_phases:03}, Mean reward: {rewards_mean:.4f}\n{table}\n")

        values = einops.repeat(advantages, "b -> b g", g=sample_ids.shape[1])
        advantages = einops.repeat(advantages, "b -> b g", g=logprobs.shape[1])
        return ReplayMemory(
            args=self.args,
            sample_ids=sample_ids,
            logprobs=logprobs,
            advantages=advantages,
            values=values,
            ref_logits=ref_logits,
        )