7
7
8
8
# Please refer to README.md in the same folder for more information.
9
9
10
- from typing import Optional
10
+ from typing import Any , Optional , Tuple
11
11
12
12
import torch
13
13
import torch .nn .functional as F
14
14
15
- from executorch .examples .models .llama .attention import ATTENTION_REGISTRY
15
+ from executorch .examples .models .llama .attention import (
16
+ ATTENTION_REGISTRY ,
17
+ ForwardOptions ,
18
+ )
16
19
17
20
from executorch .examples .models .llama .model_args import ModelArgs
18
21
@@ -148,17 +151,17 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
148
151
self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
149
152
self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
150
153
151
- def forward (self , x , freqs_cos , freqs_sin , input_pos = None ): # x: 1xN
152
- h = self .attention .forward (
153
- self .attention_norm (x ), freqs_cos , freqs_sin , input_pos = input_pos
154
+ def forward (self , x , freqs_cos , freqs_sin , attn_options : ForwardOptions ): # x: 1xN
155
+ h , attn_options_update = self .attention .forward (
156
+ self .attention_norm (x ), freqs_cos , freqs_sin , ** attn_options
154
157
)
155
158
156
159
h = x + h
157
160
if hasattr (self , "block_sparse_moe" ):
158
161
out = h + self .block_sparse_moe (self .ffn_norm (h ))
159
162
else :
160
163
out = h + self .feed_forward (self .ffn_norm (h ))
161
- return out
164
+ return out , attn_options_update
162
165
163
166
164
167
class Transformer (nn .Module ):
@@ -185,27 +188,28 @@ def __init__(self, params: ModelArgs):
185
188
def forward (
186
189
self ,
187
190
tokens : Optional [torch .LongTensor ] = None , # tokens
188
- input_pos : Optional [
189
- torch .LongTensor
190
- ] = None , # Scalar tensor indicating size of window of the caches
191
191
h : Optional [torch .FloatTensor ] = None , # embeddings
192
- ) -> torch .Tensor :
192
+ attn_options : Optional [ForwardOptions ] = None ,
193
+ ) -> Tuple [torch .Tensor , Optional [Any ]]:
193
194
if (tokens is None ) ^ (h is not None ):
194
195
raise ValueError (
195
196
"You cannot specify both tokens and h at the same time, and must specify either one"
196
197
)
197
198
if tokens is not None and h is None :
198
199
h = self .tok_embeddings (tokens )
200
+
201
+ if attn_options is None :
202
+ attn_options = {}
199
203
seqlen = h .shape [1 ]
200
- freqs_cos , freqs_sin = self .rope .get_freqs (input_pos , seqlen )
204
+ freqs_cos , freqs_sin = self .rope .get_freqs (
205
+ attn_options .get ("input_pos" ), seqlen
206
+ )
201
207
208
+ attn_options_update = None
202
209
for layer in self .layers :
203
- h = layer (
204
- h ,
205
- freqs_cos ,
206
- freqs_sin ,
207
- input_pos ,
208
- )
210
+ h , attn_options_update = layer (h , freqs_cos , freqs_sin , attn_options )
211
+ if attn_options_update is not None :
212
+ attn_options .update (** attn_options_update )
209
213
210
214
if not self .generate_full_logits :
211
215
# Only the last logit is used for the new generated token
@@ -237,4 +241,4 @@ def forward(
237
241
expanded_logits [:, list (self .output_prune_map .values ())] = logits
238
242
logits = expanded_logits
239
243
240
- return logits
244
+ return logits , attn_options_update
0 commit comments