Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • Add GQA/MQA support to the Triton forward kernel (Q heads can differ from KV heads).
  • Support broadcasting for attention mask and bias across batch, head, and sequence dims.
  • Improve stability and correctness with tighter shape checks and scalar head-indexing.
  • Optimize masked workloads by skipping compute when a tile has no active elements.
  • Keep backward GQA unimplemented for now (explicit error), with clear TODO.

Design

  • Head mapping (GQA): For each query head h_q, map to kv head h_k = h_q // (Hq/Hk), aligning with CUDA’s h/h_k ratio. LSE/output are indexed by query head.
  • Mask/Bias broadcasting: Accept head dimension in {1, Hk, Hq}; batch and sequence dims may be 1. Expand to a broadcasted view and compute the correct head index per tile.
  • Early-tile skip: If a tile’s dynamic mask has no active elements, skip the iteration entirely to avoid useless loads/compute.
  • Backward guard: Triton bwd raises a clear error when Hq != Hk to avoid silent incorrect gradients; CUDA path remains the training fallback for GQA.

Changes

  • Public API remains the same: triton_dmattn_func(query, key, value, attn_mask=None, attn_bias=None, ...)
  • Forward (Triton):
    • Accept K/V with Hk different from Hq (Hq % Hk == 0).
    • Support mask/bias head dims in {1, Hk, Hq}; broadcast batch/seq dims in {1, B}, {1, Lq}, {1, Lk}.
    • Kernel arguments extended to include nheads_q, nheads_k, nheads_mask, nheads_bias, and h_h_k_ratio.
  • Backward (Triton):
    • If nheads_q != nheads_k, raise RuntimeError (“not implemented yet; use CUDA or disable grad”). No change for Hq == Hk.

Implementation Notes

  • Head indexing:
    • off_hq from program id, off_hk = off_hq // h_h_k_ratio, head selection for mask/bias is 0 (if 1 head), off_hk (if Hk heads), or off_hq (if Hq heads).
  • Broadcasting:
    • Expand mask/bias to (B, Hmask/Hbias, Lq, Lk) so pointer arithmetic and strides are well-defined; strides are 0 for broadcast axes.
  • Stability fixes:
    • Replace tl.where-based scalar selection for head indices with scalar branches to avoid tensor/scalar ambiguity.
    • Strengthen K/V shape checks (seqlen and head count equality) and head_dim equality across Q/K/V.
    • When mask/bias are None, pass a valid dummy tensor pointer and rely on compile-time HAS_MASK/HAS_BIAS to avoid dereferencing.
  • Early skip:
    • Use tl.reduce_or on the boolean mask tile; if no True, skip load/compute for K/V and continue.

Tests

  • Forward GQA sanity:
    • B=2, Lq=Lk=256, Hq=16, Hk=4, D=64/128, with/without causal, mask/bias head dims in {1, Hk, Hq}.
    • Verified shapes and absence of NaNs/Infs; observed identical shapes and finite outputs.
  • Equivalence:
    • Recommend running forward_equivalence.py with Triton in native GQA mode (no KV repetition) against CUDA GQA as reference.
  • Backward:
    • Confirms Triton raises on Hq != Hk (expected), and still works for Hq == Hk.

Docs

  • Notes:
    • Triton now supports GQA in forward; backward for GQA is not implemented and raises an error.
    • Mask/bias head broadcasting supported: head dims 1/Hk/Hq; batch/seq dims 1/B and 1/Lq, 1/Lk.
  • Follow-up: add a short section in integration.md about Triton GQA support and limitation for backward.

Checklist

  • Linked issue provided
  • API stable (no breaking changes)
  • Tests added or updated (recommend adding native-GQA Triton run in forward_equivalence.py)
  • Docs added or updated (to add a brief note on Triton GQA forward and bwd limitation)
  • No known performance regressions (no regressions observed in local forward runs; early-tile skip improves masked workloads)

