Skip to content

Releases: flash-algo/flash-sparse-attention

v1.0.1

05 Sep 05:12
a32c730

Choose a tag to compare

What's Changed

Full Changelog: v1.0.0...v1.0.1

flash-dmattn v1.0.0 Technical Report

29 Aug 02:07
ccc0b01

Choose a tag to compare

flash-dmattn v1.0.0 Technical Report

1. Overview

flash-dmattn is a high‑performance FlashAttention-style implementation optimized for large sequence lengths and structured sparsity via Dynamic Masks. It provides:

  • Unified block-level dynamic mask (block-sparse) skip logic in both forward and backward passes.
  • Fused softmax, normalization, and recomputation-friendly backward pipeline.
  • Smart shared memory aliasing to reduce footprint and enhance occupancy.
  • Support for bias, Log-Sum-Exp (LSE) caching, and optional softcap.
  • PyTorch Autograd compatibility and downstream model integration (example: Doge model, HuggingFace-style interface).

v1.0.0 Highlights:

  1. Unified sparse skip logic for both forward and backward (eliminates redundant compute on fully masked tiles).
  2. Improved numerical and performance consistency: coherent shared memory layout, aliasing, and barrier sequencing.
  3. Documentation, API stabilization, and extensibility groundwork for finer-grained sparsity (bit-packed, fragment-level) later.

Differences vs v0.3.0:

  • v0.3.0 only considered backward skip conceptually; v1.0.0 fully unifies forward + backward skip execution.
  • Added strict barrier ordering to prevent NaNs (notably in dK path) when reusing aliased shared memory regions.
  • Enhanced documentation, tests, and benchmarking.

2. Architecture

