12
12
import torch
13
13
import torch .nn .functional as F
14
14
15
- from executorch .examples .models .llama .attention import (
16
- ATTENTION_REGISTRY ,
17
- ForwardOptions ,
18
- )
15
+ from executorch .examples .models .llama .attention import Attention , ForwardOptions
19
16
20
17
from executorch .examples .models .llama .model_args import ModelArgs
21
18
from executorch .examples .models .llama .norm import RMSNorm
@@ -83,19 +80,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
80
84
81
85
82
class TransformerBlock (nn .Module ):
86
- def __init__ (self , layer_id : int , args : ModelArgs , rope : Rope ):
83
+ def __init__ (self , args : ModelArgs , attention : Attention ):
87
84
super ().__init__ ()
88
85
self .use_kv_cache = args .use_kv_cache
89
86
self .n_heads = args .n_heads
90
87
self .dim = args .dim
91
88
self .head_dim = args .head_dim
92
- if args .attention_type not in ATTENTION_REGISTRY :
93
- raise ValueError (
94
- f"Unknown attention type: { args .attention_type } . "
95
- f"Available: { list (ATTENTION_REGISTRY .keys ())} "
96
- )
97
- cls = ATTENTION_REGISTRY [args .attention_type ]
98
- self .attention = cls (args , layer_id , rope )
89
+ self .attention = attention
99
90
if args .moe :
100
91
self .block_sparse_moe = MOEFeedForward (args )
101
92
else :
@@ -117,7 +108,7 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
117
108
118
109
119
110
class Transformer (nn .Module ):
120
- def __init__ (self , params : ModelArgs ):
111
+ def __init__ (self , params : ModelArgs , layers : nn . ModuleList , rope : Rope ):
121
112
super ().__init__ ()
122
113
self .params = params
123
114
self .vocab_size = params .vocab_size
@@ -130,10 +121,8 @@ def __init__(self, params: ModelArgs):
130
121
if self .apply_embedding
131
122
else None
132
123
)
133
- self .rope = Rope (params )
134
- self .layers = torch .nn .ModuleList ()
135
- for layer_id in range (params .n_layers ):
136
- self .layers .append (TransformerBlock (layer_id , params , self .rope ))
124
+ self .layers = layers
125
+ self .rope = rope
137
126
self .norm = RMSNorm (params .dim , eps = params .norm_eps )
138
127
self .output = (
139
128
nn .Linear (params .dim , params .vocab_size , bias = False )
0 commit comments