Additional Notes

  • The Triton kernel launches over batch × Hq to match output/LSE layout; K/V are indexed by the grouped head off_hk.
  • For dynamic masks with sparse patterns, the early “Skip this iteration if no active elements” path avoids unnecessary global loads and math.
  • For training with GQA, the CUDA backend remains the recommended path until Triton backward supports Hq != Hk.

Adds forward support for GQA/MQA (different Q vs KV heads) with optional boolean mask and bias, including broadcasting across batch/head/seq dims and per-head routing.

Switches to compile-time mask/bias flags, removes the scratchpad workaround, simplifies scaling, and indexes LSE/Out by Q heads. Skips masked-out tiles, tightens mask semantics (True = keep), and fixes backward mask handling.

Introduces a contiguity helper, bumps pipeline stages, and errors out on Triton backward for GQA/MQA until implemented.
Copilot AI review requested due to automatic review settings October 27, 2025 10:02
@LoserCheems LoserCheems changed the title Triton forward: GQA, mask/bias broadcasting, skip inactive tiles, and stability fixes Optimize triton version: GQA, mask/bias broadcasting, skip inactive tiles, and stability fixes Oct 27, 2025
@LoserCheems
Copy link
Collaborator Author

@ftgreat please test this version🤗

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds GQA/MQA support to the Triton forward kernel while improving stability and performance through mask/bias broadcasting, early tile skipping, and tighter shape validation.

  • Enables Triton forward kernel to handle differing Q and KV head counts (GQA/MQA) with proper head mapping
  • Implements mask/bias broadcasting across batch, head, and sequence dimensions
  • Optimizes masked workloads by skipping tiles with no active elements

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@@ -1,3 +1,4 @@
from typing import Optional
Copy link

Copilot AI Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typing.Optional import is used only for the maybe_contiguous function at line 1096. Consider using Optional from typing or using T | None syntax if Python 3.10+ is the minimum supported version.

Copilot uses AI. Check for mistakes.
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
Copy link

Copilot AI Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CACHE_KEY_SEQLEN_Q and CACHE_KEY_SEQLEN_K parameters now have explicit tl.constexpr type hints, but their purpose and usage are not documented. Consider adding a comment explaining that these are compile-time constants for Triton's kernel cache key.

Suggested change
headdim,
headdim,
# CACHE_KEY_SEQLEN_Q and CACHE_KEY_SEQLEN_K are compile-time constants used as part of Triton's kernel cache key.
# They help Triton select and cache the appropriate kernel configuration based on sequence lengths.

Copilot uses AI. Check for mistakes.
Comment on lines +120 to +121
lse_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
Copy link

Copilot AI Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Using tl.full instead of tl.zeros(...) - float('inf') is cleaner, but the string representation float('-inf') differs from the previous numeric subtraction. For consistency with line 120, consider using the same representation (either both string or both computed).

Copilot uses AI. Check for mistakes.
if EVEN_M & EVEN_N:

# Load q: it will stay in SRAM throughout
if EVEN_M:
Copy link

Copilot AI Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition changed from EVEN_M & EVEN_N to just EVEN_M when loading the query tensor. This may cause incorrect behavior when EVEN_M=True but EVEN_N=False, as the query is loaded without any masking. The original code had a comment noting this was necessary to avoid a Triton bug. Verify that this change is safe or restore the original condition.

Suggested change
if EVEN_M:
# The original condition required EVEN_M & EVEN_N to avoid a Triton bug.
if EVEN_M & EVEN_N:

Copilot uses AI. Check for mistakes.
# scale acc_o
acc_o_scale = tl.exp(m_i - m_ij)
# Load k
if EVEN_N:
Copy link

Copilot AI Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed from EVEN_N & EVEN_M to just EVEN_N for loading K and V tensors. The original comment stated 'If we just do "if EVEN_N", there seems to be some race condition'. Removing the EVEN_M check may reintroduce this race condition. This change should be verified or reverted.

Suggested change
if EVEN_N:
if EVEN_N & EVEN_M:

Copilot uses AI. Check for mistakes.
batch, seqlen_q, nheads_q, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape

assert nheads_q % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA"
Copy link

