2️⃣ LoRA Fine-Tuning
Learning Objectives
- Understand the mechanism behind Low-Rank Adaptors, and how they allow for fine-tuning with less resources.
- Implement LoRA in a transformer model.
- Fine-tune larger models that would otherwise take too much VRAM to be possible.
Go to the RLHF Training Args class we defined at the start of the previous section and set RUN_BASE_RLHF = False. This will skip all the expensive training runs for Section 1, so you can easily rerun the file.
Low-Rank Adaptors (🚧 Under construction 🚧)
For the previous section, we required to keep two copies of the model in memory: $\pi_{ppo}$ to train, and $\pi_{base}$ as a reference. Now, if our models are already large enough that it's maxing out the VRAM, we obviously can't realistically keep two copies of the model in memory.
Moreover, what is found in practice is that often the changes to the model are very minor, in that the activations before and after fine-tuning tend to be only a low-rank transformation. So, why not setup the fine-tuning setup such that only a low-rank transformation to the weights can be learned?
LoRA (Low-Rank Adaptator) allows us to fine-tune only a small number of additional parameters and keep the rest of the parameters fixed. We see in the diagram below that LoRA:
- Keeps the original linear layer weights $W : (d_{in} \times d_{out})$ fixed
- Adds additional low-rank matricies $A : (d_{in} \times r)$ and $B :(r \times d_{out})$, such that the fine-tuned linear layer $\tilde{W}$ performs the operation $\tilde{W}(x) =W(x) + B(A(x))$.

