Skip to content

Commit ccf7082

Browse files
sxufacebook-github-bot
authored andcommitted
Pass ForwardOptions from top level module and also return any relevant state as output (#8186)
Summary: Pass a `ForwardOptions` argument (introduced by #8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan Differential Revision: D69080123
1 parent 81f7c4f commit ccf7082

File tree

6 files changed

+42
-28
lines changed

6 files changed

+42
-28
lines changed

examples/models/llama/attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def forward(
236236
assert input_pos is not None
237237
k, v = self.kv_cache.update(input_pos, k, v)
238238
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
239-
return self.wo(output)
239+
return self.wo(output), None
240240

241241
# grouped multiquery attention: expand out keys and values
242242
k = k.repeat_interleave(self.n_rep, dim=1)
@@ -252,4 +252,4 @@ def forward(
252252

253253
output = self.wo(output)
254254

255-
return output
255+
return output, None

examples/models/llama/eval_llama_lib.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def _model_call(self, inps):
6565
result_logits = []
6666
for pos in range(inps.shape[-1]):
6767
pos_tensor = torch.tensor([pos], dtype=torch.int64)
68-
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
68+
logits = self._model(
69+
inps[:, pos : pos + 1], None, {"input_pos": pos_tensor}
70+
)
6971
result_logits.append(logits)
7072
if self._generate_full_logits:
7173
return torch.cat(result_logits, dim=1)
@@ -74,7 +76,9 @@ def _model_call(self, inps):
7476
else:
7577
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
7678
# Batch process the whole sequence.
77-
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
79+
logits = self._model(
80+
inps[:, : self._max_seq_length], None, {"input_pos": pos_tensor}
81+
)
7882
return logits
7983

8084
else:

examples/models/llama/evaluate/eager_eval.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def _model_call(self, inps):
7777
if self._use_kv_cache:
7878
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
7979
# Batch process the whole sequence.
80-
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
80+
logits = self._model(
81+
inps[:, : self._max_seq_length], None, {"input_pos": pos_tensor}
82+
)
8183
return logits
8284
else:
8385
return self._model(inps)

examples/models/llama/llama_transformer.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77

88
# Please refer to README.md in the same folder for more information.
99

10-
from typing import Optional
10+
from typing import Any, Optional, Tuple
1111

1212
import torch
1313
import torch.nn.functional as F
1414

15-
from executorch.examples.models.llama.attention import ATTENTION_REGISTRY
15+
from executorch.examples.models.llama.attention import (
16+
ATTENTION_REGISTRY,
17+
ForwardOptions,
18+
)
1619

1720
from executorch.examples.models.llama.model_args import ModelArgs
1821

@@ -148,17 +151,17 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
148151
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
149152
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
150153

151-
def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN
152-
h = self.attention.forward(
153-
self.attention_norm(x), freqs_cos, freqs_sin, input_pos=input_pos
154+
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
155+
h, attn_options_update = self.attention.forward(
156+
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
154157
)
155158

156159
h = x + h
157160
if hasattr(self, "block_sparse_moe"):
158161
out = h + self.block_sparse_moe(self.ffn_norm(h))
159162
else:
160163
out = h + self.feed_forward(self.ffn_norm(h))
161-
return out
164+
return out, attn_options_update
162165

163166

164167
class Transformer(nn.Module):
@@ -185,27 +188,28 @@ def __init__(self, params: ModelArgs):
185188
def forward(
186189
self,
187190
tokens: Optional[torch.LongTensor] = None, # tokens
188-
input_pos: Optional[
189-
torch.LongTensor
190-
] = None, # Scalar tensor indicating size of window of the caches
191191
h: Optional[torch.FloatTensor] = None, # embeddings
192-
) -> torch.Tensor:
192+
attn_options: Optional[ForwardOptions] = None,
193+
) -> Tuple[torch.Tensor, Optional[Any]]:
193194
if (tokens is None) ^ (h is not None):
194195
raise ValueError(
195196
"You cannot specify both tokens and h at the same time, and must specify either one"
196197
)
197198
if tokens is not None and h is None:
198199
h = self.tok_embeddings(tokens)
200+
201+
if attn_options is None:
202+
attn_options = {}
199203
seqlen = h.shape[1]
200-
freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen)
204+
freqs_cos, freqs_sin = self.rope.get_freqs(
205+
attn_options.get("input_pos"), seqlen
206+
)
201207

208+
attn_options_update = None
202209
for layer in self.layers:
203-
h = layer(
204-
h,
205-
freqs_cos,
206-
freqs_sin,
207-
input_pos,
208-
)
210+
h, attn_options_update = layer(h, freqs_cos, freqs_sin, attn_options)
211+
if attn_options_update is not None:
212+
attn_options.update(**attn_options_update)
209213

210214
if not self.generate_full_logits:
211215
# Only the last logit is used for the new generated token
@@ -237,4 +241,4 @@ def forward(
237241
expanded_logits[:, list(self.output_prune_map.values())] = logits
238242
logits = expanded_logits
239243

240-
return logits
244+
return logits, attn_options_update

examples/models/llama/model.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -289,16 +289,20 @@ def get_example_inputs_kvcache_sdpa(self):
289289
if self.enable_dynamic_shape:
290290
return (
291291
torch.tensor([[2, 3, 4]], dtype=torch.long),
292-
torch.tensor([0], dtype=torch.long),
292+
None,
293+
{"input_pos": torch.tensor([0], dtype=torch.long)},
293294
)
294295
else:
295296
return (
296297
torch.tensor(
297298
[[1]], dtype=torch.long
298299
), # tokens, with kv cache our input token length is always just 1 token.
299-
torch.tensor(
300-
[0], dtype=torch.long
301-
), # start_pos, what token of output are we on.
300+
None, # hidden state
301+
{
302+
"input_pos": torch.tensor(
303+
[0], dtype=torch.long
304+
) # start_pos, what token of output are we on.
305+
},
302306
)
303307

304308
def _transform_for_pre_quantization(self, checkpoint, model_args):

extension/llm/export/builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def _get_dynamic_shape(self) -> Any:
170170
self.dynamic_shapes = ({1: dim},)
171171
elif self.enable_dynamic_shape:
172172
# Two input arguments: tokens and input_pos but input_pos is static shape
173-
self.dynamic_shapes = ({1: dim}, {0: 1})
173+
self.dynamic_shapes = ({1: dim}, None, {"input_pos": {0: 1}})
174174
else:
175175
# Two input arguments: tokens and input_pos but both are of static shape
176176
self.dynamic_shapes = None

0 commit comments

Comments
 (0)