-
Notifications
You must be signed in to change notification settings - Fork 41
Remove packed tensor functions and improve dynamic mask attention #142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Removes FlashDMAttnQKVPackedFunc, FlashDMAttnVarlenQKVPackedFunc, FlashDMAttnKVPackedFunc, and FlashDMAttnVarlenKVPackedFunc classes along with their corresponding helper functions to simplify the interface. Adds sequence length padding to ensure compatibility with hardware requirements by padding sequences to multiples of 128 tokens and properly handling the unpadding in backward passes. Fixes mask and bias tensor dimension alignment to use number of key heads instead of query heads for better multi-query and grouped-query attention support.
Adds comprehensive docstring explaining function parameters and return values Separates query and key length variables for better clarity and accuracy Removes redundant kwargs to prevent duplicate parameter passing Updates implementation parameter to use explicit string value instead of config reference Adds support for keep_window_size parameter from module configuration
Removes complex padding/unpadding logic and varlen functions in favor of a streamlined approach that handles attention masking directly through bias manipulation. Adds support for keep_window_size parameter to enable top-k attention window selection when key length exceeds the specified window size. Updates function signature to include key_length parameter and removes unnecessary helper functions for index manipulation and tensor reshaping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR removes packed tensor functions and improves the dynamic mask attention implementation. The changes simplify the interface by eliminating QKV and KV packed function variants, while enhancing the core attention function with new features like window-based attention and improved parameter handling.
- Removed multiple packed tensor function classes and their corresponding public API functions
- Enhanced dynamic mask attention with new parameters like
keep_window_sizeand improved mask handling - Added sequence padding for hardware compatibility requirements (128-token alignment)
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py | Removed packed tensor utility functions and improved dynamic mask attention implementation with new windowing features |
| flash_dmattn/integrations/flash_dynamic_mask_attention.py | Enhanced wrapper function with better documentation and support for new parameters |
| flash_dmattn/flash_dmattn_interface.py | Removed packed tensor function classes and added sequence padding for hardware alignment |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| min_dtype = torch.finfo(dtype).min | ||
| batch_size, _, num_kv_heads, _ = key_states.shape | ||
|
|
||
| if not all(k in globals() for k in ("_flash_fn")): |
Copilot
AI
Sep 1, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The all() function with a single-element tuple is unnecessary and inefficient. This should be simplified to a direct membership check.
| if not all(k in globals() for k in ("_flash_fn")): | |
| if "_flash_fn" not in globals(): |
Eliminate packed tensor functions to simplify the interface and add sequence padding for hardware compatibility. Enhance the dynamic mask attention function with clearer parameters and improved implementation, while streamlining the overall logic and supporting new configuration options.