-
Notifications
You must be signed in to change notification settings - Fork 39
[FEATURE SUPPORT] Variable-Length Attention with Padding-Free Execution #188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… functions for clarity
…integration guide for consistency
…architecture and implementation details
…ty and consistency
…ion functions for consistency
Simplifies the varlen attention API by dropping explicit mask/bias inputs and associated gradients, reducing memory overhead and aligning with the underlying kernels. Avoids padding the key sequence length to multiples of 8 (still pads head size), relying on kernel support to handle ragged sizes and eliminating unnecessary work. Changes the default deterministic flag to False to favor performance; callers can still request deterministic behavior when needed. Updates saved tensors, sanitization, wrappers, returns, and docs to reflect the streamlined interface. Breaking change: callers must remove mask/bias arguments and any reliance on dbias gradients.
Introduces utilities to unpad/repad tensors and compute indices/cumulative seqlens for ragged batches, reusing mask-derived metadata across Q/K/V to reduce overhead. Handles static KV caches longer than the mask by safe slicing to avoid incorrect attention scores, and supports left-padded sequences and single-token decoding. Improves performance and correctness for attention paths that operate on variable-length inputs.
Introduces lazy resolution of attention kernels and padding helpers, plus a compile-friendly kwarg processor that adapts to kernel feature support. Enables variable-length execution via unpad/repad when masks are 2D, and padding-free/packed flows using position ids or precomputed sequence offsets. Adjusts is_causal for single-token queries and supports windowed attention with bias-safe top-k selection. Improves compatibility across kernel versions and torch.compile, adds deterministic control via env var, handles PEFT dtype quirks, and includes minor device safeguards. Raises a clear error when incompatible mask/bias shapes are mixed.
Allows passing window size as an argument and forwards it instead of always using the module default. Respects a provided causal flag from kwargs, falling back to the module value if absent. Clarifies attention mask/bias shapes to include 2D masks and per-head forms. Improves configurability and fixes ignored overrides.
Re-enables variable-length attention forward/backward and registers both with the extension. Simplifies the varlen API by removing mask/bias; uses empty placeholders and flags, and drops dbias from outputs. Enables paged KV cache for varlen forward, validates left padding, preserves zero_tensors/deterministic handling, and applies minor formatting cleanups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces comprehensive support for variable-length attention sequences in Flash Dynamic Mask Attention (FDMA), enabling padding-free and packed execution paths that significantly improve memory efficiency and performance for batches with heterogeneous sequence lengths. The implementation adds three execution paths: standard padded attention, unpadded variable-length attention with automatic unpacking/repacking, and packed sequences with zero-copy execution.
Key changes include:
- Re-enabling variable-length (varlen) forward and backward passes with simplified API (no mask/bias support)
- Adding padding/unpadding utilities for ragged batch handling with lazy kernel resolution
- Implementing packed sequence detection via position ID analysis for multi-turn dialogue scenarios
Reviewed Changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| flash_dmattn/utils/padding.py | New padding utilities for varlen attention with unpad/pad operations |
| flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py | Core dispatcher with lazy imports, kwarg processing, and packed sequence detection |
| flash_dmattn/integrations/flash_dynamic_mask_attention.py | Updated integration wrapper with window_size parameter and is_causal override support |
| flash_dmattn/flash_dmattn_interface.py | Re-enabled varlen functions with simplified signature (removed mask/bias parameters) |
| flash_dmattn/flash_dmattn_triton.py | Parameter name standardization (scale → softmax_scale) |
| flash_dmattn/flash_dmattn_flex.py | Parameter name standardization (scale → softmax_scale) |
| csrc/flash_dmattn/flash_api.cpp | C++ API updates with re-enabled varlen functions and removed mask/bias support |
| docs/integration_zh.md | New Chinese integration documentation for variable-length attention |
| docs/integration.md | Updated parameter names in examples (scale → softmax_scale) |
| docs/api_reference_zh.md | Updated Chinese API documentation with parameter name changes |
| docs/api_reference.md | Updated API documentation with parameter name changes |
| benchmarks/*.py | Updated benchmark scripts with parameter name standardization |
| README*.md | Updated examples with parameter name changes |
| .github/PULL_REQUEST_TEMPLATE/performance_optimization.yml | Spelling correction (optimisation → optimization) |
| .github/ISSUE_TEMPLATE/performance_issue.yml | Spelling correction (optimisation → optimization) |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| softcap = 0.0 | ||
| if deterministic is None: | ||
| deterministic = True | ||
| deterministic = False |
Copilot
AI
Oct 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value change from True to False for deterministic mode is a breaking change that could affect reproducibility in existing workflows. Consider adding a deprecation warning or environment variable check to maintain backward compatibility.
flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py
Outdated
Show resolved
Hide resolved
flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py
Outdated
Show resolved
Hide resolved
|
|
||
| if supports_mapping["deterministic"]: | ||
| flash_kwargs["deterministic"] = ( | ||
| deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" |
Copilot
AI
Oct 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The environment variable check should use consistent naming. The docstring mentions FLASH_DMATTN_DETERMINISTIC but the code uses FLASH_ATTENTION_DETERMINISTIC.
| deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" | |
| deterministic if deterministic is not None else os.getenv("FLASH_DMATTN_DETERMINISTIC", "0") == "1" |
| if attention_mask.dim() == 4 and attention_bias.dim() == 3: | ||
| attention_bias = attention_bias.unsqueeze(-2).expand(-1, -1, query_length, -1) | ||
| if attention_mask.dim() == 3 and attention_bias.dim() == 4: | ||
| attention_mask = attention_mask.unsqueeze(-2).expand(-1, -1, query_length, -1) |
Copilot
AI
Oct 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expand operations create memory views that could be inefficient for large tensors. Consider using repeat if the expanded dimensions will be accessed multiple times, or document that these are views to manage memory expectations.
| attention_mask = attention_mask.unsqueeze(-2).expand(-1, -1, query_length, -1) | |
| attention_mask = attention_mask.unsqueeze(-2).repeat(1, 1, query_length, 1) |
| softcap = 0.0 | ||
| if deterministic is None: | ||
| deterministic = True | ||
| deterministic = False |
Copilot
AI
Oct 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to Comment 1, this default change in the varlen function could affect reproducibility. The change should be documented in migration notes or use environment variable fallback for backward compatibility.
| const float softcap, | ||
| const bool return_softmax | ||
| ) { | ||
|
|
Copilot
AI
Oct 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Empty line added without clear purpose. While not harmful, it adds unnecessary whitespace to the codebase.
…n_utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…n_utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Summary
This PR introduces comprehensive support for variable-length attention sequences in Flash Dynamic Mask Attention (FDMA), enabling padding-free and packed execution paths that significantly improve memory efficiency and performance for batches with heterogeneous sequence lengths.
Key additions:
Motivation:
Modern LLM serving and training increasingly require efficient handling of variable-length sequences. Traditional padded attention wastes computation on padding tokens and inflates memory usage. This feature enables:
Design
Architecture Overview
The implementation introduces three execution paths based on input format:
Standard padded attention (
flash_dmattn_func):(batch_size, seqlen, num_heads, head_dim)Unpadded variable-length (
flash_dmattn_varlen_func):(total_tokens, num_heads, head_dim)with cumulative sequence lengths(batch_size, seq_len)is detectedPacked sequences (
flash_dmattn_varlen_func):cu_seqlensKey Components
1. Lazy Kernel Resolution (
lazy_import_flash_dynamic_mask_attention)2. Padding Utilities (
_upad_input,_pad_input,_get_unpad_data)3. Feature-Aware Kwarg Processing (
_process_flash_dynamic_mask_attention_kwargs)dropout,sliding_window)4. Packed Sequence Detection (
_is_packed_sequence,prepare_fdma_kwargs_from_position_ids)cu_seqlensfrom position resetsSimplified Varlen API (Breaking Change)
The varlen path intentionally excludes mask and bias parameters to:
Current varlen signature:
Removed from varlen:
attn_maskparameter and mask head stride handlingattn_biasparameter anddbiasgradient outputAlternatives Considered
Universal mask/bias support in varlen
Always auto-detect and use varlen path
Separate functions for packed vs unpadded
_flash_dynamic_mask_attention_forwarddispatcherChanges
New Public APIs
1.
flash_dmattn_varlen_func(Re-enabled)(output,)or(output, softmax_lse, S_dmask)ifreturn_attn_probs=Trueflash_dmattn_funcif dynamic masks required2. Padding Utilities (Internal, exposed for advanced users)
Modified APIs
1.
flash_dynamic_mask_attention_forward(Integration wrapper)window_sizeparameter for sliding window attentionis_causalnow respects kwargs override instead of only module defaultUpdated signature:
2.
_flash_dynamic_mask_attention_forward(Core dispatcher)cu_seqlenssupport for external packersdeterministic=False(wasNone, which defaulted toTrue)Configuration Changes
FLASH_DMATTN_DETERMINISTIC=1forces deterministic backward pass globally_lazy_define_process_functionCLI/Build Changes
None. Changes are runtime API only.
Implementation Notes
Key Components
FlashAttnVarlenFunc (flash_dmattn_interface.py:457-573)
q, k, v, out, softmax_lse, cu_seqlens_q/k, max_seqlen_q/kblock_tableUnpadding Pipeline (_upad_input)
indicesfrom flattened maskcu_seqlensvia cumulative sum of sequence lengthsPacked Detection Logic (_is_packed_sequence)
PEFT Dtype Handling (fdma_peft_integration_check)
_pre_quantization_dtypeor infers from Linear layersTricky Parts
MPS device workaround
Required for compatibility with metal-flash-sdpa kernel that mutates input tensors.
Static KV cache slicing
Prevents attention over uninitialized cache positions when using fixed-size caches.
Causal adjustment for decoding
max_seqlen_q == 1) forceis_causal=FalseDeterministic default change
deterministic=None→ kernel chose (usually True)deterministic=False→ favor performance, opt-in via kwarg or env varPerformance Considerations
round_multiple(seqlen_k, 8)to avoid wasted computationTests
Unit Tests
New test coverage in
benchmarks/:forward_equivalence.pybackward_equivalence.pydq,dk,dvin varlen modedbiasoutput (expected removal)grad_equivalence.pyIntegration Tests (Manual validation required)
Hugging Face model integration
FlashDynamicMaskAttentionlayerPaged KV cache
block_tableindexing with synthetic block mapsCoverage
Docs
Updated Documentation
API Reference (
docs/api_reference.md)flash_dmattn_varlen_funcsignature and examplesIntegration Guide (
docs/integration.md)Docstrings
flash_dynamic_mask_attention_forwardwith 2D mask behaviorflash_dmattn_varlen_funcdocstringwindow_sizeandis_causalinteractionExamples
Added example to
examples/modeling/:Missing Documentation (TODO)
Known Limitations
Varlen Mask/Bias Support (Development Required)
Current status: Variable-length attention (
flash_dmattn_varlen_func) does not supportattn_maskandattn_biasparameters.Reason: The underlying CUDA kernels in
csrc/flash_dmattn/src/lack infrastructure to handle per-head or per-token masks/biases in ragged tensor layouts. Specifically:set_params_fprop/set_params_dgradassume contiguous batch layoutsWorkaround: Use standard
flash_dmattn_funcif dynamic masks are required. It supports:Future work: Extending varlen to support mask/bias requires:
Kernel changes:
flash_fwd_kernel.hto accept mask/bias pointers with ragged indexingflash_bwd_kernel.hto accumulatedbiaswith sequence boundary awarenessgenerate_kernels.pywith new flagsAPI changes:
attn_maskandattn_biasparameters toflash_dmattn_varlen_funcFlashAttnVarlenFunc.forwardto save mask/bias in contextdbiasinFlashAttnVarlenFunc.backwardoutput tupleIntegration changes:
_flash_dynamic_mask_attention_forwardto pass mask/bias to varlen path_upad_inputto handle mask/bias unpadding (currently only unpads Q/K/V)Estimated effort: 2-3 weeks for kernel development + validation
Tracking issue: Please create an issue titled "Support dynamic masks and biases in varlen attention" to track this work.
Other Limitations
Checklist
Migration Guide for Breaking Changes
For users upgrading to this version:
If you use
flash_dmattn_func(standard padded attention):If you previously used
flash_dmattn_varlen_funcwith mask/bias:If you relied on deterministic backward by default:
Acknowledgments
This feature draws inspiration from:
_upad_inpututilitiesSpecial thanks to the CUTLASS team for the flexible GEMM templates that make ragged attention feasible.
Note to reviewers: Please pay special attention to:
cu_seqlenscomputation in edge cases (empty sequences, single-token)