Layers:

  1. Python Integration: flash_dmattn_interface.py exposing user-friendly APIs (mirroring standard attention calls).
  2. Kernel Dispatch Layer: flash_dmattn_flex.py / flash_dmattn_triton.py selecting CUDA / Triton / hybrid code paths.
  3. C++/CUDA Core: flash_api.cpp + src/*.h (core kernels: flash_fwd_kernel.h, flash_bwd_kernel.h).
  4. Dynamic Mask Integration: integrations/flash_dynamic_mask_attention.py and helpers.
  5. Benchmarks & Validation: benchmarks/*_equivalence.py, *_performance.py.

Backward dataflow:
Q,K,V,dO (+ mask, bias, LSE) → block streaming → (block-sparse skip decision) → if active: recompute scores & softmax(P) → accumulate dV,dP,dQ,dK → write back.

3. Key Features

  • Block-level Dynamic Mask:
    • OR-reduction over (BlockM × BlockN) tile; if all zeros → skip.
  • Unified Skip (Forward + Backward):
    • Forward: skip QK^T, softmax, and P·V for fully masked tiles; safely advances pointers / outputs zeros.
    • Backward: skip recompute + the chain of 5 GEMMs (QK^T, dO·V^T, P^T·dO→dV, dP·K→dQ, dP^T·Q→dK).
  • LSE Caching:
    • Ensures numerical stability: P derived via stored log-sum-exp.
  • Optional Softcap:
    • Scaling / clamping scores pre-softmax.
  • Shared Memory Aliasing:
    • sMask ↔ sP; sBias ↔ sdS with explicit barriers.
  • Mixed Precision:
    • FP16/BF16 inputs, FP32 accumulation.
  • Modular KernelTraits:
    • Controls block sizes, pipeline depth (double buffering), layouts.
  • Extensible Sparsity:
    • Design leaves room for bit-packed masks and fragment gating.

4. Algorithms & Kernels

4.1 Forward (Pseudo-code)

for m_block in M_tiles:
  load Q_tile
  for n_block in N_tiles_stream:
    load mask_block
    any_active = OR(mask_block)
    if !any_active:
        advance_pointers()
        continue
    load K_tile, V_tile
    S = Q_tile @ K_tile^T + bias_block
    S_masked = apply_mask(S, mask_block)
    P = softmax(S_masked, LSE_cache)
    O_partial += P @ V_tile
write O

4.2 Backward (Pseudo-code)

for m_block in reversed(M_tiles):
  load Q_tile, dO_tile
  init accum_dQ
  for n_block in N_tiles_stream:
    load mask_block
    any_active = OR(mask_block)
    if !any_active:
        advance_pointers_zero_side_outputs()
        continue
    load K_tile, V_tile
    # Recompute
    S = Q_tile @ K_tile^T + bias_block
    P = softmax(S, LSE_cache)
    # Grad chain
    dV += P^T @ dO_tile
    dP = dO_tile @ V_tile^T
    dS = g(P, dP)   # (dP - (P ⊙ dP).sum(axis)) * P
    dQ += dS @ K_tile
    dK += dS^T @ Q_tile
  write dQ, accumulate dK, dV

4.3 Softmax & Gradient

Given $S_{ij}$ and $LSE_i = \log \sum_k e^{S_{ik}}$,

$$ P_{ij} = \frac{e^{S_{ij}-LSE_i}}{\sum_k e^{S_{ik}-LSE_i}} $$

Backward:

$$ \frac{\partial \mathcal{L}}{\partial S_{ij}} = \left( \frac{\partial \mathcal{L}}{\partial P_{ij}} - \sum_{k} \frac{\partial \mathcal{L}}{\partial P_{ik}} P_{ik} \right) P_{ij} $$

Fully masked tile: $P=0 \Rightarrow dS=0$, all dependent GEMMs yield zero → safe to skip.

4.4 Correctness of Skip

If a tile is entirely masked:

  • Forward contributions vanish (outputs zero block).
  • Backward intermediate tensors (S,P,dS,dP) logically zero; linear GEMMs on zero give zero.
    Therefore removing those computations preserves gradients.

5. Sparsity Logic & Performance

5.1 Active Tile Detection

  • Load mask tile into shared memory.
  • Parallel OR reduction across threads / warps.
  • any_active=false triggers skip branch.

5.2 Performance Model

Let active fraction $p$, skip overhead ratio $\varepsilon$:

$$ \text{Speedup} \approx \frac{1}{p + (1-p)\varepsilon} $$

Upper bound as $\varepsilon \to 0$: $1/p$.

5.3 Influencing Factors

  • Reduction latency vs early placement.
  • Pipeline bubbles due to frequent divergent skip branches.
  • Memory bandwidth—mask format (bit-packed future) reduces load footprint.

5.4 Future Enhancements

  • Earlier gating (before K/V loads).
  • Adaptive density threshold.
  • Bit-packed + warp ballot fast OR.
  • Persistent CTA / work queue for load balancing.

6. API Summary

Primary function:
flash_dynamic_mask_attention(q, k, v, attn_mask=None, bias=None, softcap=None, causal=False, return_lse=False, ...)

Inputs:

  • q/k/v: [B, H, L, D] (k/v possibly different length)
  • attn_mask: block-aligned or internally sliced dynamic mask
  • bias: optional additive bias
  • softcap: optional scaling/clamp
    Outputs:
  • O (and optionally LSE when requested).

Config:

  • Block sizes (e.g., 64×64) via traits
  • dtype: fp16 / bf16 (fp32 accum)
  • enable_skip (default on)
  • softcap scalar

7. Memory & Synchronization

  • Double buffering for streaming Q/K/V with cp.async fences.
  • Aliasing:
    • sMask reused as sP after consumption.
    • sBias reused as sdS after gradient consumption.
  • Critical barriers:
    1. Ensure mask fully read before overwriting region with P.
    2. Ensure dS fully consumed (dK finished) before alias region becomes bias.
      Goal: minimize shared memory to enable larger tiles and higher occupancy.

8. Numerical Stability

  • LSE caching prevents overflow.
  • FP16/BF16 inputs + FP32 accumulation.
  • Skip path doesn't touch LSE entries of masked tiles.
  • Validation scripts: forward/backward/grad equivalence across lengths, densities.

9. Backward Compatibility & Upgrade

  • Same Python API; upgrading from v0.3.0 requires no code changes for standard use.
  • Internal layout symbols not part of public contract—custom kernels should revalidate alias expectations.
  • Future runtime stats API planned (non-breaking).

10. Known Limitations

  • Only block-aligned sparsity (no arbitrary coordinate compression yet).
  • Skip decision not yet moved ahead of K/V/dO loads.
  • No fragment-level (Tensor Core tile) sparsity gating yet.
  • No built-in distributed (multi-GPU) attention aggregation logic.
  • Triton path feature parity still evolving.

11. Testing & Validation

  • Numerical: compare to dense scaled_dot_product_attention.
  • Sparsity: random masks of varying density; compare skip vs forced-dense output.
  • Regression: multi-block scenarios to guard prior dK NaN issue.
  • Benchmarks: measure kernel time vs density p.

12. Roadmap

  1. Early mask gating pre-load.
  2. Bit-packed mask + warp ballot OR.
  3. Adaptive skip threshold (disable when p high).
  4. Fragment-level MMA gating.
  5. Persistent CTA + work queue.
  6. Runtime counters: active/skipped tile counts, effective density.
  7. Distributed integration examples.

13. Safety & Robustness

  • Input validation: shapes / dtypes / device alignment.
  • Mask alignment and slicing.
  • LSE + FP32 mitigate overflow.
  • Barriers enforce safe alias lifecycle.
  • Future fallback path for anomaly detection (planned).

14. Acknowledgements

  • Inspired by FlashAttention research and community.
  • Contributors: core maintainers & commit authors (see git history).
  • Ecosystem: PyTorch / CUTLASS / Triton.

15. Version Delta Summary

Changes vs v0.3.0:

  • Added forward skip bringing full forward/backward symmetry.
  • Fixed block size condition + enhanced documentation.
  • Shared memory alias + barrier ordering refinements (resolved dK NaNs).
  • Skip branch pointer advancement semantics aligned with dense path.
  • Comprehensive technical documentation and math derivations.

16. Formula Quick Reference

  1. Softmax:

$$ P_{ij} = \frac{e^{S_{ij}-LSE_i}}{\sum_k e^{S_{ik}-LSE_i}}, \quad LSE_i = \log \sum_k e^{S_{ik}} $$

  1. dS:

$$ dS_{ij} = \left(dP_{ij} - \sum_k dP_{ik} P_{ik}\right) P_{ij} $$

  1. Grad propagation:

$$ dQ = dS K,\quad dK = dS^T Q,\quad dV = P^T dO $$

  1. Skip predicate:

$$ any_active = \bigvee_{(i,j)\in tile} mask_{ij} $$

17. Alias & Barrier Snippet

load mask -> sMask
any_active = or_reduce(sMask)
if any_active:
    # reuse sMask region as sP after consumption
    compute S
    softmax -> write P into aliased region (sP)
    ...
__syncthreads()  # ensure dS consumed
# reuse sBias region as sdS in next iteration

18. Glossary

  • Block / Tile: matrix sub-region processed per step.
  • Skip: branch eliminating compute for fully masked tile.
  • LSE: log-sum-exp cache for stability.
  • Aliasing: reusing shared memory region across disjoint lifetimes.
  • Fragment-level: granularity of Tensor Core MMA fragments.

19. Integration

  • HuggingFace-style example: modeling_doge.py
  • Drop-in custom attention module inside transformer blocks.
  • Planned: wrapper matching scaled_dot_product_attention signature for rapid...
Read more

v0.3.0

26 Aug 16:17

Choose a tag to compare

What's Changed

Full Changelog: v0.2.0...v0.3.0

v0.2.0

25 Aug 14:23
f89d9f6

Choose a tag to compare

What's Changed

  • Remove unused CUDA generator includes for improved build performance by @LoserCheems in #105
  • [WIP] Support Backward for Dynamic Mask Attention by @LoserCheems in #106
  • Fix CUDA forward crash when seqlen_q == 1 by @LoserCheems in #108
  • Add backward pass support for FlashDynamicMaskAttention by @LoserCheems in #109
  • Fix varlen mask and bias tensor shapes for all varlen attention functions by @Copilot in #114
  • Refactor backward pass and optimize kernel configurations by @LoserCheems in #116
  • Integrate Flash Dynamic Mask Attention (FDMA) Into Transformers-Style Attention Flow by @LoserCheems in #118
  • Fixes attention mask/bias shape documentation by @LoserCheems in #123
  • Improve CUDA build configuration and fix gradient computation in attention by @LoserCheems in #124
  • Enhance backward pass support and optimization for CUDA architectures by @LoserCheems in #125
  • Bumps version to 0.2.0 by @LoserCheems in #126

Full Changelog: v0.1.0...v0.2.0

🎉 Flash-DMA v0.1.0

10 Aug 12:38
802613e

Choose a tag to compare

We're excited to announce the first official release of Flash-DMA (Flash Dynamic Mask Attention)!

🚀 What is Flash-DMA?

Flash-DMA is a high-performance attention implementation that combines:

  • Flash Attention's memory efficiency
  • Dynamic Mask Attention's sparse computation
  • Support for extremely long sequences (128K+ tokens)

✨ Key Features

🔥 Performance

  • Sparse Attention: Reduces computation from O(N²) to O(N·w) where w ≪ N
  • Memory Efficient: Maintains O(N) memory complexity
  • CUDA Accelerated: Custom sparse GEMM operations at kernel level

🛠️ Multiple Backends

  • CUDA Backend: Maximum performance with custom kernels
  • Triton Backend: Flexibility for research and development
  • Flex Backend: Integration with Transformers library

📏 Long Sequence Support

  • Efficiently handles sequences of 128K+ tokens
  • Dynamic masking when sequence length exceeds keep_window_size
  • Optimized memory layouts for large-scale processing

📦 Installation

Prerequisites

  • Python 3.9+
  • PyTorch 2.0+
  • CUDA 11.8+
  • NVIDIA GPU with Compute Capability 8.0+

Install from Source

git clone https://github.com/SmallDoges/flash-dmattn.git
cd flash-dmattn
git submodule update --init --recursive
pip install .

What's Changed

Read more