-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Bug Fix: Dynamically set return_lse flag in FlexAttention #40352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6a91a9d
2350c9f
b6baacf
2a053e7
1b69c01
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,7 +90,7 @@ def compile_friendly_flex_attention( | |
| value: torch.Tensor, | ||
| training=False, | ||
| **kwargs, | ||
| ) -> torch.Tensor: | ||
| ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: | ||
| # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention | ||
| # Do not use compiled version if already compiling forward (it raises issues) | ||
| flex_attention_compiled = WrappedFlexAttention(training)() if not is_torchdynamo_compiling() else flex_attention | ||
|
|
@@ -243,7 +243,7 @@ def flex_attention_forward( | |
| head_mask: Optional[torch.Tensor] = None, | ||
| s_aux: Optional[torch.Tensor] = None, | ||
| **kwargs, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: | ||
| if head_mask is not None: | ||
| logger.warning_once( | ||
| "`flex_attention` does not support `head_mask`. Please set your attention to `eager` if you want this feature." | ||
|
|
@@ -290,7 +290,10 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): | |
| enable_gqa = False | ||
|
|
||
| kernel_options = kwargs.get("kernel_options") | ||
| attn_output, attention_weights = compile_friendly_flex_attention( | ||
| # On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is also bound to be removed once fixed in torch no? let's report it to torch team!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems to be known internally, see https://github.com/pytorch/pytorch/blob/639b8cc51ddebf10361f3840a6b0a244eb6092a1/torch/nn/attention/flex_attention.py#L1289-L1292 I wasn't aware of this but that's also because we mostly run tests on some accelerator, seldom on cpu
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For more context, the linked code would happen if we were not setting |
||
| return_lse = query.device.type != "cpu" | ||
|
|
||
| flex_attention_output = compile_friendly_flex_attention( | ||
| query, | ||
| key, | ||
| value, | ||
|
|
@@ -301,11 +304,16 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): | |
| kernel_options=kernel_options, | ||
| # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. | ||
| # For simplification, we thus always return it as no additional computations are introduced. | ||
| return_lse=True, | ||
| return_lse=return_lse, | ||
vasqu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| training=module.training, | ||
| ) | ||
| # lse is returned in float32 | ||
| attention_weights = attention_weights.to(value.dtype) | ||
| attn_output = attn_output.transpose(1, 2).contiguous() | ||
| if return_lse: | ||
| attention_output, lse = flex_attention_output # type: ignore[misc] | ||
| lse = lse.to(value.dtype) | ||
amd-lalithnc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| attention_output = flex_attention_output # type: ignore[assignment] | ||
| lse = None | ||
|
|
||
| return attn_output, attention_weights | ||
| attention_output = attention_output.transpose(1, 2).contiguous() | ||
| return attention_output, lse | ||
Uh oh!
There was an error while loading. Please reload this page.