3️⃣ GRPO LoRA
Learning Objectives
- Understand and implement GRPO
- Use GRPO + LoRA together to finetune a model.
Group Relative Policy Optimization (🚧 Under construction 🚧)
GRPO is a variant of PPO specialised for doing RLHF on LLMs. It was first described in Apr 2024 for use for fine-tuning DeepSeek to achieve better performance on tasks that require reasoning, by reinforcing rollouts that lead to correct answers.
The main differences between PPO and GRPO is that: * PPO uses a critic head to estimate the baseline. GRPO removes the critic entirely, and instead performs many rollouts, and uses the average reward over those rollouts as a baseline function. * PPO computes the advantages using GAE. GRPO simply uses the normalized rewards for the set of rollouts as the advantages.

Letting $o_{1:T}$ be a sequence of tokens, the joy function (loss function but we maximize) is given as
The advantages are now just the normalized rewards:
For the moment, we just superclass the existing TransformerWithValueHeadLora class and skip the value head. This is hacky, but it's a quick way to get the code working.
class TransformerWithLora(TransformerWithValueHeadLora):
"We don't need the value head for training with GRPO"
lora: nn.ModuleList
lora_fwd_hooks: list[tuple[str, Callable]]
dtype: t.dtype
device: t.device
def get_value_head_params(self):
return iter([]) # no value head parameters
@classmethod
def from_pretrained(cls, *args, lora_alpha: float = 32, rank: int = 4, **kwargs):
model = super(TransformerWithLora, cls).from_pretrained(*args, use_value_head=False, **kwargs)
model.value_head_output = None
return model
@property
def fwd_hooks(self):
return self.lora_fwd_hooks # no value head hook
def forward_with_value_head(
self, tokens: Int[Tensor, "batch seq"]
) -> tuple[Float[Tensor, "batch seq d_vocab"], Float[Tensor, "batch seq"]]:
"""
Forward pass with LoRA enabled, but value head is not used.
"""
logits, value = super().forward_with_value_head(tokens)
assert value is None, "Value head got run somehow?"
return logits
In GRPO-style training we optimize only the policy objective plus regularizers, without a value head or critic loss. This simplifies the architecture when your reward is available at the sequence level and you propagate it per generated token.
- We re-use the optimizer and scheduler helpers.
- In the rollout, we compute rewards per sample, optionally normalize them, and use them as advantages for all generated positions.
- In learning, we maximize the clipped objective with an entropy bonus, and subtract the KL penalty computed against the frozen reference model.
Exercise: Construct GRPO trainer
Construct a GRPO trainer class that inherits from RLHFTrainer and overrides the rollout_phase and learning_phase methods.
We recommend copying the solution for RLHFTrainer for PPO, and then modifying it to work for GRPO.
This will msotly involve chopping parts out, or replacing parts (e.g. calculation of the advantage.)
You could also redefine calc_value_function_loss and compute_advantages, and then try to use RLHFTrainer as is.
The rough changes should be
* Drop the value head and the associated critic loss
* Use normalized rewards as the advantages. As advantages are of shape (minibatch, seq_len), and rewards are of shape (minibatch,), we need to deal with this somehow. Looking at Section 4 in the GRPO paper:
- 4.1.2 Outcome Supervision: Treat each advantage as the reward we get at the end of the sequence.
- Essentially we repeat the rewards for each token in the sequence. We use this approach.
- 4.1.2 Process Supervision: Query the reward function for every prefix directly, and the advantage becomes the returns
Hint for `compute_rlhf_objective`
If you modify TransformerWithLora to return a tensor of zeros of appropriate size, and redefine calc_value_function_loss to just return zero, you should be able to use compute_rlhf_objective as is.
We don't do that here, but just redefine the function and remove parts.
@dataclass
class GrpoArgs(RLHFArgs):
lora_rank: int = 4
lora_alpha: float = 32
class GrpoTrainer(RLHFTrainer):
model: TransformerWithLora
memory: ReplayMemory
def __init__(self, args: RLHFArgs):
# duplicates code from RLHFTrainerLora
t.manual_seed(args.seed)
self.args = args
self.run_name = f"{args.wandb_project_name}__seed{args.seed}__{time.strftime('%Y%m%d-%H%M%S')}"
self.model = TransformerWithLora.from_pretrained(args.base_model).to(device).train()
self.ref_model = self.model
self.optimizer, self.scheduler = get_optimizer_and_scheduler(self.args, self.model)
self.prefix_len = len(self.model.to_str_tokens(self.args.prefix, prepend_bos=self.args.prepend_bos))
def compute_rlhf_objective(self, minibatch: ReplayMinibatch):
raise NotImplementedError()
def rollout_phase(self) -> ReplayMemory:
raise NotImplementedError()
print("Training GRPO model (example setup)")
grpo_args = GrpoArgs(
use_wandb=False,
kl_coef=2.5,
total_phases=30,
warmup_steps=0,
reward_fn=reward_fn_char_count,
base_lr=1e-3,
# batch_size=8,
# num_minibatches=2,
gen_len=16,
)
grpo_trainer = GrpoTrainer(grpo_args)
grpo_trainer.train() # Uncomment to run a tiny smoke test
Solution
@dataclass
class GrpoArgs(RLHFArgs):
lora_rank: int = 4
lora_alpha: float = 32
class GrpoTrainer(RLHFTrainer):
model: TransformerWithLora
memory: ReplayMemory
def __init__(self, args: RLHFArgs):
# duplicates code from RLHFTrainerLora
t.manual_seed(args.seed)
self.args = args
self.run_name = f"{args.wandb_project_name}__seed{args.seed}__{time.strftime('%Y%m%d-%H%M%S')}"
self.model = TransformerWithLora.from_pretrained(args.base_model).to(device).train()
self.ref_model = self.model
self.optimizer, self.scheduler = get_optimizer_and_scheduler(self.args, self.model)
self.prefix_len = len(self.model.to_str_tokens(self.args.prefix, prepend_bos=self.args.prepend_bos))
def compute_rlhf_objective(self, minibatch: ReplayMinibatch):
gen_len_slice = slice(-self.args.gen_len - 1, -1)
logits, values = self.model.forward_with_value_head(minibatch.sample_ids)
logprobs = get_logprobs(logits, minibatch.sample_ids, self.prefix_len)
clipped_surrogate_objective = calc_clipped_surrogate_objective(
logprobs,
minibatch.logprobs,
minibatch.advantages,
self.args.clip_coef,
self.args.gen_len,
)
entropy_bonus = calc_entropy_bonus(logits[:, gen_len_slice], self.args.ent_coef, self.args.gen_len)
kl_penalty = calc_kl_penalty(
logits[:, gen_len_slice],
minibatch.ref_logits[:, gen_len_slice],
self.args.kl_coef,
self.args.gen_len,
)
ppo_objective_fn = clipped_surrogate_objective + entropy_bonus
total_objective_function = ppo_objective_fn - kl_penalty
if self.args.use_wandb:
with t.inference_mode():
logratio = logprobs - minibatch.logprobs
ratio = logratio.exp()
clipfracs = [((ratio - 1.0).abs() > self.args.clip_coef).float().mean().item()]
wandb.log(
dict(
total_steps=self.step,
lr=self.scheduler.get_last_lr()[0],
clipped_surrogate_objective=clipped_surrogate_objective.item(),
clipfrac=np.mean(clipfracs),
entropy_bonus=entropy_bonus.item(),
kl_penalty=kl_penalty.item(),
),
step=self.step,
)
return total_objective_function
def rollout_phase(self) -> ReplayMemory:
sample_ids, samples = get_samples(
self.model,
prompt=self.args.prefix,
batch_size=self.args.batch_size,
gen_len=self.args.gen_len,
temperature=self.args.temperature,
top_k=self.args.top_k,
prepend_bos=self.args.prepend_bos,
)
with t.inference_mode():
logits, values = self.model.forward_with_value_head(sample_ids)
ref_logits = self.ref_model(sample_ids)
logprobs = get_logprobs(logits, sample_ids, self.prefix_len)
rewards = self.args.reward_fn(samples)
rewards_mean = rewards.mean().item()
rewards_normed = normalize_reward(rewards) if self.args.normalize_reward else rewards
advantages = rewards_normed
if self.args.use_wandb:
wandb.log({"mean_reward": rewards_mean}, step=self.step)
n_log_samples = min(5, self.args.batch_size)
ref_logprobs = get_logprobs(ref_logits[:n_log_samples], sample_ids[:n_log_samples], self.prefix_len).sum(-1)
headers = ["Reward", "Ref logprobs", "Sample"]
table_data = [[str(int(r)), f"{lp:.2f}", repr(s)] for r, lp, s in zip(rewards.tolist(), ref_logprobs, samples)]
table = tabulate(table_data, headers, tablefmt="simple_grid", maxcolwidths=[None, None, 90])
print(f"Phase {self.phase + 1:03}/{self.args.total_phases:03}, Mean reward: {rewards_mean:.4f}\n{table}\n")
values = einops.repeat(advantages, "b -> b g", g=sample_ids.shape[1])
advantages = einops.repeat(advantages, "b -> b g", g=logprobs.shape[1])
return ReplayMemory(
args=self.args,
sample_ids=sample_ids,
logprobs=logprobs,
advantages=advantages,
values=values,
ref_logits=ref_logits,
)