Once training is done, we note that since both operations are linear, the final output is equivalent to adding a low-rank matrix to $W$. Once training is complete, one can set $\tilde{W} = W + BA$, which "bakes" the adaptor into the model. This means the final-finetuned model is architectually identical to the original.
Exercise - complete Lora
Now, you'll implement the Lora class. This class should implement the basic LoRA block, which is a low-rank linear layer (no bias) written as two seperate matricies, a project-down matrix $A$ and a project-up matrix $B$.
For simplicity, the Lora module actually handles n_inst instances of a low-rank linear layer, to make it easier to interface with multi-head attention later on.
You should
* Finish the __init__ method to define the model parameters A and B, and initalize then:
* A should be initialized with kaiming_uniform_ with $a = \sqrt{5}$.
* B should be initialized with zeros.
* Implement the forward method to compute the forward pass of the LoRA block f(x) = (x @ A) @ B * lora_alpha / rank.
* The larger the rank, the more we scale down the effect of the LoRA block.
* lora_alpha is a hyperparameter that controls the scale of the LoRA block, usually quite large (32).
Why no bias?
There is no matrix $\tilde{W}$ for which $\tilde{W}x = Wx + b$ unless $b = 0$. Matrix multiplication is not an affine transformation.
Why `(x @ A) @ B` over `x @ (A @ B)`?
The product A @ B is of shape (n_inst, d_in, d_out), making it quite a large matrix, but it can be of rank only at most r,
so it is very wasteful to perfom the multiplication in this order.
class Lora(nn.Module):
"""
Module that implements the basic LoRA block.
- Input: tensor of shape (..., [inst], d_in) and returns a tensor of shape (..., inst, d_out).
- Calculated intermediate activations of shape (..., inst, rank)
- Output: tensor of shape (..., inst, d_out)
"""
A: nn.Parameter # (n_inst, d_in, rank)
B: nn.Parameter # (n_inst, rank, d_out)
def __init__(
self,
d_in: int = 768,
d_out: int = 768,
rank: int = 4,
lora_alpha: float = 32,
n_inst: int | None = None,
dtype: t.dtype | None = None,
):
"""
Initialize the weights of the LoRA block.
- The A block should be initialized with kaiming uniform with a=sqrt(5)
- The B block should be initialized with zeros.
"""
super().__init__()
self.rank = rank
self.d_in = d_in
self.d_out = d_out
self.n_inst = 1 if n_inst is None else n_inst
self.lora_alpha = lora_alpha
self.dtype = dtype
# Define the model parameters here
raise NotImplementedError()
def forward(self, x: Float[Tensor, "... inst d_in"]) -> Float[Tensor, "... inst d_out"]:
"""
Computes the forward pass of the LoRA block f(x) = (x @ A) @ B * lora_alpha / rank
Args:
x: Tensor of shape (..., inst, d_in)
Returns:
out (..., inst, d_out) such that out[..., i, :] = (x[..., i] @ A[i]) @ B[i] * lora_alpha / rank
"""
if x.dtype != self.dtype:
x = x.to(self.dtype)
assert x.shape[-2] == self.n_inst or x.shape[-2] == 1, (
f"Expected inst dim {self.n_inst} or 1, got {x.shape[-2]}. (input shape was {x.shape=})"
)
raise NotImplementedError()
return out * self.lora_alpha / self.rank
model = HookedTransformer.from_pretrained("pythia-14m")
tests_lora.testing_lora(Lora)
Solution
class Lora(nn.Module):
"""
Module that implements the basic LoRA block.
- Input: tensor of shape (..., [inst], d_in) and returns a tensor of shape (..., inst, d_out).
- Calculated intermediate activations of shape (..., inst, rank)
- Output: tensor of shape (..., inst, d_out)
"""
A: nn.Parameter # (n_inst, d_in, rank)
B: nn.Parameter # (n_inst, rank, d_out)
def __init__(
self,
d_in: int = 768,
d_out: int = 768,
rank: int = 4,
lora_alpha: float = 32,
n_inst: int | None = None,
dtype: t.dtype | None = None,
):
"""
Initialize the weights of the LoRA block.
- The A block should be initialized with kaiming uniform with a=sqrt(5)
- The B block should be initialized with zeros.
"""
super().__init__()
self.rank = rank
self.d_in = d_in
self.d_out = d_out
self.n_inst = 1 if n_inst is None else n_inst
self.lora_alpha = lora_alpha
self.dtype = dtype
# Define the model parameters here
self.A = nn.Parameter(t.empty(self.n_inst, d_in, rank, dtype=dtype))
self.B = nn.Parameter(t.zeros(self.n_inst, rank, d_out, dtype=dtype))
nn.init.kaiming_uniform_(self.A, a=5**0.5)
def forward(self, x: Float[Tensor, "... inst d_in"]) -> Float[Tensor, "... inst d_out"]:
"""
Computes the forward pass of the LoRA block f(x) = (x @ A) @ B * lora_alpha / rank
Args:
x: Tensor of shape (..., inst, d_in)
Returns:
out (..., inst, d_out) such that out[..., i, :] = (x[..., i] @ A[i]) @ B[i] * lora_alpha / rank
"""
if x.dtype != self.dtype:
x = x.to(self.dtype)
assert x.shape[-2] == self.n_inst or x.shape[-2] == 1, (
f"Expected inst dim {self.n_inst} or 1, got {x.shape[-2]}. (input shape was {x.shape=})"
)
# force order of operations (x A) B
tmp = einops.einsum(x, self.A, "... inst d_in, inst d_in rank -> ... inst rank")
out = einops.einsum(tmp, self.B, "... inst rank, inst rank d_out -> ... inst d_out")
return out * self.lora_alpha / self.rank
model = HookedTransformer.from_pretrained("pythia-14m")
tests_lora.testing_lora(Lora)
Attention LoRA
Next, we want to add code that adds the Low-Rank Adaptator to the attention layers. The original implementation of LoRA adds low-rank adaptors across all matrices in the attention layers ($W_Q, W_K, W_V, W_O$), leaving the MLP layers untouched.
Due to the way TransformerLens is implemented, we need to add hooks that cache the inputs to the attention sublayers, and then separately add hooks to modify the output of the attention sublayers, by loading the cached input from earlier, running it through the adaptor, and then adding it back to the output.
That is, we need to:
- Store the input to attention $W_Q$, $W_K$, and $W_V$ sublayers (all three recieve the same input,
normalized) as well as the input to the $W_O$ sublayer (z) - Modify the output of the attention $W_Q$, $W_K$, $W_V$, and $W_O$ sublayers (
q,k,v,attn_out) by- Running the cached input through the adaptor, and
- Adding the output from the adaptor to the original output of the module, and returning the result instead.

