-
Notifications
You must be signed in to change notification settings - Fork 39
Optimize triton version: GQA, mask/bias broadcasting, skip inactive tiles, and stability fixes #200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Adds forward support for GQA/MQA (different Q vs KV heads) with optional boolean mask and bias, including broadcasting across batch/head/seq dims and per-head routing. Switches to compile-time mask/bias flags, removes the scratchpad workaround, simplifies scaling, and indexes LSE/Out by Q heads. Skips masked-out tiles, tightens mask semantics (True = keep), and fixes backward mask handling. Introduces a contiguity helper, bumps pipeline stages, and errors out on Triton backward for GQA/MQA until implemented.
|
@ftgreat please test this version🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds GQA/MQA support to the Triton forward kernel while improving stability and performance through mask/bias broadcasting, early tile skipping, and tighter shape validation.
- Enables Triton forward kernel to handle differing Q and KV head counts (GQA/MQA) with proper head mapping
- Implements mask/bias broadcasting across batch, head, and sequence dimensions
- Optimizes masked workloads by skipping tiles with no active elements
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -1,3 +1,4 @@ | |||
| from typing import Optional | |||
Copilot
AI
Oct 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The typing.Optional import is used only for the maybe_contiguous function at line 1096. Consider using Optional from typing or using T | None syntax if Python 3.10+ is the minimum supported version.
| seqlen_q, | ||
| seqlen_k, | ||
| seqlen_q_rounded, | ||
| headdim, |
Copilot
AI
Oct 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CACHE_KEY_SEQLEN_Q and CACHE_KEY_SEQLEN_K parameters now have explicit tl.constexpr type hints, but their purpose and usage are not documented. Consider adding a comment explaining that these are compile-time constants for Triton's kernel cache key.
| headdim, | |
| headdim, | |
| # CACHE_KEY_SEQLEN_Q and CACHE_KEY_SEQLEN_K are compile-time constants used as part of Triton's kernel cache key. | |
| # They help Triton select and cache the appropriate kernel configuration based on sequence lengths. |
| lse_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) | ||
| m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) |
Copilot
AI
Oct 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Using tl.full instead of tl.zeros(...) - float('inf') is cleaner, but the string representation float('-inf') differs from the previous numeric subtraction. For consistency with line 120, consider using the same representation (either both string or both computed).
| if EVEN_M & EVEN_N: | ||
|
|
||
| # Load q: it will stay in SRAM throughout | ||
| if EVEN_M: |
Copilot
AI
Oct 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition changed from EVEN_M & EVEN_N to just EVEN_M when loading the query tensor. This may cause incorrect behavior when EVEN_M=True but EVEN_N=False, as the query is loaded without any masking. The original code had a comment noting this was necessary to avoid a Triton bug. Verify that this change is safe or restore the original condition.
| if EVEN_M: | |
| # The original condition required EVEN_M & EVEN_N to avoid a Triton bug. | |
| if EVEN_M & EVEN_N: |
| # scale acc_o | ||
| acc_o_scale = tl.exp(m_i - m_ij) | ||
| # Load k | ||
| if EVEN_N: |
Copilot
AI
Oct 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed from EVEN_N & EVEN_M to just EVEN_N for loading K and V tensors. The original comment stated 'If we just do "if EVEN_N", there seems to be some race condition'. Removing the EVEN_M check may reintroduce this race condition. This change should be verified or reverted.
| if EVEN_N: | |
| if EVEN_N & EVEN_M: |
flash_dmattn/flash_dmattn_triton.py
Outdated
| batch, seqlen_q, nheads_q, d = q.shape | ||
| _, seqlen_k, nheads_k, _ = k.shape | ||
|
|
||
| assert nheads_q % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" |
Copilot
AI
Oct 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message should clarify the actual constraint. It currently says 'divisible by KV heads' but the check is nheads_q % nheads_k == 0. Consider rephrasing to 'Number of query heads must be divisible by number of KV heads for GQA/MQA' for clarity.
| assert nheads_q % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" | |
| assert nheads_q % nheads_k == 0, "Number of query heads must be divisible by number of KV heads for GQA/MQA" |
| _, seqlen_k, nheads_k, _ = k.shape | ||
|
|
Copilot
AI
Oct 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing assertion that k.shape[0] == batch and k.shape[3] == d after unpacking. The forward function should validate that K and V have matching batch size and head dimension with Q for correctness. Similar validation exists for V's shape on the removed line but is now missing.
| _, seqlen_k, nheads_k, _ = k.shape | |
| k_batch, seqlen_k, nheads_k, k_d = k.shape | |
| v_batch, v_seqlen, v_nheads, v_d = v.shape | |
| assert k_batch == batch, "Key tensor batch size must match Query tensor batch size" | |
| assert k_d == d, "Key tensor head dimension must match Query tensor head dimension" | |
| assert v_batch == batch, "Value tensor batch size must match Query tensor batch size" | |
| assert v_d == d, "Value tensor head dimension must match Query tensor head dimension" |
flash_dmattn/flash_dmattn_triton.py
Outdated
| BLOCK_N=BLOCK_N, | ||
| num_warps=num_warps, | ||
| num_stages=1, | ||
| num_stages=2, |
Copilot
AI
Oct 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number of pipeline stages changed from 1 to 2. While this can improve performance by overlapping memory and compute, it increases register pressure and may cause issues on older GPUs or with large block sizes. This change should be documented or made configurable if it affects compatibility.
|
|
||
|
|
||
| def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: | ||
| return x.contiguous() if x is not None and x.stride(-1) != 1 else x |
Copilot
AI
Oct 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function returns x when x is None, which is correct, but also returns x when x.stride(-1) == 1. However, if x is None, the expression x.stride(-1) will raise an AttributeError. The condition should be return x if x is None or x.stride(-1) == 1 else x.contiguous() to properly short-circuit.
| return x.contiguous() if x is not None and x.stride(-1) != 1 else x | |
| return x if x is None or x.stride(-1) == 1 else x.contiguous() |
flash_dmattn/flash_dmattn_triton.py
Outdated
| if query.shape[2] != key.shape[2]: | ||
| raise RuntimeError( | ||
| "Triton backward for GQA/MQA (nheads_q != nheads_k) is not implemented yet. " | ||
| "Use the CUDA backend for training or disable grad for Triton path." |
Copilot
AI
Oct 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message should provide more actionable guidance. Consider adding information about how to switch to the CUDA backend or a reference to documentation. For example: 'Triton backward for GQA/MQA is not implemented yet. Use the CUDA backend by setting use_triton=False, or disable gradients for this operation.'
| "Use the CUDA backend for training or disable grad for Triton path." | |
| "To resolve this, use the CUDA backend by setting use_triton=False, or disable gradients for this operation. " | |
| "See the documentation for more details: https://github.com/your-repo/docs#backend-selection" |
Removes rigid 4D asserts and explicit expands for mask/bias to support broadcast-friendly inputs and avoid unnecessary memory overhead. Renames flags to lower-case and passes them through consistently to the kernel, updating conditional strides and placeholder tensors accordingly. Improves readability and aligns runtime flags with kernel expectations.
Drops backward-path guards that blocked attention mask gradients and GQA/MQA head configs with the Triton backend. Expands training support without forcing CUDA fallback; relies on the underlying kernel for validation.
Enables kernel autotuning with multiple tile configs and warp counts, keyed by sequence lengths, causality, mask/bias presence, and head dim for better performance and correct cache separation. Defaults missing mask/bias to empty tensors up front to simplify the call path and stabilize the kernel signature. Removes hardcoded launch params to defer selection to the autotuner; standardizes num_stages in configs.
Supports mixed Q/KV head counts (GQA/MQA) by validating divisibility, updating grid/shape logic to query heads, and passing head-count ratios to the kernel. Makes mask and bias truly optional: supplies empty tensors with zero strides and passes has_mask/has_bias flags to the kernel, removing strict stride/layout assumptions. Improves robustness with clearer assertions and compiles keys (is_causal/has_mask/has_bias/head dim), and standardizes forward to require explicit mask/bias args.
Enables backward pass with differing Q/K head counts (GQA/MQA) and broadcasting of mask/bias across 1, KV, or Q heads by remapping head offsets and conditioning pointer advances on feature presence. Introduces a compile-time switch for atomic accumulation that triggers under sequence parallelism or head-count mismatches to prevent write races, while avoiding atomics otherwise for performance. Improves correctness and flexibility across broader attention configurations.
Introduces a flag to perform atomic accumulation of gradients when tiles may contend, eliminating race conditions in the backward path. Retains fast masked stores for safe even/fully covered cases; applies atomic adds with appropriate masks otherwise. Improves numerical correctness for uneven M/N and variable head dims while preserving performance where possible.
Uses query head count for backward offsets to correct head/batch mapping, enabling GQA/MQA configurations without misindexing. Aligns masked score accumulation with the tensor dtype by using a zero literal that avoids unintended type promotion, improving numerical stability and performance across fp16/bf16.
Adds a comprehensive backward equivalence test suite validating Triton gradients against the Python prototype across many shapes and head dims, with accuracy and speed reporting. Enables the Triton test path via the test_type flag. Removes redundant .contiguous() calls before the Triton attention invocation to avoid extra copies and rely on stride-aware kernels, improving memory use and potential performance. Skips when Triton is unavailable and performs GPU memory cleanup between runs.
Improves correctness and stability of the backward path with GQA/MQA and broadcasted mask/bias: - Unifies head indexing and fixes DK/DV/DBias head offsets; expands and reduces grads when head counts differ - Adds broadcast-aware strides and rounded K length; pads head dim to 8 for aligned 16‑bit storage and crops after - Reworks bwd column kernel to handle optional mask/bias, accumulate bias grad to avoid atomics, and inline safe stores - Enables sequence-parallel config and introduces dbias accumulation heuristic for better performance Also refines masking/bias application and removes race-prone barriers, addressing correctness for uneven shapes and broadcasts.
|
Current benchmark performance comparison 🐍 PyTorch version: 2.9.0+cu128
🔥 Device: cuda
🎮 CUDA device: NVIDIA GeForce RTX 4090
💾 CUDA memory: 22.5 GB
🎲 Random seed: 42
📊 Test type: sdpa-vs-cuda-vs-triton
🔄 Runs: 3, Warmup: 2
🏆============================================================================🏆
🔥 Backward Pass Performance Benchmark 🔥
🏆============================================================================🏆
📊 Backward Pass Benchmark Results (averaged over 3 runs):
🔧 Configuration ⚡ SDPA-BWD 🚀 CUDA-BWD 🌟 Triton-BWD ✨ Flex-BWD 📈 Speedup
🔄----------------------------------------------------------------------------------------------------------------------------------------------------------------🔄
📊 B1 Hq2 Hkv1 Q256 K256 D64 W1024 C ⚡ 0.50ms 🚀 0.58ms 🌟 0.57ms ✨ N/A 📈 CUDA: 0.9x, Tri: 0.9x
📊 B1 Hq2 Hkv1 Q512 K512 D64 W1024 C ⚡ 0.48ms 🚀 0.37ms 🌟 0.53ms ✨ N/A 📈 CUDA: 1.3x, Tri: 0.9x
📊 B1 Hq2 Hkv1 Q1024 K1024 D64 W1024 C ⚡ 0.44ms 🚀 0.45ms 🌟 0.72ms ✨ N/A 📈 CUDA: 1.0x, Tri: 0.6x
📊 B1 Hq2 Hkv1 Q2048 K2048 D64 W1024 C ⚡ 0.61ms 🚀 0.45ms 🌟 1.05ms ✨ N/A 📈 CUDA: 1.3x, Tri: 0.6x
📊 B1 Hq2 Hkv1 Q4096 K4096 D64 W1024 C ⚡ 2.62ms 🚀 0.92ms 🌟 2.15ms ✨ N/A 📈 CUDA: 2.9x, Tri: 1.2x
📊 B1 Hq2 Hkv1 Q8192 K8192 D64 W1024 C ⚡ 10.12ms 🚀 2.02ms 🌟 5.73ms ✨ N/A 📈 CUDA: 5.0x, Tri: 1.8x
📊 B1 Hq2 Hkv1 Q16384 K16384 D64 W1024 C ⚡ 40.61ms 🚀 6.09ms 🌟 22.11ms ✨ N/A 📈 CUDA: 6.7x, Tri: 1.8x
🔄----------------------------------------------------------------------------------------------------------------------------------------------------------------🔄 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # if we just call tl.load(k_ptrs), we get the wrong output! | ||
| if EVEN_N & EVEN_M: | ||
|
|
||
| # Load k and v, them will stay in SRAM throughout |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected grammar from 'them will stay' to 'they will stay' in comment.
| # Load k and v, them will stay in SRAM throughout | |
| # Load k and v, they will stay in SRAM throughout |
| _, seqlen_k, nheads_k, _ = k.shape | ||
|
|
||
| assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" | ||
| assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected grammar from 'only support' to 'only supports' in error message.
| assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" | |
| assert d <= 128, "FlashDynamicMaskAttention only supports head dimensions up to 128" |
| _, seqlen_k, nheads_k, dk = k.shape | ||
|
|
||
| assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" | ||
| assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected grammar from 'only support' to 'only supports' in error message.
| assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" | |
| assert d <= 128, "FlashDynamicMaskAttention only supports head dimensions up to 128" |
| # Apply scaling | ||
| acc_s = acc_s * softmax_scale | ||
|
|
||
| lse_i = tl.load(LSE + offs_m_curr) |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented-out line 525 should be removed. If it's being kept for reference, add a comment explaining why the new implementation differs (handling -inf case).
| lse_i = tl.load(LSE + offs_m_curr) | |
| lse_i = tl.load(LSE + offs_m_curr) | |
| # Previous implementation did not handle the case where lse_i == -inf, which can result in NaNs. |
| ctx.seqlen_k_bias_og = attn_bias.shape[-1] if attn_bias is not None else 0 | ||
| return o |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The saved head_size_og is not stored in the context but is needed in backward pass. Line 1179 computes head_size_og but it's not saved to ctx, yet line 1214 in backward tries to use do.size(3) assuming it matches. If query was padded in forward, the original head size should be saved to ctx for proper gradient slicing in backward.
| or ((bias.shape[-2] == 1) and (seqlen_q > 1)) | ||
| ): | ||
| if bias.shape[-2] == 1: | ||
| dbias_expanded = torch.zeros(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using torch.zeros for dbias_expanded allocation when bias.shape[-2] == 1 may be inefficient. Consider using torch.empty and ensuring accumulation is handled correctly, as zeros initialization adds unnecessary overhead for a tensor that will be written to.
| dbias_expanded = torch.zeros(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) | |
| dbias_expanded = torch.empty(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) |
| dbias_expanded = torch.zeros(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) | ||
| else: | ||
| dbias_expanded = torch.zeros(batch, nheads, seqlen_q, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using torch.zeros for dbias_expanded allocation may be inefficient. Consider using torch.empty if the tensor will be fully written to during computation, avoiding the overhead of zero initialization.
| dbias_expanded = torch.zeros(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) | |
| else: | |
| dbias_expanded = torch.zeros(batch, nheads, seqlen_q, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) | |
| dbias_expanded = torch.empty(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) | |
| else: | |
| dbias_expanded = torch.empty(batch, nheads, seqlen_q, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) |
| offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | ||
| offs_d = tl.arange(0, BLOCK_HEADDIM) | ||
| # load | ||
| # Load o |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Corrected comment from 'Load o' to 'Load output' for clarity.
| # Load o | |
| # Load output |
Moves softmax scaling from logits to queries to avoid per-block scaling and keep bias unscaled. Initializes logits with bias, then adds the dot product and applies masks afterward. Improves numerical stability and likely reduces instruction count in the forward path.
Moves softmax scaling from logits to keys to reduce redundant multiplies and improve numerical stability/perf. Reworks score accumulation to start from bias, then matmul, then masking, preserving masking semantics while simplifying flow. Introduces the reciprocal scale and uses it when accumulating the query gradients so gradients are computed with unscaled keys.
Triton fwd/bwd: fuse softmax scale into operands, start from bias, preserve mask semanticsWhat changed
Why this helps
Correctness
Environment and setup
Performance summary (Triton vs SDPA)
Notes
cc: @ftgreat |
Initializes the score accumulator within the bias branch and zero-initializes otherwise. Prevents referencing an undefined bias when disabled and improves compilation stability in both forward and backward kernels.
…as dtype checks Add HAS_MASK, HAS_BIAS and HAS_INDICE to the autotune key to ensure different kernel configs are cached per mask/bias/indice usage. Also enforce bias dtype to match query dtype (only fp16/bf16) and standardize the mask dtype assert message.
Moves softmax scaling to the final dk update to cut register pressure and simplify accumulation. Aligns dq accumulation with unscaled k for more stable gradients.
…ard/_flash_dmattn_backward and update call sites
Ensures backward only returns a bias gradient when bias exists, keeping the signature consistent for biasless calls.
Summary
Design
Changes
Implementation Notes
Tests
Docs
Checklist
Additional Notes