Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Eliminate packed tensor functions to simplify the interface and add sequence padding for hardware compatibility. Enhance the dynamic mask attention function with clearer parameters and improved implementation, while streamlining the overall logic and supporting new configuration options.

Removes FlashDMAttnQKVPackedFunc, FlashDMAttnVarlenQKVPackedFunc, FlashDMAttnKVPackedFunc, and FlashDMAttnVarlenKVPackedFunc classes along with their corresponding helper functions to simplify the interface.

Adds sequence length padding to ensure compatibility with hardware requirements by padding sequences to multiples of 128 tokens and properly handling the unpadding in backward passes.

Fixes mask and bias tensor dimension alignment to use number of key heads instead of query heads for better multi-query and grouped-query attention support.
Adds comprehensive docstring explaining function parameters and return values

Separates query and key length variables for better clarity and accuracy

Removes redundant kwargs to prevent duplicate parameter passing

Updates implementation parameter to use explicit string value instead of config reference

Adds support for keep_window_size parameter from module configuration
Removes complex padding/unpadding logic and varlen functions in favor of a streamlined approach that handles attention masking directly through bias manipulation.

Adds support for keep_window_size parameter to enable top-k attention window selection when key length exceeds the specified window size.

Updates function signature to include key_length parameter and removes unnecessary helper functions for index manipulation and tensor reshaping.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR removes packed tensor functions and improves the dynamic mask attention implementation. The changes simplify the interface by eliminating QKV and KV packed function variants, while enhancing the core attention function with new features like window-based attention and improved parameter handling.

  • Removed multiple packed tensor function classes and their corresponding public API functions
  • Enhanced dynamic mask attention with new parameters like keep_window_size and improved mask handling
  • Added sequence padding for hardware compatibility requirements (128-token alignment)

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py Removed packed tensor utility functions and improved dynamic mask attention implementation with new windowing features
flash_dmattn/integrations/flash_dynamic_mask_attention.py Enhanced wrapper function with better documentation and support for new parameters
flash_dmattn/flash_dmattn_interface.py Removed packed tensor function classes and added sequence padding for hardware alignment

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

min_dtype = torch.finfo(dtype).min
batch_size, _, num_kv_heads, _ = key_states.shape

if not all(k in globals() for k in ("_flash_fn")):
Copy link

Copilot AI Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The all() function with a single-element tuple is unnecessary and inefficient. This should be simplified to a direct membership check.

Suggested change
if not all(k in globals() for k in ("_flash_fn")):
if "_flash_fn" not in globals():

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit f8db33a into main Sep 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants