-
Notifications
You must be signed in to change notification settings - Fork 39
[BUG FIX] Unify masking utilities and improve performance #209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Replaces the ad‑hoc mask builder with a shared utility to centralize top‑k/causal masking across CUDA, Triton, and Flex paths. Passes explicit tensor shapes and dtype min to ensure correctness, and adds block‑sparse support (block size 64) for dynamic CUDA. Handles optional masks safely when repeating KV for GQA. Updates benchmark bias to broadcast over queries (B,H,1,K) to reduce memory and match masking expectations. Improves consistency, reduces duplication, and prepares for extensible masking strategies.
Removes an unused symbol to clean up imports and silence linter warnings. Reduces clutter and avoids confusion with unreferenced utilities. Introduces no functional changes.
Replaces the local mask builder with a centralized utility to standardize top-k/causal masking across Python, CUDA, Triton, and Flex paths. Passes explicit batch/query/key sizes and dtype min, repeats masks only when present, and skips masked_fill when unneeded. Reduces duplication, improves consistency and maintainability, and streamlines GQA handling.
Replaces ad-hoc mask construction with a shared mask utility to unify top‑k/causal masking across Python, CUDA, Triton, and Flex paths. Reduces duplication and allows safely skipping masking when not required. Fixes gradient reporting in the Flex path by returning grads w.r.t. the original input tensors. Also clarifies shape handling and guards masked fill, improving robustness.
Replaces duplicated mask logic with a shared utility to standardize top‑k/causal masking and dtype handling across CUDA, Triton, and Flex backends. Aligns attention bias to a broadcastable per‑query shape to cut memory and simplify kernel expectations. Removes redundant KV/mask/bias repetition in the Triton path, repeats conditionally for Flex, and makes GQA fan‑out explicit for correctness and performance.
Updates forward/backward equivalence benchmarks to create attention bias with a singleton query dimension so it broadcasts across queries. Aligns shapes with kernel expectations during cached decoding, reduces memory footprint, and prevents shape mismatches across CUDA, Triton, and Flex paths.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
Comments suppressed due to low confidence (1)
benchmarks/backward_equivalence.py:268
- Missing repeat_kv calls for Triton implementation. The Triton function expects key_states, value_states, attn_mask, and attn_bias to have
num_headsdimension (notnum_kv_heads). The following code should be added before the transpose operations:
# Repeat KV for multi-head attention (GQA support)
key_states = repeat_kv(key_states, num_queries_per_kv)
value_states = repeat_kv(value_states, num_queries_per_kv)
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)This is consistent with the forward implementation in forward_equivalence.py lines 242-245.
# Ensure correct data types and memory layout for Triton function
query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim]
value_states = value_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim]
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Ensure correct data types and memory layout for Triton function | ||
| query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] | ||
| key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] | ||
| value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] |
Copilot
AI
Nov 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing repeat_kv calls for Triton implementation. The Triton function expects key_states, value_states, attn_mask, and attn_bias to have num_heads dimension (not num_kv_heads). The following code should be added before the transpose operations:
# Repeat KV for multi-head attention (GQA support)
key_states = repeat_kv(key_states, num_queries_per_kv)
value_states = repeat_kv(value_states, num_queries_per_kv)
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
attn_bias = repeat_kv(attn_bias, num_queries_per_kv)This is consistent with the forward implementation in forward_performance.py lines 273-276 and forward_equivalence.py lines 242-245.
| batch_size, num_heads, query_len, _ = query_states.shape | ||
| _, num_kv_heads, key_len, _ = key_states.shape | ||
|
|
||
| num_queries_per_kv = num_heads // num_kv_heads |
Copilot
AI
Nov 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable num_queries_per_kv is not used.
| num_queries_per_kv = num_heads // num_kv_heads |
| num_queries_per_kv = num_heads // num_kv_heads | ||
|
|
Copilot
AI
Nov 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable num_queries_per_kv is not used.
| num_queries_per_kv = num_heads // num_kv_heads |
| num_queries_per_kv = num_heads // num_kv_heads | ||
|
|
Copilot
AI
Nov 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable num_queries_per_kv is not used.
| num_queries_per_kv = num_heads // num_kv_heads |
| num_queries_per_kv = num_heads // num_kv_heads | ||
|
|
Copilot
AI
Nov 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable num_queries_per_kv is not used.
| num_queries_per_kv = num_heads // num_kv_heads |
Summary
Root Cause
Changes
Reproduction
Tests
Compatibility
Checklist