-
Notifications
You must be signed in to change notification settings - Fork 619
[compile] Fix graphbreaks in moe split; scale_grad #2771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
359387b
0165a4b
2a8df60
454868d
fe5c81e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -209,11 +209,12 @@ def _attention_call( | |
# This will use flash attention under the hood with support for custom masks. | ||
# Currently, it is used when sample packing is enabled (see torchtune.datasets.PackedDataset) | ||
if isinstance(mask, BlockMask): | ||
log_once( | ||
_log, | ||
"Using flex attention for attention computation since a BlockMask was passed in.", | ||
level=logging.DEBUG, | ||
) | ||
if not torch.compiler.is_compiling(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Noob q: why do we only want to log this when we're not compiling? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dynamo graph_breaks on log(), so this is only to avoid the graph break. But it's safe to log in normal non-compiling execution :)
|
||
log_once( | ||
_log, | ||
"Using flex attention for attention computation since a BlockMask was passed in.", | ||
level=logging.DEBUG, | ||
) | ||
if dropout_p > 0.0: | ||
raise ValueError( | ||
"Flex attention does not support dropout. Please set dropout to 0.0." | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may also want to add this to lora_finetune_distributed.py too (I think the logic should be the same there)