Skip to content

Commit b6a95cd

Browse files
committed
refactor-attention
1 parent 52b4483 commit b6a95cd

File tree

3 files changed

+26
-10
lines changed

3 files changed

+26
-10
lines changed

examples/models/llama/attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,12 @@ def forward(
162162

163163
@register_attention("mha")
164164
class AttentionMHA(Attention):
165-
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
165+
def __init__(
166+
self,
167+
args: ModelArgs,
168+
layer_id: int,
169+
rope: Rope,
170+
):
166171
super().__init__()
167172
self.use_kv_cache = args.use_kv_cache
168173
self.n_heads = args.n_heads

examples/models/llama/llama_transformer.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8383

8484

8585
class TransformerBlock(nn.Module):
86-
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
86+
def __init__(self, args: ModelArgs, attention: Attention):
8787
super().__init__()
8888
self.use_kv_cache = args.use_kv_cache
8989
self.n_heads = args.n_heads
@@ -94,8 +94,8 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
9494
f"Unknown attention type: {args.attention_type}. "
9595
f"Available: {list(ATTENTION_REGISTRY.keys())}"
9696
)
97-
cls = ATTENTION_REGISTRY[args.attention_type]
98-
self.attention = cls(args, layer_id, rope)
97+
98+
self.attention = attention
9999
if args.moe:
100100
self.block_sparse_moe = MOEFeedForward(args)
101101
else:
@@ -117,7 +117,7 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
117117

118118

119119
class Transformer(nn.Module):
120-
def __init__(self, params: ModelArgs):
120+
def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope):
121121
super().__init__()
122122
self.params = params
123123
self.vocab_size = params.vocab_size
@@ -130,10 +130,8 @@ def __init__(self, params: ModelArgs):
130130
if self.apply_embedding
131131
else None
132132
)
133-
self.rope = Rope(params)
134-
self.layers = torch.nn.ModuleList()
135-
for layer_id in range(params.n_layers):
136-
self.layers.append(TransformerBlock(layer_id, params, self.rope))
133+
self.layers = layers
134+
self.rope = rope
137135
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
138136
self.output = (
139137
nn.Linear(params.dim, params.vocab_size, bias=False)

examples/models/llama/model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.examples.models.llama.llama_transformer import Transformer
1919

2020
from executorch.examples.models.llama.model_args import ModelArgs
21+
from executorch.examples.models.llama.rope import Rope
2122

2223
try:
2324
from .fairseq2 import convert_to_llama_checkpoint
@@ -173,7 +174,19 @@ def __init__(self, **kwargs):
173174
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
174175
with torch.device("meta"):
175176
# Model itself is loaded in default dtype, fp32.
176-
self.model_ = Transformer(model_args)
177+
178+
# Construct attention layers.
179+
rope = Rope(model_args)
180+
layers = nn.ModuleList()
181+
cls = ATTENTION_REGISTRY[model_args.attention_type]
182+
for layer_id in range(model_args.n_layers):
183+
attention = cls(model_args, layer_id, rope)
184+
transformer_block = TransformerBlock(model_args, attention)
185+
layers.append(transformer_block)
186+
187+
# Construct transformer model.
188+
self.model_ = Transformer(model_args, layers, rope)
189+
177190
# Get checkpoint dtype.
178191
if checkpoint:
179192
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)

0 commit comments

Comments
 (0)