-
Notifications
You must be signed in to change notification settings - Fork 540
Pass ForwardOptions from top level module and also return any relevant state as output #8186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8186
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 867a885 with merge base dd31d93 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D69080123 |
@pytorchbot label "release notes: llama_transformer: Pass ForwardOptions from top level module and also return any relevant state as output" |
Didn't find following labels among repository labels: release notes: llama_transformer: Pass ForwardOptions from top level module and also return any relevant state as output |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally seems fine
@@ -236,7 +236,7 @@ def forward( | |||
assert input_pos is not None | |||
k, v = self.kv_cache.update(input_pos, k, v) | |||
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) | |||
return self.wo(output) | |||
return self.wo(output), None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is None
a place holder or something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's to support additional output from each attention layer that'll be aggregated and returned by the top level transformer. To avoid disruptions if this is `None, the top level transformer still return a single logit output, otherwise it returns a tuple.
This was actually added in the previous PR that introduced the abstract attention class:
) -> Tuple[torch.Tensor, Optional[Any]]: |
[0], dtype=torch.long | ||
), # start_pos, what token of output are we on. | ||
{ | ||
"input_pos": torch.tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it because we are plugging in the Forward
option down the line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, "input_pos" is the only thing the regular attention needs, other implementations can take additional/different parameters.
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan, cccclai Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan, cccclai Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan, cccclai Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
…t state as output (pytorch#8186) Summary: Pass a `ForwardOptions` argument (introduced by pytorch#8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Reviewed By: iseeyuan, cccclai Differential Revision: D69080123
This pull request was exported from Phabricator. Differential Revision: D69080123 |
Summary: Pass a
ForwardOptions
argument (introduced by #8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. This is in preparation for a static shape attention implementation.Differential Revision: D69080123
cc @mergennachin @cccclai @helunwencser @dvorjackz