Copilot AI Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message should clarify the actual constraint. It currently says 'divisible by KV heads' but the check is nheads_q % nheads_k == 0. Consider rephrasing to 'Number of query heads must be divisible by number of KV heads for GQA/MQA' for clarity.

Suggested change
assert nheads_q % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA"
assert nheads_q % nheads_k == 0, "Number of query heads must be divisible by number of KV heads for GQA/MQA"

Copilot uses AI. Check for mistakes.
Comment on lines +871 to +872
_, seqlen_k, nheads_k, _ = k.shape

Copy link

Copilot AI Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing assertion that k.shape[0] == batch and k.shape[3] == d after unpacking. The forward function should validate that K and V have matching batch size and head dimension with Q for correctness. Similar validation exists for V's shape on the removed line but is now missing.

Suggested change
_, seqlen_k, nheads_k, _ = k.shape
k_batch, seqlen_k, nheads_k, k_d = k.shape
v_batch, v_seqlen, v_nheads, v_d = v.shape
assert k_batch == batch, "Key tensor batch size must match Query tensor batch size"
assert k_d == d, "Key tensor head dimension must match Query tensor head dimension"
assert v_batch == batch, "Value tensor batch size must match Query tensor batch size"
assert v_d == d, "Value tensor head dimension must match Query tensor head dimension"

Copilot uses AI. Check for mistakes.
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=1,
num_stages=2,
Copy link

Copilot AI Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number of pipeline stages changed from 1 to 2. While this can improve performance by overlapping memory and compute, it increases register pressure and may cause issues on older GPUs or with large block sizes. This change should be documented or made configurable if it affects compatibility.

Copilot uses AI. Check for mistakes.


def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
Copy link

Copilot AI Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function returns x when x is None, which is correct, but also returns x when x.stride(-1) == 1. However, if x is None, the expression x.stride(-1) will raise an AttributeError. The condition should be return x if x is None or x.stride(-1) == 1 else x.contiguous() to properly short-circuit.

Suggested change
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
return x if x is None or x.stride(-1) == 1 else x.contiguous()

