Skip to content

Commit b3e1550

Browse files
committed
Add LoRA linear definition
^ Add lora linear definition. Pull out linears from attention, and allow custom linear (eg. lora linear) to be passed in. If none, construct linear (current behaviour). Differential Revision: [D73953776](https://our.internmc.facebook.com/intern/diff/D73953776/) ghstack-source-id: 285402969 Pull Request resolved: #11044
1 parent 6a8d286 commit b3e1550

File tree

5 files changed

+182
-10
lines changed

5 files changed

+182
-10
lines changed

examples/models/llama/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ runtime.python_library(
1313
name = "llama_transformer",
1414
srcs = [
1515
"llama_transformer.py",
16+
"lora.py",
1617
"rope.py",
1718
"attention.py",
1819
"model_args.py",

examples/models/llama/attention.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,28 @@ def update(
324324

325325
@register_attention("mha")
326326
class AttentionMHA(Attention):
327-
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
327+
def __init__(
328+
self,
329+
args: ModelArgs,
330+
layer_id: int,
331+
rope: Rope,
332+
wq: Optional[nn.Module] = None,
333+
wk: Optional[nn.Module] = None,
334+
wv: Optional[nn.Module] = None,
335+
wo: Optional[nn.Module] = None,
336+
):
337+
"""
338+
Multi-head attention layer.
339+
340+
Args:
341+
args (ModelArgs): Model configuration parameters.
342+
layer_id (int): Layer index.
343+
rope (Rope): Rotary position embedding module.
344+
wq (Optional[nn.Module]): Query projection module. If None, use regular nn.Linear.
345+
wk (Optional[nn.Module]): Key projection module. If None, use regular nn.Linear.
346+
wv (Optional[nn.Module]): Value projection module. If None, use regular nn.Linear.
347+
wo (Optional[nn.Module]): Output projection module. If None, use regular nn.Linear.
348+
"""
328349
super().__init__()
329350
self.use_kv_cache = args.use_kv_cache
330351
self.n_heads = args.n_heads
@@ -349,19 +370,34 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
349370
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
350371
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
351372

352-
self.wq = nn.Linear(
353-
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
373+
self.wq = (
374+
wq
375+
if wq is not None
376+
else nn.Linear(
377+
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
378+
)
379+
)
380+
self.wk = (
381+
wk
382+
if wk is not None
383+
else nn.Linear(
384+
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
385+
)
354386
)
355-
self.wk = nn.Linear(
356-
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
387+
self.wv = (
388+
wv
389+
if wv is not None
390+
else nn.Linear(
391+
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
392+
)
357393
)
358-
self.wv = nn.Linear(
359-
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
394+
self.wo = (
395+
wo
396+
if wo is not None
397+
else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
360398
)
361-
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
362399

363400
self.layer_id = layer_id
364-
365401
self.rope = rope
366402

367403
causal_mask = torch.tril(

examples/models/llama/llama_transformer.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ForwardOptions,
1919
)
2020

21+
from executorch.examples.models.llama.lora import LoRALinear
2122
from executorch.examples.models.llama.model_args import ModelArgs
2223
from executorch.examples.models.llama.norm import RMSNorm
2324
from executorch.examples.models.llama.rope import Rope
@@ -254,7 +255,83 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
254255
layers = torch.nn.ModuleList()
255256
cls = ATTENTION_REGISTRY[model_args.attention_type]
256257
for layer_id in range(model_args.n_layers):
257-
attention = cls(model_args, layer_id, rope)
258+
wq = (
259+
LoRALinear(
260+
in_dim=model_args.dim,
261+
out_dim=model_args.n_heads * model_args.head_dim,
262+
rank=model_args.r,
263+
alpha=model_args.lora_alpha,
264+
dropout=0.0,
265+
use_bias=model_args.attention_qkv_bias,
266+
)
267+
if model_args.target_modules is not None
268+
and "q_proj" in model_args.target_modules
269+
else (
270+
torch.nn.Linear(
271+
model_args.dim,
272+
model_args.n_heads * model_args.head_dim,
273+
bias=model_args.attention_qkv_bias,
274+
)
275+
)
276+
)
277+
278+
wk = (
279+
LoRALinear(
280+
in_dim=model_args.dim,
281+
out_dim=model_args.n_kv_heads * model_args.head_dim,
282+
rank=model_args.r,
283+
alpha=model_args.lora_alpha,
284+
dropout=0.0,
285+
use_bias=model_args.attention_qkv_bias,
286+
)
287+
if model_args.target_modules is not None
288+
and "k_proj" in model_args.target_modules
289+
else (
290+
torch.nn.Linear(
291+
model_args.dim,
292+
model_args.n_kv_heads * model_args.head_dim,
293+
bias=model_args.attention_qkv_bias,
294+
)
295+
)
296+
)
297+
wv = (
298+
LoRALinear(
299+
in_dim=model_args.dim,
300+
out_dim=model_args.n_kv_heads * model_args.head_dim,
301+
rank=model_args.r,
302+
alpha=model_args.lora_alpha,
303+
dropout=0.0,
304+
use_bias=model_args.attention_qkv_bias,
305+
)
306+
if model_args.target_modules is not None
307+
and "v_proj" in model_args.target_modules
308+
else (
309+
torch.nn.Linear(
310+
model_args.dim,
311+
model_args.n_kv_heads * model_args.head_dim,
312+
bias=model_args.attention_qkv_bias,
313+
)
314+
)
315+
)
316+
317+
wo = (
318+
LoRALinear(
319+
in_dim=model_args.n_kv_heads * model_args.head_dim,
320+
out_dim=model_args.dim,
321+
rank=model_args.r,
322+
alpha=model_args.lora_alpha,
323+
dropout=0.0,
324+
use_bias=model_args.attention_qkv_bias,
325+
)
326+
if model_args.target_modules is not None
327+
and "output_proj" in model_args.target_modules
328+
else (
329+
torch.nn.Linear(
330+
model_args.n_heads * model_args.head_dim, model_args.dim, bias=False
331+
)
332+
)
333+
)
334+
attention = cls(model_args, layer_id, rope, wq, wk, wv, wo)
258335
transformer_block = TransformerBlock(model_args, attention)
259336
layers.append(transformer_block)
260337

examples/models/llama/lora.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch import nn
9+
10+
11+
class LoRALinear(nn.Module):
12+
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`."""
13+
14+
def __init__(
15+
self,
16+
in_dim: int,
17+
out_dim: int,
18+
rank: int,
19+
alpha: float,
20+
dropout: float = 0.0,
21+
use_bias: bool = False,
22+
):
23+
super().__init__()
24+
self.in_dim = in_dim
25+
self.out_dim = out_dim
26+
self.rank = rank
27+
self.alpha = alpha
28+
self.use_bias = use_bias
29+
self.dropout = dropout
30+
31+
linear = nn.Linear(in_dim, out_dim, bias=use_bias)
32+
weight = linear.weight
33+
bias = linear.bias if self.use_bias else None
34+
self.register_parameter("weight", nn.Parameter(weight))
35+
self.register_parameter(
36+
"bias", nn.Parameter(bias) if bias is not None else None
37+
)
38+
39+
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
40+
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
41+
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
42+
43+
def forward(self, x: torch.Tensor) -> torch.Tensor:
44+
out = torch.nn.functional.linear(x, self.weight, self.bias)
45+
lora_out = self.lora_a(self.dropout(x))
46+
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
47+
48+
return out + lora_out

examples/models/llama/model_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,18 @@ class ModelArgs:
5555
eos_count: int = 2
5656

5757
quantization_args: Optional[dict] = None
58+
# LoRA for QAT.
5859
lora_args: Optional[dict] = None
5960

61+
# LoRA arguments to set up a LoRA inference model.
62+
# These arguments come directly from a torchtune LoRA config.
63+
r: Optional[int] = None # Rank.
64+
lora_alpha: Optional[int] = None # Alpha.
65+
# Eg. q_proj, k_proj, v_proj, output_proj
66+
target_modules: Optional[list] = None
67+
peft_type: Optional[str] = None # PEFT type.
68+
base_model_name_or_path: Optional[str] = None # Base model name or path.
69+
6070
def __post_init__(self):
6171
if self.n_kv_heads is None:
6272
self.n_kv_heads = self.n_heads

0 commit comments

Comments
 (0)