Skip to content

Commit d0118d7

Browse files
committed
refactor-attention
1 parent 7c150d4 commit d0118d7

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
import torch
1313
import torch.nn.functional as F
1414

15-
from executorch.examples.models.llama.attention import (
16-
ATTENTION_REGISTRY,
17-
ForwardOptions,
18-
)
15+
from executorch.examples.models.llama.attention import Attention, ForwardOptions
1916

2017
from executorch.examples.models.llama.model_args import ModelArgs
2118
from executorch.examples.models.llama.norm import RMSNorm
@@ -83,19 +80,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8380

8481

8582
class TransformerBlock(nn.Module):
86-
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
83+
def __init__(self, args: ModelArgs, attention: Attention):
8784
super().__init__()
8885
self.use_kv_cache = args.use_kv_cache
8986
self.n_heads = args.n_heads
9087
self.dim = args.dim
9188
self.head_dim = args.head_dim
92-
if args.attention_type not in ATTENTION_REGISTRY:
93-
raise ValueError(
94-
f"Unknown attention type: {args.attention_type}. "
95-
f"Available: {list(ATTENTION_REGISTRY.keys())}"
96-
)
97-
cls = ATTENTION_REGISTRY[args.attention_type]
98-
self.attention = cls(args, layer_id, rope)
89+
self.attention = attention
9990
if args.moe:
10091
self.block_sparse_moe = MOEFeedForward(args)
10192
else:
@@ -117,7 +108,7 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
117108

118109

119110
class Transformer(nn.Module):
120-
def __init__(self, params: ModelArgs):
111+
def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope):
121112
super().__init__()
122113
self.params = params
123114
self.vocab_size = params.vocab_size
@@ -130,10 +121,8 @@ def __init__(self, params: ModelArgs):
130121
if self.apply_embedding
131122
else None
132123
)
133-
self.rope = Rope(params)
134-
self.layers = torch.nn.ModuleList()
135-
for layer_id in range(params.n_layers):
136-
self.layers.append(TransformerBlock(layer_id, params, self.rope))
124+
self.layers = layers
125+
self.rope = rope
137126
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
138127
self.output = (
139128
nn.Linear(params.dim, params.vocab_size, bias=False)

examples/models/llama/model.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
get_checkpoint_dtype,
1616
get_default_model_resource_dir,
1717
)
18-
from executorch.examples.models.llama.llama_transformer import Transformer
19-
18+
from executorch.examples.models.llama.attention import ATTENTION_REGISTRY
19+
from executorch.examples.models.llama.llama_transformer import (
20+
Transformer,
21+
TransformerBlock,
22+
)
2023
from executorch.examples.models.llama.model_args import ModelArgs
24+
from executorch.examples.models.llama.rope import Rope
2125
from torchao.utils import TorchAOBaseTensor
2226

2327
try:
@@ -174,7 +178,24 @@ def __init__(self, **kwargs):
174178
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
175179
with torch.device("meta"):
176180
# Model itself is loaded in default dtype, fp32.
177-
self.model_ = Transformer(model_args)
181+
182+
# Construct attention layers.
183+
rope = Rope(model_args)
184+
if model_args.attention_type not in ATTENTION_REGISTRY:
185+
raise ValueError(
186+
f"Unknown attention type: {model_args.attention_type}. "
187+
f"Available: {list(ATTENTION_REGISTRY.keys())}"
188+
)
189+
layers = torch.nn.ModuleList()
190+
cls = ATTENTION_REGISTRY[model_args.attention_type]
191+
for layer_id in range(model_args.n_layers):
192+
attention = cls(model_args, layer_id, rope)
193+
transformer_block = TransformerBlock(model_args, attention)
194+
layers.append(transformer_block)
195+
196+
# Construct transformer model.
197+
self.model_ = Transformer(model_args, layers, rope)
198+
178199
# Get checkpoint dtype.
179200
if checkpoint:
180201
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)

0 commit comments

Comments
 (0)