Exercise - complete LoraHooks
Now you'll implement the LoraHooks class. This class should define LoRA modules for the linear projections inside the attention layer of the transformer.
The following methods have been implemented for you:
- The hook function
store_hook_attn_normalizedshould cache the input to querys, keys, and values. - The hook function
store_hook_zshould cache the input to $W_O$. - The method
list_fwd_hooksshould return a list of hook_point names and functions to call for the forward pass of the model using LoRA.
You should
* Define self.lora_q, self.lora_k, self.lora_v, self.lora_o of appropriate sizes.
* Implement the lora_hook_qkv method to apply the LoRA modules to the input to the attention layer.
- This function should check the hook location (hook.name) and apply the appropriate LoRA modules to the appropriate input.
- Note that normalized : Float[Tensor, "batch pos d_model"] is the input to the attention layer, which we need to repeat for each head before passing to the LoRA modules.
* Implement the lora_hook_out method to apply the LoRA modules to the output of the attention layer.
- Note that transformer_lens doesn't hook the output of $W_O$ before the heads are summed over. Lucky for us, LoRA performs a linear operation, so we can sum the output of the LoRA model for $W_0$ over each head, and then add to the output!
For example, for two heads:
What's the deal with n_qo_heads and n_kv_heads?
TL;DR: All you need to know is use n_qo_heads for the number of query heads and output heads, and n_kv_heads for the number of key and value heads.
For gpt2, we have the same number of heads in each layer, and each head has linear projections
* $W_Q : (n_{heads},d_{model}, d_{head})$
* $W_K : (n_{heads},d_{model}, d_{head})$
* $W_V : (n_{heads},d_{model}, d_{head})$
* $W_O : (n_{heads}, d_{head}, d_{model})$
This turns out to be costly on memory when using KV-caching, so a solution proposed was grouped-query attention, where there are fewer key and value heads than query heads. The query heads are put into groups, and each group shares the same key and value head.

