Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Introduce a variable-length attention forward path with mask and bias capabilities, enhancing numerical stability by zeroing NaN/Inf values. Refactor parameters for consistency and usability, and update documentation for clarity. Deprecate outdated arguments and enforce a single backend for attention operations. Improve performance on Ampere+ GPUs while maintaining correctness.

Adds a working variable-length attention forward path with boolean mask and additive bias, supporting broadcast across head dims. Enforces dtype/device/contiguity and shape checks, initializes outputs when requested, and handles empty key sequences.

Keeps Paged KV disabled pending fixes, retains decoding optimization for single-token queries, and exposes optional softmax output for debugging/inspection. Improves usability on Ampere+ GPUs while maintaining correctness and constraints.
Enables sanitization in forward/backward to replace NaN and ±Inf with 0, improving numerical stability and preventing invalid value propagation.

Stops treating bias gradient as an input and ensures the computed bias gradient is contiguous, avoiding layout issues in downstream kernels.
Removes the eager attention fallback and unconditionally routes attention through the flash dynamic mask backend, printing an install hint when the dependency is missing.

Deprecates the past_key_value argument in favor of past_key_values across attention, layer, and LM APIs, switches types to a cache interface, and initializes the dynamic cache with config for correctness.

Disables SDPA/flex/attention-backend support flags to reflect the single supported backend.

Adds a buffer annotation to satisfy linting, drops unused decoder accessors, ignores unused attention weights, and updates a paper link to the HF mirror.

Improves consistency of the attention API and enforces a single, performant backend.
Reorganizes documentation for faster onboarding: adds Quick Start, backend selection/comparison, and a clearer API reference.

Clarifies attention mask/bias shapes with broadcasting support and updates the integration example to drop manual expansion and eager fallback.

Adds concise install instructions, backend availability flags, and unified import guidance; removes redundant backend and summary sections for clarity.
Introduces CUDA custom ops and fake/meta registrations for variable-length attention forward/backward using cumulative sequence lengths, enabling efficient packed ragged batches.

Adds an autograd wrapper and public API exposing varlen attention with support for MQA/GQA, optional mask/bias, causal mode, softcap, deterministic backward, and optional attention probs/LSE for testing.

Pads head dim and key seqlen to multiples of 8 for 16‑bit–friendly allocations, rounds workspace shapes, and sanitizes outputs; also supports paged KV via a block table.

Improves performance and memory by avoiding per-sequence padding and aligning allocations to hardware-friendly sizes.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This pull request introduces support for variable-length attention sequences with comprehensive mask and bias capabilities. The changes refactor parameters for consistency, deprecate legacy interfaces, and implement improved numerical stability through tensor sanitization.

Key changes include:

  • Addition of variable-length attention forward and backward functions with proper mask/bias handling
  • Parameter name consistency changes from keep_window_size to window_size and scale to softmax_scale
  • Enhanced numerical stability by enabling tensor sanitization to zero out NaN/Inf values
  • Documentation updates and deprecated parameter handling in model integration

Reviewed Changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
flash_dmattn/integrations/flash_dynamic_mask_attention.py Updates parameter names from keep_window_size to window_size
flash_dmattn/flash_dmattn_interface.py Adds variable-length attention functions and enables tensor sanitization
flash_dmattn/init.py Exports new flash_dmattn_varlen_func
examples/modeling/modeling_doge.py Updates model to use new parameter names and removes unused code
examples/modeling/configuration_doge.py Updates configuration parameter names
docs/api_reference_zh.md New Chinese API documentation
docs/api_reference.md Updated English API documentation
csrc/flash_dmattn/flash_api.cpp Uncomments variable-length forward implementation
README_zh.md Updates Chinese README examples
README.md Updates English README examples

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +638 to +642
if ctx.seqlen_k_og % 8 != 0:
dk = dk[:, : ctx.seqlen_k_og, :, :]
dv = dv[:, : ctx.seqlen_k_og, :, :]
if dbias is not None:
dbias = dbias[..., : ctx.seqlen_k_og]
Copy link

Copilot AI Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable ctx.seqlen_k_og is referenced but not defined in the context. This will cause an AttributeError at runtime.

Copilot uses AI. Check for mistakes.
deterministic: bool,
) -> torch.Tensor:
dout, dbias, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, dbias, q, k, v, mask, bias, out)]
dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)]
Copy link

Copilot AI Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable dbias was removed from the maybe_contiguous call but is still used later in the function. This could cause issues if dbias is not contiguous.

Suggested change
dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)]
dout, q, k, v, mask, bias, out, dbias = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out, dbias)]

Copilot uses AI. Check for mistakes.
deterministic: bool,
) -> torch.Tensor:
dout, dbias, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, dbias, q, k, v, mask, bias, out)]
dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)]
Copy link

Copilot AI Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as above - dbias was removed from maybe_contiguous call but is still used in the function.

Suggested change
dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)]
dout, q, k, v, mask, bias, out, dbias = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out, dbias)]

Copilot uses AI. Check for mistakes.
deterministic: bool,
zero_tensors: bool = False,
) -> torch.Tensor:
dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)]
Copy link

Copilot AI Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing dbias in maybe_contiguous call while it's still used in the backward function.

Suggested change
dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)]
dout, q, k, v, mask, bias, out, dbias = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out, dbias)]

Copilot uses AI. Check for mistakes.
deterministic: bool,
zero_tensors: bool = False,
) -> torch.Tensor:
dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)]
Copy link

Copilot AI Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing dbias in maybe_contiguous call for the varlen backward fake function.

Suggested change
dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)]
dout, q, k, v, mask, bias, out, dbias = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out, dbias)]

Copilot uses AI. Check for mistakes.
…n_forward for clarity and consistency; rename keep_window_size to window_size and enhance FlashDynamicMaskAttentionKwargs documentation.
@LoserCheems LoserCheems merged commit 30d55b9 into main Oct 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants