-
Notifications
You must be signed in to change notification settings - Fork 39
Fix varlen mask and bias tensor shapes for all varlen attention functions #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com>
|
Lood works @copilot, let's complete some tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes a critical bug in varlen attention functions where default mask and bias tensors were created with incorrect shapes, causing RuntimeError when the C++ backend expected different dimensions. The fix ensures that default tensors match the expected shapes for variable-length sequence processing.
- Updates tensor shape creation from batch-based to token-based dimensions
- Corrects shape calculation to use
total_q/total_tokensandnum_heads_kinstead ofbatch_sizeandnum_heads - Updates documentation to reflect the correct expected tensor shapes
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| is_grad = is_grad_enabled and qkv.requires_grad | ||
| if mask is None: | ||
| mask = torch.ones((batch_size, num_heads, max_seqlen, max_seqlen), dtype=qkv.dtype, device=qkv.device) | ||
| mask = torch.ones((total_tokens, num_heads, max_seqlen), dtype=qkv.dtype, device=qkv.device) |
Copilot
AI
Aug 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mask tensor shape should be (total_tokens, num_heads, max_seqlen, max_seqlen) to maintain consistency with attention mask semantics. A 3D mask (total_tokens, num_heads, max_seqlen) cannot properly represent attention relationships between query and key positions.
| mask = torch.ones((total_tokens, num_heads, max_seqlen), dtype=qkv.dtype, device=qkv.device) | |
| mask = torch.ones((total_tokens, num_heads, max_seqlen, max_seqlen), dtype=qkv.dtype, device=qkv.device) |
| mask = torch.ones((total_tokens, num_heads, max_seqlen), dtype=qkv.dtype, device=qkv.device) | ||
| if bias is None: | ||
| bias = torch.zeros((batch_size, num_heads, max_seqlen, max_seqlen), dtype=qkv.dtype, device=qkv.device) | ||
| bias = torch.zeros((total_tokens, num_heads, max_seqlen), dtype=qkv.dtype, device=qkv.device) |
Copilot
AI
Aug 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bias tensor shape should be (total_tokens, num_heads, max_seqlen, max_seqlen) to properly represent attention bias between all query-key pairs. A 3D bias tensor cannot encode relationships between different sequence positions.
| bias = torch.zeros((total_tokens, num_heads, max_seqlen), dtype=qkv.dtype, device=qkv.device) | |
| bias = torch.zeros((total_tokens, num_heads, max_seqlen, max_seqlen), dtype=qkv.dtype, device=qkv.device) |
Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com>
I've completed comprehensive tests for the varlen attention function bug fix! Added three test files:
The tests validate: All tests pass and can be run with Commit: ee3102f |
The varlen attention functions were creating default mask and bias tensors with incorrect shapes, causing a RuntimeError when the C++ backend expected different dimensions.
Problem
When calling
flash_dmattn_varlen_func(and related varlen functions) with defaultattn_mask=Noneandattn_bias=None, the following error occurred:Root Cause
The default mask and bias tensors were being created with shapes:
(batch_size, num_heads, max_seqlen_q, max_seqlen_k)(total_q, num_heads_k, max_seqlen_k)Where:
total_q= sum of all sequence lengths in the batch (first dimension of query tensor)num_heads_k= number of key/value heads (second dimension of key tensor)Solution
Fixed the default tensor shape creation in three varlen functions:
FlashDMAttnVarlenFunc: Now creates(total_q, num_heads_k, max_seqlen_k)FlashDMAttnVarlenQKVPackedFunc: Now creates(total_tokens, num_heads, max_seqlen)FlashDMAttnVarlenKVPackedFunc: Now creates(total_q, num_heads_k, max_seqlen_k)Example
The bug report scenario now works correctly:
Before: Creates mask/bias with shape
(3, 16, 1024, 1024)→ RuntimeErrorAfter: Creates mask/bias with shape
(2304, 16, 1024)→ SuccessTests Added
Added comprehensive test suite to validate the fix:
Also updated the documentation to reflect the correct expected tensor shapes for all varlen functions.
Fixes #113.
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.