-
Notifications
You must be signed in to change notification settings - Fork 594
Pass ForwardOptions from top level module and also return any relevant state as output #8186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -289,16 +289,18 @@ def get_example_inputs_kvcache_sdpa(self): | |
if self.enable_dynamic_shape: | ||
return ( | ||
torch.tensor([[2, 3, 4]], dtype=torch.long), | ||
torch.tensor([0], dtype=torch.long), | ||
{"input_pos": torch.tensor([0], dtype=torch.long)}, | ||
) | ||
else: | ||
return ( | ||
torch.tensor( | ||
[[1]], dtype=torch.long | ||
), # tokens, with kv cache our input token length is always just 1 token. | ||
torch.tensor( | ||
[0], dtype=torch.long | ||
), # start_pos, what token of output are we on. | ||
{ | ||
"input_pos": torch.tensor( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it because we are plugging in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, "input_pos" is the only thing the regular attention needs, other implementations can take additional/different parameters. |
||
[0], dtype=torch.long | ||
) # start_pos, what token of output are we on. | ||
}, | ||
) | ||
|
||
def _transform_for_pre_quantization(self, checkpoint, model_args): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
None
a place holder or something else?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's to support additional output from each attention layer that'll be aggregated and returned by the top level transformer. To avoid disruptions if this is `None, the top level transformer still return a single logit output, otherwise it returns a tuple.
This was actually added in the previous PR that introduced the abstract attention class:
executorch/examples/models/llama/attention.py
Line 30 in 440a3ac