We can see this in the family of Llama models which makes use of this technique.
model = transformer_lens.HookedTransformer.from_pretrained("meta-llama/Llama-3.2-1B")
print(f"{model.cfg.n_heads=}")
print(f"{model.cfg.n_key_value_heads=}")
print(f"Group size: {model.cfg.n_heads // model.cfg.n_key_value_heads}")
print(f"{(model.W_K[0][:4] == model.W_K[0][0]).all()=}")
Loaded pretrained model meta-llama/Llama-3.2-1B into HookedTransformer model.cfg.n_heads=32 model.cfg.n_key_value_heads=8 Group size: 4 (model.W_K[0][:4] == model.W_K[0][0]).all()=tensor(True, device='cuda:0')
Note that transformer lens presents W_K and W_V as the same shape as W_Q, but this is a lie, they are only a repeated view of the true key and value matricies, which are smaller. We can see this by looking inside the attention layer: _W_K is the real weights, and W_K is a repeated view of it.
print(f"{model.blocks[0].attn._W_K.shape=}")
print(f"{model.blocks[0].attn.W_K.shape=}")
model.blocks[0].attn._W_K.shape=torch.Size([8, 2048, 64]) model.blocks[0].attn.W_K.shape=torch.Size([32, 2048, 64])
class LoraHooks(nn.Module):
"""
Defines the LoRA hooks needed for the Attention Layers of the transformer.
(Could be modified to add LoRA to the MLP layers)
"""
lora_q: Lora
lora_k: Lora
lora_v: Lora
lora_o: Lora
cache_qkv_in: Float[Tensor, "batch pos d_model"] = None
cache_z: Float[Tensor, "batch pos n_heads d_head"] = None
def __init__(
self,
layer_idx: int,
cfg: HookedTransformerConfig,
lora_alpha: float = 32,
rank: int = 4,
dtype: t.dtype = None,
):
super().__init__()
self.layer_idx = layer_idx
self.rank = rank
self.lora_alpha = lora_alpha
self.dtype = dtype
self.n_qo_heads = n_qo_heads = cfg.n_heads
self.n_kv_heads = n_kv_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads
d_model, d_head = cfg.d_model, cfg.d_head
raise NotImplementedError()
def store_hook_attn_normalized(self, normalized: Float[Tensor, "batch pos d_model"], hook: HookPoint) -> None:
"""
Cache the input to query/key/value.
"""
self.cache_qkv_in = normalized
def store_hook_z(self, z: Float[Tensor, "batch pos n_heads d_head"], hook: HookPoint) -> None:
"""
Cache the input to $W_O$.
"""
self.cache_z = z
def list_fwd_hooks(self) -> list[tuple[str, Callable]]:
"""
Returns a list of hook_point names and functions to call for the forward pass of
the model using LoRA.
"""
fwd_hooks = []
# Attention Hooks qkv
fwd_hooks.append((f"blocks.{self.layer_idx}.ln1.hook_normalized", self.store_hook_attn_normalized))
fwd_hooks.append((f"blocks.{self.layer_idx}.attn.hook_q", self.lora_hook_qkv))
fwd_hooks.append((f"blocks.{self.layer_idx}.attn.hook_k", self.lora_hook_qkv))
fwd_hooks.append((f"blocks.{self.layer_idx}.attn.hook_v", self.lora_hook_qkv))
# Attention Hooks z/out
fwd_hooks.append((f"blocks.{self.layer_idx}.attn.hook_z", self.store_hook_z))
fwd_hooks.append((f"blocks.{self.layer_idx}.hook_attn_out", self.lora_hook_out))
return fwd_hooks
def lora_hook_qkv(
self, qkv_hook_out: Float[Tensor, "batch pos n_heads d_head"], hook: HookPoint
) -> Float[Tensor, "batch pos n_heads d_head"]:
"""
Applies the LoRA modules to query/key/value, based on the hook location.
Args:
hook_qkv_out: Float[Tensor, "batch pos n_heads d_head"]
The original output from query/key/value.
hook: HookPoint
Returns:
The original output from query/key/value, plus the output from the corresponding LoRA module.
"""
raise NotImplementedError()
def lora_hook_out(
self, attn_out: Float[Tensor, "batch pos n_heads d_head"], hook: HookPoint
) -> Float[Tensor, "batch pos n_heads d_head"]:
"""
Applies the LoRA modules to the output projection matrix W_O in the attention layer.
The output of the LoRA module is computed per head, so we sum over heads before adding
to the activation `attn_out`.
Args:
attn_out: Float[Tensor, "batch pos n_heads d_head"]
The output from the attention layer.
hook: HookPoint
Returns:
The original output from the attention layer, plus the output from the LoRA module.
"""
raise NotImplementedError()
tests_lora.testing_lora_hooks(LoraHooks)
tests_lora.testing_lora_hooks_qkv_dispatch_and_out(LoraHooks)
print("All tests for LoraHooks passed!")
Solution
class LoraHooks(nn.Module):
"""
Defines the LoRA hooks needed for the Attention Layers of the transformer.
(Could be modified to add LoRA to the MLP layers)
"""
lora_q: Lora
lora_k: Lora
lora_v: Lora
lora_o: Lora
cache_qkv_in: Float[Tensor, "batch pos d_model"] = None
cache_z: Float[Tensor, "batch pos n_heads d_head"] = None
def __init__(
self,
layer_idx: int,
cfg: HookedTransformerConfig,
lora_alpha: float = 32,
rank: int = 4,
dtype: t.dtype = None,
):
super().__init__()
self.layer_idx = layer_idx
self.rank = rank
self.lora_alpha = lora_alpha
self.dtype = dtype
self.n_qo_heads = n_qo_heads = cfg.n_heads
self.n_kv_heads = n_kv_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads
d_model, d_head = cfg.d_model, cfg.d_head
self.lora_q = Lora(d_model, d_head, n_inst=n_qo_heads, rank=rank, lora_alpha=lora_alpha, dtype=dtype)
self.lora_k = Lora(d_model, d_head, n_inst=n_kv_heads, rank=rank, lora_alpha=lora_alpha, dtype=dtype)
self.lora_v = Lora(d_model, d_head, n_inst=n_kv_heads, rank=rank, lora_alpha=lora_alpha, dtype=dtype)
self.lora_o = Lora(d_head, d_model, n_inst=n_qo_heads, rank=rank, lora_alpha=lora_alpha, dtype=dtype)
def store_hook_attn_normalized(self, normalized: Float[Tensor, "batch pos d_model"], hook: HookPoint) -> None:
"""
Cache the input to query/key/value.
"""
self.cache_qkv_in = normalized
def store_hook_z(self, z: Float[Tensor, "batch pos n_heads d_head"], hook: HookPoint) -> None:
"""
Cache the input to $W_O$.
"""
self.cache_z = z
def list_fwd_hooks(self) -> list[tuple[str, Callable]]:
"""
Returns a list of hook_point names and functions to call for the forward pass of
the model using LoRA.
"""
fwd_hooks = []
# Attention Hooks qkv
fwd_hooks.append((f"blocks.{self.layer_idx}.ln1.hook_normalized", self.store_hook_attn_normalized))
fwd_hooks.append((f"blocks.{self.layer_idx}.attn.hook_q", self.lora_hook_qkv))
fwd_hooks.append((f"blocks.{self.layer_idx}.attn.hook_k", self.lora_hook_qkv))
fwd_hooks.append((f"blocks.{self.layer_idx}.attn.hook_v", self.lora_hook_qkv))
# Attention Hooks z/out
fwd_hooks.append((f"blocks.{self.layer_idx}.attn.hook_z", self.store_hook_z))
fwd_hooks.append((f"blocks.{self.layer_idx}.hook_attn_out", self.lora_hook_out))
return fwd_hooks
def lora_hook_qkv(
self, qkv_hook_out: Float[Tensor, "batch pos n_heads d_head"], hook: HookPoint
) -> Float[Tensor, "batch pos n_heads d_head"]:
"""
Applies the LoRA modules to query/key/value, based on the hook location.
Args:
hook_qkv_out: Float[Tensor, "batch pos n_heads d_head"]
The original output from query/key/value.
hook: HookPoint
Returns:
The original output from query/key/value, plus the output from the corresponding LoRA module.
"""
hook_location = hook.name.split(".")[-1]
qkv_in = self.cache_qkv_in
qkv_in_repeated = einops.repeat(qkv_in, "batch pos d_model -> batch pos n_inst d_model", n_inst=1)
if hook_location == "hook_q":
return qkv_hook_out + self.lora_q(qkv_in_repeated)
elif hook_location == "hook_k":
return qkv_hook_out + self.lora_k(qkv_in_repeated)
elif hook_location == "hook_v":
return qkv_hook_out + self.lora_v(qkv_in_repeated)
else:
raise ValueError(f"Invalid hook location: {hook_location}")
def lora_hook_out(
self, attn_out: Float[Tensor, "batch pos n_heads d_head"], hook: HookPoint
) -> Float[Tensor, "batch pos n_heads d_head"]:
"""
Applies the LoRA modules to the output projection matrix W_O in the attention layer.
The output of the LoRA module is computed per head, so we sum over heads before adding
to the activation `attn_out`.
Args:
attn_out: Float[Tensor, "batch pos n_heads d_head"]
The output from the attention layer.
hook: HookPoint
Returns:
The original output from the attention layer, plus the output from the LoRA module.
"""
lora_result = self.lora_o(self.cache_z)
lora_attn_out = einops.einsum(lora_result, "... n_heads d_model -> ... d_model")
return attn_out + lora_attn_out
Training with LoRA
We can now define a modified form on the HookedTransformerWithValueHead class that includes a LoRA module attached to every attention layer.
This means when we train the model, we no longer need an additional reference model, but can simply turn the LoRA modules on and off. With the LoRA modules disabled, the model will act like the base model, as the original parameters of the model are not modified.
Exercise - complete TransformerWithValueHeadLora
Now you'll implement the LoraHooks class. This class should define LoRA modules for the linear projections inside the attention layer of the transformer.
You should
* Define setup_lora which will
- define self.lora as a nn.ModuleList of LoraHooks for each layer.
- defines self.lora_fwd_hooks as a list of all the forward hooks for the LoRA modules.
- You shouldmake use of the list_fwd_hooks() method we defined for you earlier.
- Define
forward_with_value_headwhich will use thefwd_hooksproperty to forward the model with the LoRA modules enabled.- Make use of the
with.self.hooks(fwd_hooks=self.fwd_hooks)context manager. - This function is quite simple, only a few lines.
- Make use of the
- Define
generateto override thegeneratemethod in the parent class to use the LoRA hooks.- Also simple, should be a similar implementation to
forward_with_value_head.
- Also simple, should be a similar implementation to
class TransformerWithValueHeadLora(HookedTransformerWithValueHead):
lora: nn.ModuleList
lora_fwd_hooks: list[tuple[str, Callable]]
dtype: t.dtype
device: t.device
use_value_head: bool
def base_model_params(self):
return (p for name, p in self.named_parameters() if "value_head" not in name and "lora" not in name)
def lora_params(self):
return self.lora.parameters()
# we use these for compatibility with get_optimizer_and_scheduler
def get_base_model_trainable_params(self):
return self.lora_params()
def get_value_head_params(self):
return (p for name, p in self.named_parameters() if "value_head" in name)
@classmethod
def from_pretrained(cls, *args, lora_alpha: float = 32, rank: int = 4, **kwargs):
model = super(TransformerWithValueHeadLora, cls).from_pretrained(*args, **kwargs)
model.setup_lora(lora_alpha=lora_alpha, rank=rank, **kwargs)
for param in model.base_model_params():
param.requires_grad = False
return model
def setup_lora(self, lora_alpha: float = 32, rank: int = 4, **kwargs):
"""
Initializes LoRA (Low-Rank Adaptation) for all attention layers in the transformer.
Steps of this function are:
- Creates a LoraHooks module for each transformer layer
- Creates the list of forward hooks for all layers
"""
raise NotImplementedError()
@property
def fwd_hooks(self):
return self.lora_fwd_hooks + [self.value_head_hook]
def forward_with_value_head(
self, tokens: Int[Tensor, "batch seq"]
) -> tuple[Float[Tensor, "batch seq d_vocab"], Float[Tensor, "batch seq"]]:
"""
Forward pass with LoRA enabled, including the value head outputs.
Args:
tokens: Int[Tensor, "batch seq"]
The input tokens to the transformer.
Returns:
logits: Float[Tensor, "batch seq d_vocab"]
The logits of the transformer.
value: Float[Tensor, "batch seq"]
The value head outputs for each token.
"""
raise NotImplementedError()
@t.no_grad()
def generate(self, tokens: Int[Tensor, "batch seq"], **kwargs) -> Int[Tensor, "batch seq"]:
"""
We override the generate method to use the LoRA hooks applied so that we don't need to update the previous training code.
This function should call generate on the parent class (HookedTransformer), but with the LoRA hooks applied.
We don't need to return the value head outputs during generation.
Args:
tokens: Int[Tensor, "batch seq"]
The input tokens to the transformer.
**kwargs:
Additional keyword arguments to pass to the base class generate method.
Returns:
gen_tokens: Int[Tensor, "batch gen_len"]
The generated tokens.
"""
raise NotImplementedError()
model = TransformerWithValueHeadLora.from_pretrained("pythia-14m").to(device)
tests_lora.test_lora_fwd_hooks_list(model)
tests_lora.test_lora_model_forward_methods(model)
print("All tests for TransformerWithValueHeadLora passed!")
Solution
class TransformerWithValueHeadLora(HookedTransformerWithValueHead):
lora: nn.ModuleList
lora_fwd_hooks: list[tuple[str, Callable]]
dtype: t.dtype
device: t.device
use_value_head: bool
def base_model_params(self):
return (p for name, p in self.named_parameters() if "value_head" not in name and "lora" not in name)
def lora_params(self):
return self.lora.parameters()
# we use these for compatibility with get_optimizer_and_scheduler
def get_base_model_trainable_params(self):
return self.lora_params()
def get_value_head_params(self):
return (p for name, p in self.named_parameters() if "value_head" in name)
@classmethod
def from_pretrained(cls, *args, lora_alpha: float = 32, rank: int = 4, **kwargs):
model = super(TransformerWithValueHeadLora, cls).from_pretrained(*args, **kwargs)
model.setup_lora(lora_alpha=lora_alpha, rank=rank, **kwargs)
for param in model.base_model_params():
param.requires_grad = False
return model
def setup_lora(self, lora_alpha: float = 32, rank: int = 4, **kwargs):
"""
Initializes LoRA (Low-Rank Adaptation) for all attention layers in the transformer.
Steps of this function are:
- Creates a LoraHooks module for each transformer layer
- Creates the list of forward hooks for all layers
"""
self.lora = nn.ModuleList(
[LoraHooks(layer_idx, self.cfg, lora_alpha, rank) for layer_idx in range(len(self.blocks))]
).to(device)
# create list of all hooks for all layers
self.lora_fwd_hooks = []
for layer_idx in range(len(self.blocks)):
self.lora_fwd_hooks.extend(self.lora[layer_idx].list_fwd_hooks())
@property
def fwd_hooks(self):
return self.lora_fwd_hooks + [self.value_head_hook]
def forward_with_value_head(
self, tokens: Int[Tensor, "batch seq"]
) -> tuple[Float[Tensor, "batch seq d_vocab"], Float[Tensor, "batch seq"]]:
"""
Forward pass with LoRA enabled, including the value head outputs.
Args:
tokens: Int[Tensor, "batch seq"]
The input tokens to the transformer.
Returns:
logits: Float[Tensor, "batch seq d_vocab"]
The logits of the transformer.
value: Float[Tensor, "batch seq"]
The value head outputs for each token.
"""
with self.hooks(fwd_hooks=self.fwd_hooks):
logits = self.forward(tokens)
value = self.value_head_output
return logits, value
@t.no_grad()
def generate(self, tokens: Int[Tensor, "batch seq"], **kwargs) -> Int[Tensor, "batch seq"]:
"""
We override the generate method to use the LoRA hooks applied so that we don't need to update the previous training code.
This function should call generate on the parent class (HookedTransformer), but with the LoRA hooks applied.
We don't need to return the value head outputs during generation.
Args:
tokens: Int[Tensor, "batch seq"]
The input tokens to the transformer.
**kwargs:
Additional keyword arguments to pass to the base class generate method.
Returns:
gen_tokens: Int[Tensor, "batch gen_len"]
The generated tokens.
"""
with self.hooks(fwd_hooks=self.lora_fwd_hooks):
gen_tokens = super().generate(tokens, **kwargs)
return gen_tokens
model = TransformerWithValueHeadLora.from_pretrained("pythia-14m").to(device)
tests_lora.test_lora_fwd_hooks_list(model)
tests_lora.test_lora_model_forward_methods(model)
print("All tests for TransformerWithValueHeadLora passed!")
Since we still need the reference model, and since we don't modify the base model weights directly, we can load the base model once and apply LoRA via forward hooks for training, while also using the same base as the frozen reference policy (without LoRA hooks) for KL.
Why do we only add LoRA to the attention layers?
The original LoRA paper adds adapters to attention projections. More recent work often also adds them to MLP layers, or even only to MLP. As a bonus exercise, you can try adding LoRA to the MLP layers too!
@dataclass
class RLHFArgsLora(RLHFArgs):
lora_rank: int = 4
lora_alpha: float = 32
dtype: t.dtype = None
class RLHFTrainerLora(RLHFTrainer):
model: TransformerWithValueHeadLora
memory: ReplayMemory
def __init__(self, args: RLHFArgsLora):
"""
Method that now loads the reference model and the lora_model.
"""
t.manual_seed(args.seed)
self.args = args
self.run_name = f"{args.wandb_project_name}__seed{args.seed}__{time.strftime('%Y%m%d-%H%M%S')}"
self.model = TransformerWithValueHeadLora.from_pretrained(
args.base_model, lora_alpha=args.lora_alpha, rank=args.lora_rank
)
self.model.to(device).train()
self.ref_model = self.model # no need for seperate reference model!
self.optimizer, self.scheduler = get_optimizer_and_scheduler(self.args, self.model)
self.prefix_len = len(self.model.to_str_tokens(self.args.prefix, prepend_bos=self.args.prepend_bos))
print("Training LoRA model RLHF (example setup)")
lora_args = RLHFArgsLora(
use_wandb=False,
kl_coef=0.0,
total_phases=2,
warmup_steps=0,
reward_fn=reward_fn_char_count,
base_lr=1e-3,
batch_size=8,
num_minibatches=2,
gen_len=8,
)
lora_trainer = RLHFTrainerLora(lora_args)
lora_trainer.train() # Uncomment to run a tiny smoke test