-
Notifications
You must be signed in to change notification settings - Fork 29
defend flash attention3 failed #126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @qzzz95, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request enhances the stability and robustness of the attention mechanisms by introducing explicit checks for Flash Attention 3's head_dim prerequisite. It ensures that Flash Attention 3 is only utilized when its operational conditions are met, providing either a graceful fallback to alternative implementations or raising a clear error when explicitly requested under unsupported configurations.
Highlights
- Flash Attention 3 Prerequisite Enforcement: Added checks to ensure
flash_attn3is only used when thehead_dim(the last dimension of the query tensorq) is 256 or less, as required by the library. This prevents potential failures or incorrect behavior when the prerequisite is not met. - Improved Auto-Selection Logic: When
attn_implis set to 'auto', the system will now log a warning and automatically fall back to other available attention implementations (such as Xformers or SDPA) if Flash Attention 3'shead_dimprerequisite is not met. This provides a more robust and graceful handling of attention mechanism selection. - Explicit Error Handling for Invalid Usage: If
flash_attn_3is explicitly requested (i.e.,attn_impl='flash_attn_3') but thehead_dimexceeds 256, aRuntimeErroris now raised. This prevents silent failures and provides clear, immediate feedback to the user about the unsupported configuration. - Consistent Application Across Attention Functions: These new
head_dimchecks, warning messages, and error handling mechanisms have been consistently applied to both the standardattentionfunction and thelong_context_attentionfunction, ensuring uniform behavior across different attention contexts.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a safeguard for using FlashAttention-3 by checking if the head dimension (head_dim) meets its requirement of being less than or equal to 256. The implementation is logically sound, but there is code duplication and a magic number that should be addressed.
| "sage_attn", | ||
| "sparge_attn", | ||
| ] | ||
| flash_attn3_satisfied = q.shape[-1] <= 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for checking FlashAttention-3 compatibility, including the check q.shape[-1] <= 256, is duplicated in the long_context_attention function. This introduces maintainability issues: code duplication and a magic number. The value 256 should be defined as a module-level constant (e.g., FLASH_ATTN_3_HEAD_DIM_LIMIT = 256) for readability and easier updates.
| return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale) | ||
| elif attn_impl == "flash_attn_3": | ||
| if not flash_attn3_satisfied: | ||
| raise RuntimeError(f"find head_dim={q.shape[-1]} > 256, but flash_attn_3 only supports head dimension at most 256") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The phrasing "find head_dim=..." in this error message is unclear. A more direct and informative message would be more helpful. This message is also duplicated in long_context_attention at line 269.
raise RuntimeError(f"FlashAttention-3 does not support head_dim={q.shape[-1]}; the maximum supported is 256")| if flash_attn3_satisfied: | ||
| attn_func = LongContextAttention(attn_type=AttnType.FA3) | ||
| else: | ||
| raise RuntimeError(f"find head_dim={q.shape[-1]} > 256, but flash_attn_3 only supports head dimension at most 256") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.