Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

@LoserCheems LoserCheems commented Sep 3, 2025

Description

This PR fixes issue #132 by:

  1. Simplifying Split-KV forward path (fixed 64×64 block geometry).
  2. Standardizing forward kernel block sizes (64×64) across head dimensions to reduce template explosion and improve occupancy predictability.
  3. Consolidating forward shared memory gating to a single high threshold (≥164 KB) for A100 / H100 multi‑CTA favorable variants.
  4. Tuning backward kernel trait selections (BlockM / BlockN / buffering / accumulation) to balance shared memory footprint vs parallelism across sm86/sm89, A100, H100.
  5. Reducing specialization breadth by constraining “even M/N/K” template branching for large head dims.
  6. Centralizing unsupported-architecture diagnostics and cleaning macro usage.

Overall: simpler launch logic, improved large‑sequence throughput, and lower maintenance complexity.

Type of Change

  • Bug fix
  • Performance optimization
  • CUDA kernel improvement
  • Code refactoring
  • New feature
  • Breaking change
  • Documentation update

Related Issues

Changes Made

Code / Kernel Launch Logic

  • Forward:
    • Replaced conditional head-dimension dependent Split-KV block size logic with constant kBlockM = kBlockN = 64.
    • Unified head-dimension variants to prefer 64×64 tiles (previous 128×64 or 128×128 kept only in comments for historical context).
    • Simplified softcap / even-MN / even-K branching to cut template instantiations (especially for head dim > 128).
  • Backward:
    • Adjusted per-head-dim kernel trait matrices to reduce shared memory (improves achievable CTAs per SM on sm86/sm89).
    • Selective double buffering retained where it produces stable gains.
    • Clear separation of seq-k parallel path with deterministic grid shaping.
  • Infrastructure:
    • Central macro for unsupported arch message.
    • Consistent __grid_constant__ usage when arch ≥ sm80.

Documentation / Comments

  • Added inline commentary explaining CTA counts per architecture (sm86/sm89, A100, H100).
  • Preserved prior heuristic block-size logic as commented reference.

Testing

Functional validation (numerical equivalence) done via existing benchmark / equivalence scripts for head dims {32, 64, 96, 128, 192, 256}, seq lens up to 32K (and extreme K-only scaling cases). Gradients match within FP16/BF16 tolerance relative to SDPA reference.

  • Forward vs SDPA correctness (all tested shapes)
  • Backward gradient equivalence (D32, D64 paths shown; others spot-checked)
  • Split-KV multi-splits (1,2,4,8,16) launch correctness
  • Deterministic flag path unaffected
  • Kernel attribute dynamic shared memory opt‑in works (A100 / H100)
  • No new compiler warnings (nvcc --ptxas-options=-v)

Test Configuration (example)

  • OS: Ubuntu 22.04
  • Python: 3.12.x
  • PyTorch: 2.5.1+cu121
  • CUDA: 12.1
  • GPUs: A100-SXM4-80GB (also sanity on H100 for sm90)
  • Seed: 42
  • Runs: warmup 2, measure 3 (averaged)

Performance Impact

Summary

Forward pass sees consistent latency reductions for medium–large sequence lengths with head dim 32 (up to ~27% at 2K–4K tokens; 11–18% at very large lengths), while very small K-extreme cases remain dominated by memory and show neutral or expected under‑utilization (unchanged behavior). Head dim 64+ likewise exhibits strong speedups or maintains parity with previous best variant. Backward pass achieves large speedups vs SDPA (6–11× for long sequences) with stable scaling; D64 variant is slightly slower than SDPA for the tiniest shapes (expected due to launch overhead), but overtakes quickly as sequence length grows.

Forward (Head Dim 32) – Old (128×64 tile) vs New (64×64 tile)

Q = K = seq Old CUDA (ms) New CUDA (ms) Δ (%) SDPA (ms) New Speedup vs SDPA
256 0.17 0.16 -6% 0.30 1.84×
512 0.16 0.16 ~0% 0.35 2.18×
1024 0.20 0.17 -15% 0.50 2.97×
2048 0.30 0.22 -27% 1.00 4.50×
4096 0.51 0.37 -27% 2.40 6.40×
8192 0.76 0.60 -21% 8.83 14.80×
16384 1.98 1.62 -18% 26.23 16.16×
32768 3.31 2.93 -11% 104.56 35.67×

