-
Notifications
You must be signed in to change notification settings - Fork 39
Fix CUDA forward crash when seqlen_q == 1 #108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Uncomments previously disabled test cases to run full suite of forward equivalence tests across various batch sizes, head configurations, sequence lengths, and causal/non-causal modes. Adds two new edge case configurations with very short sequence lengths to improve test coverage.
Uncomments and activates previously disabled benchmark test cases to enable comprehensive performance testing across various parameter configurations. Includes inference tests with different sequence lengths, batch size variations, head count and dimension testing, window size experiments, and non-causal attention benchmarks. Also fixes inference test parameter from 2 to 1 for proper single-token inference evaluation.
Updates error messages to reflect "flash dynamic mask attention" branding. Adds contiguity checks for mask and bias tensors to ensure proper memory layout. Handles tensor reshaping for grouped query attention scenarios by expanding mask and bias tensors to match the reshaped query dimensions, ensuring consistent tensor shapes throughout the attention computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes a CUDA forward crash that occurs when seqlen_q == 1 in GQA/MQA mode where num_heads > num_heads_k. The crash was caused by a shape mismatch where the kernel expected mask/bias tensors in the shape [B, H_k, ngroups, K] but received [B, H_k, 1, K].
- Fixes shape mismatch in the fast path by expanding mask/bias tensors to the expected shape using zero-copy views
- Updates error messages to reflect "flash dynamic mask attention" instead of "flash attention"
- Enables previously commented test cases in benchmark files to improve test coverage
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| csrc/flash_api.cpp | Core fix: creates expanded views of mask/bias tensors for seqlen_q==1 fast path and updates error messages |
| benchmarks/forward_performance.py | Uncomments test cases including seqlen_q=1 inference scenarios to validate the fix |
| benchmarks/forward_equivalence.py | Uncomments various test configurations and adds seqlen_q=1 test case |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | ||
| TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | ||
| TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | ||
| TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); |
Copilot
AI
Aug 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The stride check for mask is performed on the original mask tensor, but later the code uses mask_view which may have different strides after the expand operation. The stride check should be performed after creating mask_view or should check that the original mask is suitable for expansion.
| TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | |
| // The stride check for mask should be performed after any expand/view operation. | |
| // TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); |
| TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | ||
| TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | ||
| TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); | ||
| TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); |
Copilot
AI
Aug 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The stride check for bias is performed on the original bias tensor, but later the code uses bias_view which may have different strides after the expand operation. The stride check should be performed after creating bias_view or should check that the original bias is suitable for expansion.
Fix #107
Description
Fix CUDA forward crash when seqlen_q == 1 in GQA/MQA mode (num_heads > num_heads_k). In this fast path, the kernel expects mask/bias as [B, H_k, ngroups, K] where ngroups = num_heads / num_heads_k. Previously, Python passed [B, H_k, 1, K], causing a shape mismatch. This PR adapts mask/bias inside C++ by creating zero-copy expanded views to the expected shape, keeping the Python API unchanged and avoiding extra memory.
Type of Change
Related Issues
Changes Made
Code Changes
Documentation
Testing
Test Configuration
Performance Impact
Breaking Changes
Checklist
CUDA-specific
Additional Notes