Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

@LoserCheems LoserCheems commented Aug 13, 2025

Fix #107

Description

Fix CUDA forward crash when seqlen_q == 1 in GQA/MQA mode (num_heads > num_heads_k). In this fast path, the kernel expects mask/bias as [B, H_k, ngroups, K] where ngroups = num_heads / num_heads_k. Previously, Python passed [B, H_k, 1, K], causing a shape mismatch. This PR adapts mask/bias inside C++ by creating zero-copy expanded views to the expected shape, keeping the Python API unchanged and avoiding extra memory.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)

Related Issues

  • Fixes #NNN
  • Related to #MMM (optional)

Changes Made

Code Changes

  • C++: Adjust fixed-length forward fast path in csrc/flash_api.cpp:
    • In the seqlenq_ngroups_swapped path (decoding with GQA/MQA), create zero-copy broadcast views:
      • mask_view = mask.expand({B, H_k, ngroups, K})
      • bias_view = bias.expand({B, H_k, ngroups, K})
    • Keep mask and bias const; do not mutate user inputs.
    • Continue to reshape/transpose q/out in the fast path; Python API unchanged.
  • No Python API changes.
  • No CUDA kernel changes.

Documentation

  • N/A

Testing

  • Reproduction script (before fix: crash; after fix: pass):
import torch
from flash_dmattn.flash_dmattn_interface import _wrapped_flash_dmattn_forward

torch.manual_seed(0)
device, dtype = "cuda", torch.float16
B, H, Hk, d = 2, 8, 2, 64
ngroups = H // Hk
seqlen_q, seqlen_k = 1, 128

q = torch.randn(B, seqlen_q, H, d, device=device, dtype=dtype)
k = torch.randn(B, seqlen_k, Hk, d, device=device, dtype=dtype)
v = torch.randn(B, seqlen_k, Hk, d, device=device, dtype=dtype)
mask = torch.ones(B, Hk, seqlen_q, seqlen_k, device=device, dtype=dtype)
bias = torch.zeros(B, Hk, seqlen_q, seqlen_k, device=device, dtype=dtype)

out, lse, S = _wrapped_flash_dmattn_forward(
    q, k, v, mask, bias, 1.0 / (d ** 0.5), is_causal=False, softcap=0.0, return_softmax=False
)
print(out.shape, lse.shape)  # Expected: (B, seqlen_q, H, d), (B, H, seqlen_q)
  • Sanity checks:
    • Existing demos/benchmarks run: benchmarks/forward_equivalence.py, forward_performance.py
    • Numerical parity checked on random inputs for both seqlen_q=1 and seqlen_q>1
    • Verified no extra memory is allocated for mask/bias (expand is zero-copy)

Test Configuration

  • OS: Ubuntu 22.04 and Windows 11
  • Python: 3.10/3.11
  • PyTorch: 2.1–2.4
  • CUDA: 11.8/12.1
  • GPU: A100 (SM 8.0), RTX 4090 (SM 8.9)

Performance Impact

  • Neutral: No additional memory traffic for mask/bias (expand view).
  • No observed regression in forward benchmarks for decoding or full-context.

Breaking Changes

  • None. Python API and tensor contracts remain unchanged.

Checklist

  • Code follows style guidelines
  • Self-reviewed changes
  • Comments added for non-trivial logic (fast path shape adaptation)
  • No new warnings
  • Added/updated minimal tests or scripts to verify the fix
  • Local validations pass

CUDA-specific

  • Compiles without warnings on SM 8.0+
  • Tested on A100/4090
  • No extra memory allocations for mask/bias; no leaks

Additional Notes

  • The fix uses zero-copy expand for mask/bias to [B, H_k, ngroups, K] only when the fast path triggers (seqlen_q == 1 and H > H_k). For other cases, behavior is unchanged.
  • This PR does not touch backward; gradients are unaffected.

Uncomments previously disabled test cases to run full suite of forward equivalence tests across various batch sizes, head configurations, sequence lengths, and causal/non-causal modes.

Adds two new edge case configurations with very short sequence lengths to improve test coverage.
Uncomments and activates previously disabled benchmark test cases to enable comprehensive performance testing across various parameter configurations.

Includes inference tests with different sequence lengths, batch size variations, head count and dimension testing, window size experiments, and non-causal attention benchmarks.

Also fixes inference test parameter from 2 to 1 for proper single-token inference evaluation.
Updates error messages to reflect "flash dynamic mask attention" branding.

Adds contiguity checks for mask and bias tensors to ensure proper memory layout.

Handles tensor reshaping for grouped query attention scenarios by expanding mask and bias tensors to match the reshaped query dimensions, ensuring consistent tensor shapes throughout the attention computation.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a CUDA forward crash that occurs when seqlen_q == 1 in GQA/MQA mode where num_heads > num_heads_k. The crash was caused by a shape mismatch where the kernel expected mask/bias tensors in the shape [B, H_k, ngroups, K] but received [B, H_k, 1, K].

  • Fixes shape mismatch in the fast path by expanding mask/bias tensors to the expected shape using zero-copy views
  • Updates error messages to reflect "flash dynamic mask attention" instead of "flash attention"
  • Enables previously commented test cases in benchmark files to improve test coverage

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
csrc/flash_api.cpp Core fix: creates expanded views of mask/bias tensors for seqlen_q==1 fast path and updates error messages
benchmarks/forward_performance.py Uncomments test cases including seqlen_q=1 inference scenarios to validate the fix
benchmarks/forward_equivalence.py Uncomments various test configurations and adds seqlen_q=1 test case

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension");
Copy link

Copilot AI Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stride check for mask is performed on the original mask tensor, but later the code uses mask_view which may have different strides after the expand operation. The stride check should be performed after creating mask_view or should check that the original mask is suitable for expansion.

Suggested change
TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension");
// The stride check for mask should be performed after any expand/view operation.
// TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension");

Copilot uses AI. Check for mistakes.
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension");
Copy link

Copilot AI Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stride check for bias is performed on the original bias tensor, but later the code uses bias_view which may have different strides after the expand operation. The stride check should be performed after creating bias_view or should check that the original bias is suitable for expansion.

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 820fecf into main Aug 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] q_len=1 with GQA/MQA crashes in CUDA forward due to mask/bias shape mismatch

5 participants