Copilot uses AI. Check for mistakes.
if query.shape[2] != key.shape[2]:
raise RuntimeError(
"Triton backward for GQA/MQA (nheads_q != nheads_k) is not implemented yet. "
"Use the CUDA backend for training or disable grad for Triton path."
Copy link

Copilot AI Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message should provide more actionable guidance. Consider adding information about how to switch to the CUDA backend or a reference to documentation. For example: 'Triton backward for GQA/MQA is not implemented yet. Use the CUDA backend by setting use_triton=False, or disable gradients for this operation.'

Suggested change
"Use the CUDA backend for training or disable grad for Triton path."
"To resolve this, use the CUDA backend by setting use_triton=False, or disable gradients for this operation. "
"See the documentation for more details: https://github.com/your-repo/docs#backend-selection"

Copilot uses AI. Check for mistakes.
Removes rigid 4D asserts and explicit expands for mask/bias to support broadcast-friendly inputs and avoid unnecessary memory overhead.

Renames flags to lower-case and passes them through consistently to the kernel, updating conditional strides and placeholder tensors accordingly.

Improves readability and aligns runtime flags with kernel expectations.
Drops backward-path guards that blocked attention mask gradients and GQA/MQA head configs with the Triton backend.

Expands training support without forcing CUDA fallback; relies on the underlying kernel for validation.
Enables kernel autotuning with multiple tile configs and warp counts, keyed by sequence lengths, causality, mask/bias presence, and head dim for better performance and correct cache separation.

Defaults missing mask/bias to empty tensors up front to simplify the call path and stabilize the kernel signature.

Removes hardcoded launch params to defer selection to the autotuner; standardizes num_stages in configs.
Supports mixed Q/KV head counts (GQA/MQA) by validating divisibility, updating grid/shape logic to query heads, and passing head-count ratios to the kernel.

Makes mask and bias truly optional: supplies empty tensors with zero strides and passes has_mask/has_bias flags to the kernel, removing strict stride/layout assumptions.

Improves robustness with clearer assertions and compiles keys (is_causal/has_mask/has_bias/head dim), and standardizes forward to require explicit mask/bias args.
Enables backward pass with differing Q/K head counts (GQA/MQA) and broadcasting of mask/bias across 1, KV, or Q heads by remapping head offsets and conditioning pointer advances on feature presence.

Introduces a compile-time switch for atomic accumulation that triggers under sequence parallelism or head-count mismatches to prevent write races, while avoiding atomics otherwise for performance.

Improves correctness and flexibility across broader attention configurations.
Introduces a flag to perform atomic accumulation of gradients when tiles may contend, eliminating race conditions in the backward path.

Retains fast masked stores for safe even/fully covered cases; applies atomic adds with appropriate masks otherwise.

Improves numerical correctness for uneven M/N and variable head dims while preserving performance where possible.
Uses query head count for backward offsets to correct head/batch mapping, enabling GQA/MQA configurations without misindexing.

Aligns masked score accumulation with the tensor dtype by using a zero literal that avoids unintended type promotion, improving numerical stability and performance across fp16/bf16.
Adds a comprehensive backward equivalence test suite validating Triton gradients against the Python prototype across many shapes and head dims, with accuracy and speed reporting. Enables the Triton test path via the test_type flag.

Removes redundant .contiguous() calls before the Triton attention invocation to avoid extra copies and rely on stride-aware kernels, improving memory use and potential performance.

Skips when Triton is unavailable and performs GPU memory cleanup between runs.
Improves correctness and stability of the backward path with GQA/MQA and broadcasted mask/bias:
- Unifies head indexing and fixes DK/DV/DBias head offsets; expands and reduces grads when head counts differ
- Adds broadcast-aware strides and rounded K length; pads head dim to 8 for aligned 16‑bit storage and crops after
- Reworks bwd column kernel to handle optional mask/bias, accumulate bias grad to avoid atomics, and inline safe stores
- Enables sequence-parallel config and introduces dbias accumulation heuristic for better performance

Also refines masking/bias application and removes race-prone barriers, addressing correctness for uneven shapes and broadcasts.
@LoserCheems
Copy link
Collaborator Author

Current benchmark performance comparison

🐍 PyTorch version: 2.9.0+cu128
🔥 Device: cuda
🎮 CUDA device: NVIDIA GeForce RTX 4090
💾 CUDA memory: 22.5 GB
🎲 Random seed: 42
📊 Test type: sdpa-vs-cuda-vs-triton
🔄 Runs: 3, Warmup: 2

🏆============================================================================🏆
🔥 Backward Pass Performance Benchmark 🔥
🏆============================================================================🏆

📊 Backward Pass Benchmark Results (averaged over 3 runs):
🔧 Configuration                                                ⚡ SDPA-BWD     🚀 CUDA-BWD     🌟 Triton-BWD   ✨ Flex-BWD        📈 Speedup        
🔄----------------------------------------------------------------------------------------------------------------------------------------------------------------🔄
📊 B1 Hq2 Hkv1 Q256 K256 D64 W1024 C                            ⚡ 0.50ms       🚀 0.58ms       🌟 0.57ms       ✨ N/A             📈 CUDA: 0.9x, Tri: 0.9x
📊 B1 Hq2 Hkv1 Q512 K512 D64 W1024 C                            ⚡ 0.48ms       🚀 0.37ms       🌟 0.53ms       ✨ N/A             📈 CUDA: 1.3x, Tri: 0.9x
📊 B1 Hq2 Hkv1 Q1024 K1024 D64 W1024 C                          ⚡ 0.44ms       🚀 0.45ms       🌟 0.72ms       ✨ N/A             📈 CUDA: 1.0x, Tri: 0.6x
📊 B1 Hq2 Hkv1 Q2048 K2048 D64 W1024 C                          ⚡ 0.61ms       🚀 0.45ms       🌟 1.05ms       ✨ N/A             📈 CUDA: 1.3x, Tri: 0.6x
📊 B1 Hq2 Hkv1 Q4096 K4096 D64 W1024 C                          ⚡ 2.62ms       🚀 0.92ms       🌟 2.15ms       ✨ N/A             📈 CUDA: 2.9x, Tri: 1.2x
📊 B1 Hq2 Hkv1 Q8192 K8192 D64 W1024 C                          ⚡ 10.12ms      🚀 2.02ms       🌟 5.73ms       ✨ N/A             📈 CUDA: 5.0x, Tri: 1.8x
📊 B1 Hq2 Hkv1 Q16384 K16384 D64 W1024 C                        ⚡ 40.61ms      🚀 6.09ms       🌟 22.11ms      ✨ N/A             📈 CUDA: 6.7x, Tri: 1.8x
🔄----------------------------------------------------------------------------------------------------------------------------------------------------------------🔄

@LoserCheems LoserCheems requested a review from Copilot October 29, 2025 16:07
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

# if we just call tl.load(k_ptrs), we get the wrong output!
if EVEN_N & EVEN_M:

# Load k and v, them will stay in SRAM throughout
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected grammar from 'them will stay' to 'they will stay' in comment.

Suggested change
# Load k and v, them will stay in SRAM throughout
# Load k and v, they will stay in SRAM throughout

Copilot uses AI. Check for mistakes.
_, seqlen_k, nheads_k, _ = k.shape

assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA"
assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128"
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected grammar from 'only support' to 'only supports' in error message.

Suggested change
assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128"
assert d <= 128, "FlashDynamicMaskAttention only supports head dimensions up to 128"

Copilot uses AI. Check for mistakes.
_, seqlen_k, nheads_k, dk = k.shape

assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA"
assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128"
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected grammar from 'only support' to 'only supports' in error message.

Suggested change
assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128"
assert d <= 128, "FlashDynamicMaskAttention only supports head dimensions up to 128"

Copilot uses AI. Check for mistakes.
# Apply scaling
acc_s = acc_s * softmax_scale

lse_i = tl.load(LSE + offs_m_curr)
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented-out line 525 should be removed. If it's being kept for reference, add a comment explaining why the new implementation differs (handling -inf case).

Suggested change
lse_i = tl.load(LSE + offs_m_curr)
lse_i = tl.load(LSE + offs_m_curr)
# Previous implementation did not handle the case where lse_i == -inf, which can result in NaNs.

Copilot uses AI. Check for mistakes.
Comment on lines +1207 to 1208
ctx.seqlen_k_bias_og = attn_bias.shape[-1] if attn_bias is not None else 0
return o
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The saved head_size_og is not stored in the context but is needed in backward pass. Line 1179 computes head_size_og but it's not saved to ctx, yet line 1214 in backward tries to use do.size(3) assuming it matches. If query was padded in forward, the original head size should be saved to ctx for proper gradient slicing in backward.

Copilot uses AI. Check for mistakes.
or ((bias.shape[-2] == 1) and (seqlen_q > 1))
):
if bias.shape[-2] == 1:
dbias_expanded = torch.zeros(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype)
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using torch.zeros for dbias_expanded allocation when bias.shape[-2] == 1 may be inefficient. Consider using torch.empty and ensuring accumulation is handled correctly, as zeros initialization adds unnecessary overhead for a tensor that will be written to.