Average reduction (geometric, excluding tiny 256/512): ~19.5%.

Forward (Head Dim 64) – New Kernel (64×64)

Representative large‑sequence speedups vs SDPA:

  • 8192: 9.38 ms → 0.49 ms (19.1×)
  • 16384: 29.81 ms → 1.67 ms (17.8×)
  • 32768: 111.83 ms → 4.80 ms (23.3×)

Forward (Head Dim 96) – Two Candidate Configs (80 KB vs 52 KB Variant)

Both 64×64-based; lower shared memory variant (52 KB) improved large‑sequence throughput (e.g., 32K tokens: 7.93 ms → 7.23 ms, ~9%).

Long-Context & Windowed Variants

Window sizes (W) from 32 → 32768 maintain high relative speedups (20–50×) for large sequence lengths, confirming robust behavior across sliding window scenarios.

Backward (Head Dim 32) – CUDA vs SDPA

Selected points:

  • 1024: 0.97 ms → 0.57 ms (1.7×)
  • 8192: 13.82 ms → 1.39 ms (10.0×)
  • 16384: 37.78 ms → 3.30 ms (11.4×)
  • 32768: 134.73 ms → 17.50 ms (7.7×)

Backward (Head Dim 64)

  • 1024: 0.94 ms → 0.58 ms (1.6×)
  • 8192: 12.16 ms → 1.68 ms (7.2×)
  • 16384: 39.62 ms → 4.41 ms (9.0×)
  • 32768: 142.19 ms → 23.07 ms (6.2×)

Smaller shapes (256) show SDPA marginally faster (expected overhead dominance), but crossover occurs quickly.

Notes

  • Extremely asymmetric Q=1 / large K cases remain memory/latency bound; speedups shrink or invert (known limitation; future micro‑batch / stream overlap strategies may help).
  • Tile standardization reduced fluctuation in speedup curves across head dims (smoother scale-up).

Breaking Changes

None.

  • Public Python API unchanged.
  • No semantic shifts in masks, softcap, or return layout.
  • Deterministic mode preserved.

Checklist

  • Style & clang-format alignment
  • Self-review
  • Added clarifying comments
  • README / docs updated (planned follow-up: performance section)
  • No new warnings
  • Benchmarks performed (results in fwd.md, bwd.md)
  • Numerical equivalence tested
  • Dynamic shared memory attributes validated
  • Additional automated tests (could be added under tests/ later)

CUDA-specific

  • SM 8.0+ (A100) & SM 9.0 (H100) tested
  • Memory footprint profiled; reduced for several head dims
  • No leaks (cuda-memcheck clean)
  • Kernel launches validated under deterministic and non-deterministic settings

Performance Data Sources

  • Forward benchmarks: fwd.md
  • Backward benchmarks: bwd.md
  • All runs: 3 measurement + 2 warmup, averaged; single-GPU (A100-SXM4-80GB).

Additional Notes

  • Potential enhancement: runtime auto-tuner (env override) for experimental tile selection.
  • Could export lightweight telemetry (e.g., chosen kernel traits) under a debug flag for profiling.
  • Future work: unify backward small-shape fast path to eliminate SDPA advantage at very tiny sizes.

Cleans up commented debug code that was used for development purposes but is no longer needed in production code.
@LoserCheems LoserCheems changed the title Optimize Flash Dynamic Mask Attention Kernel Configurations (Fixes #146) Optimize Flash Dynamic Mask Attention Kernel Configurations Sep 3, 2025
@LoserCheems LoserCheems merged commit a32c730 into main Sep 3, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes Flash Dynamic Mask Attention kernel configurations to improve performance and reduce template explosion. It standardizes forward kernel block sizes to 64×64 across all head dimensions, simplifies the Split-KV forward path, and tunes backward kernel configurations for better memory usage and parallelism.

  • Standardized forward kernel block geometry to fixed 64×64 tiles
  • Optimized backward kernel trait selections for improved shared memory usage
  • Consolidated shared memory gating logic and reduced template specialization
Comments suppressed due to low confidence (1)

csrc/src/flash_bwd_launch_template.h:1

  • Good removal of commented debug printf statement. This improves code cleanliness and maintainability.
/******************************************************************************

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@LoserCheems LoserCheems deleted the support-integration branch November 13, 2025 04:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] Reduce compute bubbles in backward skip path

8 participants