-
Notifications
You must be signed in to change notification settings - Fork 39
Implement variable-length attention with mask and bias support #185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Adds a working variable-length attention forward path with boolean mask and additive bias, supporting broadcast across head dims. Enforces dtype/device/contiguity and shape checks, initializes outputs when requested, and handles empty key sequences. Keeps Paged KV disabled pending fixes, retains decoding optimization for single-token queries, and exposes optional softmax output for debugging/inspection. Improves usability on Ampere+ GPUs while maintaining correctness and constraints.
Enables sanitization in forward/backward to replace NaN and ±Inf with 0, improving numerical stability and preventing invalid value propagation. Stops treating bias gradient as an input and ensures the computed bias gradient is contiguous, avoiding layout issues in downstream kernels.
Removes the eager attention fallback and unconditionally routes attention through the flash dynamic mask backend, printing an install hint when the dependency is missing. Deprecates the past_key_value argument in favor of past_key_values across attention, layer, and LM APIs, switches types to a cache interface, and initializes the dynamic cache with config for correctness. Disables SDPA/flex/attention-backend support flags to reflect the single supported backend. Adds a buffer annotation to satisfy linting, drops unused decoder accessors, ignores unused attention weights, and updates a paper link to the HF mirror. Improves consistency of the attention API and enforces a single, performant backend.
Reorganizes documentation for faster onboarding: adds Quick Start, backend selection/comparison, and a clearer API reference. Clarifies attention mask/bias shapes with broadcasting support and updates the integration example to drop manual expansion and eager fallback. Adds concise install instructions, backend availability flags, and unified import guidance; removes redundant backend and summary sections for clarity.
…n classes for consistency
…n_forward for consistency
…ity and consistency
Introduces CUDA custom ops and fake/meta registrations for variable-length attention forward/backward using cumulative sequence lengths, enabling efficient packed ragged batches. Adds an autograd wrapper and public API exposing varlen attention with support for MQA/GQA, optional mask/bias, causal mode, softcap, deterministic backward, and optional attention probs/LSE for testing. Pads head dim and key seqlen to multiples of 8 for 16‑bit–friendly allocations, rounds workspace shapes, and sanitizes outputs; also supports paged KV via a block table. Improves performance and memory by avoiding per-sequence padding and aligning allocations to hardware-friendly sizes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request introduces support for variable-length attention sequences with comprehensive mask and bias capabilities. The changes refactor parameters for consistency, deprecate legacy interfaces, and implement improved numerical stability through tensor sanitization.
Key changes include:
- Addition of variable-length attention forward and backward functions with proper mask/bias handling
- Parameter name consistency changes from
keep_window_sizetowindow_sizeandscaletosoftmax_scale - Enhanced numerical stability by enabling tensor sanitization to zero out NaN/Inf values
- Documentation updates and deprecated parameter handling in model integration
Reviewed Changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| flash_dmattn/integrations/flash_dynamic_mask_attention.py | Updates parameter names from keep_window_size to window_size |
| flash_dmattn/flash_dmattn_interface.py | Adds variable-length attention functions and enables tensor sanitization |
| flash_dmattn/init.py | Exports new flash_dmattn_varlen_func |
| examples/modeling/modeling_doge.py | Updates model to use new parameter names and removes unused code |
| examples/modeling/configuration_doge.py | Updates configuration parameter names |
| docs/api_reference_zh.md | New Chinese API documentation |
| docs/api_reference.md | Updated English API documentation |
| csrc/flash_dmattn/flash_api.cpp | Uncomments variable-length forward implementation |
| README_zh.md | Updates Chinese README examples |
| README.md | Updates English README examples |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| if ctx.seqlen_k_og % 8 != 0: | ||
| dk = dk[:, : ctx.seqlen_k_og, :, :] | ||
| dv = dv[:, : ctx.seqlen_k_og, :, :] | ||
| if dbias is not None: | ||
| dbias = dbias[..., : ctx.seqlen_k_og] |
Copilot
AI
Oct 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable ctx.seqlen_k_og is referenced but not defined in the context. This will cause an AttributeError at runtime.
| deterministic: bool, | ||
| ) -> torch.Tensor: | ||
| dout, dbias, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, dbias, q, k, v, mask, bias, out)] | ||
| dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] |
Copilot
AI
Oct 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable dbias was removed from the maybe_contiguous call but is still used later in the function. This could cause issues if dbias is not contiguous.
| dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] | |
| dout, q, k, v, mask, bias, out, dbias = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out, dbias)] |
| deterministic: bool, | ||
| ) -> torch.Tensor: | ||
| dout, dbias, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, dbias, q, k, v, mask, bias, out)] | ||
| dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] |
Copilot
AI
Oct 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as above - dbias was removed from maybe_contiguous call but is still used in the function.
| dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] | |
| dout, q, k, v, mask, bias, out, dbias = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out, dbias)] |
| deterministic: bool, | ||
| zero_tensors: bool = False, | ||
| ) -> torch.Tensor: | ||
| dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] |
Copilot
AI
Oct 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing dbias in maybe_contiguous call while it's still used in the backward function.
| dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] | |
| dout, q, k, v, mask, bias, out, dbias = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out, dbias)] |
| deterministic: bool, | ||
| zero_tensors: bool = False, | ||
| ) -> torch.Tensor: | ||
| dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] |
Copilot
AI
Oct 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing dbias in maybe_contiguous call for the varlen backward fake function.
| dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] | |
| dout, q, k, v, mask, bias, out, dbias = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out, dbias)] |
…n_forward for clarity and consistency; rename keep_window_size to window_size and enhance FlashDynamicMaskAttentionKwargs documentation.
Introduce a variable-length attention forward path with mask and bias capabilities, enhancing numerical stability by zeroing NaN/Inf values. Refactor parameters for consistency and usability, and update documentation for clarity. Deprecate outdated arguments and enforce a single backend for attention operations. Improve performance on Ampere+ GPUs while maintaining correctness.