Suggested change
dbias_expanded = torch.zeros(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype)
dbias_expanded = torch.empty(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype)

Copilot uses AI. Check for mistakes.
Comment on lines +1030 to +1032
dbias_expanded = torch.zeros(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype)
else:
dbias_expanded = torch.zeros(batch, nheads, seqlen_q, seqlen_k_rounded, device=q.device, dtype=dbias.dtype)
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using torch.zeros for dbias_expanded allocation may be inefficient. Consider using torch.empty if the tensor will be fully written to during computation, avoiding the overhead of zero initialization.

Suggested change
dbias_expanded = torch.zeros(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype)
else:
dbias_expanded = torch.zeros(batch, nheads, seqlen_q, seqlen_k_rounded, device=q.device, dtype=dbias.dtype)
dbias_expanded = torch.empty(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype)
else:
dbias_expanded = torch.empty(batch, nheads, seqlen_q, seqlen_k_rounded, device=q.device, dtype=dbias.dtype)

Copilot uses AI. Check for mistakes.
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# load
# Load o
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Corrected comment from 'Load o' to 'Load output' for clarity.

Suggested change
# Load o
# Load output

Copilot uses AI. Check for mistakes.
Moves softmax scaling from logits to queries to avoid per-block scaling and keep bias unscaled.
Initializes logits with bias, then adds the dot product and applies masks afterward.
Improves numerical stability and likely reduces instruction count in the forward path.
Moves softmax scaling from logits to keys to reduce redundant multiplies and improve numerical stability/perf.

Reworks score accumulation to start from bias, then matmul, then masking, preserving masking semantics while simplifying flow.

Introduces the reciprocal scale and uses it when accumulating the query gradients so gradients are computed with unscaled keys.
@LoserCheems
Copy link
Collaborator Author

LoserCheems commented Oct 31, 2025

Triton fwd/bwd: fuse softmax scale into operands, start from bias, preserve mask semantics

What changed

  • Forward path in _fwd_kernel
    • Pre-scale queries once: multiply Q by softmax_scale and keep it in SRAM.
    • Build logits “bias-first”: initialize acc_s with bias (or zeros), then add QK^T, then apply masks.
    • Removes per-block logits scaling, keeping bias unscaled by construction.
  • Backward path in _bwd_kernel_one_col_block
    • Pre-scale keys once for the dot-product.
    • Introduce softmax_unscale = 1/softmax_scale and use it when accumulating dQ so gradients are w.r.t. the original (unscaled) keys.
    • Mirrors forward’s “bias-first → matmul → masks” score flow for stable numerics.
  • Preprocess stays the same in _bwd_preprocess_do_o_dot; no changes to the public API.

Why this helps

  • Fusing the scale into one operand eliminates extra multiply on the logits and avoids accidentally scaling bias.
  • Keeping bias as the base for acc_s improves numerical stability and reduces instruction count.
  • The dQ unscale restores the correct gradient semantics when K is pre-scaled.

Correctness

  • No accuracy loss: forward outputs and backward gradients match the previous implementation. Forward/backward parity verified across the benchmark sweep.

Environment and setup

  • PyTorch 2.6.0+cu124, NVIDIA A800-SXM4-80GB (79.3 GB), cuda device, seed=42.
  • Test type: SDPA vs Triton; 3 runs per point with 2 warmups.

Performance summary (Triton vs SDPA)

  • Forward (before → after):
    • Average speedup: 2.71x → 5.56x.
    • Best speedup: 10.52x → 14.99x (B1 Hq2 Hkv1 Q1 K524288 D64 C).
    • Typical gains grow with sequence length; 8k–32k tokens see 5–14x.
  • Backward:
    • Small to mid sizes: ~parity to minor regressions (0.7–1.1x).
    • Long sequences: 1.5–2.0x (e.g., Q/K=8k–16k: 1.5–2.0x).
  • CUDA/Flex paths were not part of this change (N/A in these runs).

Notes

  • Masking and bias broadcasting semantics are preserved.
  • Head dim up to 128 is supported as before.
  • Results are numerically identical pre/post change.

cc: @ftgreat

Initializes the score accumulator within the bias branch and zero-initializes otherwise.

Prevents referencing an undefined bias when disabled and improves compilation stability in both forward and backward kernels.
…as dtype checks

Add HAS_MASK, HAS_BIAS and HAS_INDICE to the autotune key to ensure different
kernel configs are cached per mask/bias/indice usage. Also enforce bias dtype
to match query dtype (only fp16/bf16) and standardize the mask dtype assert
message.
Moves softmax scaling to the final dk update to cut register pressure and simplify accumulation.
Aligns dq accumulation with unscaled k for more stable gradients.
…ard/_flash_dmattn_backward and update call sites
Ensures backward only returns a bias gradient when bias exists, keeping the signature consistent for biasless calls.
@LoserCheems LoserCheems merged commit 1b9dace into main Nov 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants