Releases: flash-algo/flash-sparse-attention
v1.0.1
What's Changed
- Improve performance of data processing module by @LoserCheems in #135
- Add tensor sanitization and improve test configurations by @LoserCheems in #136
- Enable head dimension 128 test cases for backward and forward equivalence by @LoserCheems in #137
- Fix attention mask handling for invalid topk values by @LoserCheems in #138
- Switch attention backend from flex to cuda by @LoserCheems in #139
- Replace attention_mask with cache_position for improved efficiency by @LoserCheems in #140
- Remove packed tensor functions and improve dynamic mask attention by @LoserCheems in #142
- Fix dynamic block-sparse attention kernel performance by @LoserCheems in #144
- Backward Kernel: Dynamic Mask Skip, Race Fixes, and 5–6× Speedup vs SDPA (Fixes #132) by @LoserCheems in #145
- Fix modeling example by @LoserCheems in #147
- Optimize Flash Dynamic Mask Attention Kernel Configurations by @LoserCheems in #148
Full Changelog: v1.0.0...v1.0.1
flash-dmattn v1.0.0 Technical Report
flash-dmattn v1.0.0 Technical Report
1. Overview
flash-dmattn is a high‑performance FlashAttention-style implementation optimized for large sequence lengths and structured sparsity via Dynamic Masks. It provides:
- Unified block-level dynamic mask (block-sparse) skip logic in both forward and backward passes.
- Fused softmax, normalization, and recomputation-friendly backward pipeline.
- Smart shared memory aliasing to reduce footprint and enhance occupancy.
- Support for bias, Log-Sum-Exp (LSE) caching, and optional softcap.
- PyTorch Autograd compatibility and downstream model integration (example: Doge model, HuggingFace-style interface).
v1.0.0 Highlights:
- Unified sparse skip logic for both forward and backward (eliminates redundant compute on fully masked tiles).
- Improved numerical and performance consistency: coherent shared memory layout, aliasing, and barrier sequencing.
- Documentation, API stabilization, and extensibility groundwork for finer-grained sparsity (bit-packed, fragment-level) later.
Differences vs v0.3.0:
- v0.3.0 only considered backward skip conceptually; v1.0.0 fully unifies forward + backward skip execution.
- Added strict barrier ordering to prevent NaNs (notably in dK path) when reusing aliased shared memory regions.
- Enhanced documentation, tests, and benchmarking.
2. Architecture
Layers:
- Python Integration:
flash_dmattn_interface.pyexposing user-friendly APIs (mirroring standard attention calls). - Kernel Dispatch Layer:
flash_dmattn_flex.py/flash_dmattn_triton.pyselecting CUDA / Triton / hybrid code paths. - C++/CUDA Core: flash_api.cpp +
src/*.h(core kernels:flash_fwd_kernel.h,flash_bwd_kernel.h). - Dynamic Mask Integration:
integrations/flash_dynamic_mask_attention.pyand helpers. - Benchmarks & Validation:
benchmarks/*_equivalence.py,*_performance.py.
Backward dataflow:
Q,K,V,dO (+ mask, bias, LSE) → block streaming → (block-sparse skip decision) → if active: recompute scores & softmax(P) → accumulate dV,dP,dQ,dK → write back.
3. Key Features
- Block-level Dynamic Mask:
- OR-reduction over (BlockM × BlockN) tile; if all zeros → skip.
- Unified Skip (Forward + Backward):
- Forward: skip QK^T, softmax, and P·V for fully masked tiles; safely advances pointers / outputs zeros.
- Backward: skip recompute + the chain of 5 GEMMs (QK^T, dO·V^T, P^T·dO→dV, dP·K→dQ, dP^T·Q→dK).
- LSE Caching:
- Ensures numerical stability: P derived via stored log-sum-exp.
- Optional Softcap:
- Scaling / clamping scores pre-softmax.
- Shared Memory Aliasing:
- sMask ↔ sP; sBias ↔ sdS with explicit barriers.
- Mixed Precision:
- FP16/BF16 inputs, FP32 accumulation.
- Modular KernelTraits:
- Controls block sizes, pipeline depth (double buffering), layouts.
- Extensible Sparsity:
- Design leaves room for bit-packed masks and fragment gating.
4. Algorithms & Kernels
4.1 Forward (Pseudo-code)
for m_block in M_tiles:
load Q_tile
for n_block in N_tiles_stream:
load mask_block
any_active = OR(mask_block)
if !any_active:
advance_pointers()
continue
load K_tile, V_tile
S = Q_tile @ K_tile^T + bias_block
S_masked = apply_mask(S, mask_block)
P = softmax(S_masked, LSE_cache)
O_partial += P @ V_tile
write O
4.2 Backward (Pseudo-code)
for m_block in reversed(M_tiles):
load Q_tile, dO_tile
init accum_dQ
for n_block in N_tiles_stream:
load mask_block
any_active = OR(mask_block)
if !any_active:
advance_pointers_zero_side_outputs()
continue
load K_tile, V_tile
# Recompute
S = Q_tile @ K_tile^T + bias_block
P = softmax(S, LSE_cache)
# Grad chain
dV += P^T @ dO_tile
dP = dO_tile @ V_tile^T
dS = g(P, dP) # (dP - (P ⊙ dP).sum(axis)) * P
dQ += dS @ K_tile
dK += dS^T @ Q_tile
write dQ, accumulate dK, dV
4.3 Softmax & Gradient
Given
Backward:
Fully masked tile:
4.4 Correctness of Skip
If a tile is entirely masked:
- Forward contributions vanish (outputs zero block).
- Backward intermediate tensors (S,P,dS,dP) logically zero; linear GEMMs on zero give zero.
Therefore removing those computations preserves gradients.
5. Sparsity Logic & Performance
5.1 Active Tile Detection
- Load mask tile into shared memory.
- Parallel OR reduction across threads / warps.
- any_active=false triggers skip branch.
5.2 Performance Model
Let active fraction
Upper bound as
5.3 Influencing Factors
- Reduction latency vs early placement.
- Pipeline bubbles due to frequent divergent skip branches.
- Memory bandwidth—mask format (bit-packed future) reduces load footprint.
5.4 Future Enhancements
- Earlier gating (before K/V loads).
- Adaptive density threshold.
- Bit-packed + warp ballot fast OR.
- Persistent CTA / work queue for load balancing.
6. API Summary
Primary function:
flash_dynamic_mask_attention(q, k, v, attn_mask=None, bias=None, softcap=None, causal=False, return_lse=False, ...)
Inputs:
- q/k/v: [B, H, L, D] (k/v possibly different length)
- attn_mask: block-aligned or internally sliced dynamic mask
- bias: optional additive bias
- softcap: optional scaling/clamp
Outputs: - O (and optionally LSE when requested).
Config:
- Block sizes (e.g., 64×64) via traits
- dtype: fp16 / bf16 (fp32 accum)
- enable_skip (default on)
- softcap scalar
7. Memory & Synchronization
- Double buffering for streaming Q/K/V with
cp.asyncfences. - Aliasing:
- sMask reused as sP after consumption.
- sBias reused as sdS after gradient consumption.
- Critical barriers:
- Ensure mask fully read before overwriting region with P.
- Ensure dS fully consumed (dK finished) before alias region becomes bias.
Goal: minimize shared memory to enable larger tiles and higher occupancy.
8. Numerical Stability
- LSE caching prevents overflow.
- FP16/BF16 inputs + FP32 accumulation.
- Skip path doesn't touch LSE entries of masked tiles.
- Validation scripts: forward/backward/grad equivalence across lengths, densities.
9. Backward Compatibility & Upgrade
- Same Python API; upgrading from v0.3.0 requires no code changes for standard use.
- Internal layout symbols not part of public contract—custom kernels should revalidate alias expectations.
- Future runtime stats API planned (non-breaking).
10. Known Limitations
- Only block-aligned sparsity (no arbitrary coordinate compression yet).
- Skip decision not yet moved ahead of K/V/dO loads.
- No fragment-level (Tensor Core tile) sparsity gating yet.
- No built-in distributed (multi-GPU) attention aggregation logic.
- Triton path feature parity still evolving.
11. Testing & Validation
- Numerical: compare to dense
scaled_dot_product_attention. - Sparsity: random masks of varying density; compare skip vs forced-dense output.
- Regression: multi-block scenarios to guard prior dK NaN issue.
- Benchmarks: measure kernel time vs density p.
12. Roadmap
- Early mask gating pre-load.
- Bit-packed mask + warp ballot OR.
- Adaptive skip threshold (disable when p high).
- Fragment-level MMA gating.
- Persistent CTA + work queue.
- Runtime counters: active/skipped tile counts, effective density.
- Distributed integration examples.
13. Safety & Robustness
- Input validation: shapes / dtypes / device alignment.
- Mask alignment and slicing.
- LSE + FP32 mitigate overflow.
- Barriers enforce safe alias lifecycle.
- Future fallback path for anomaly detection (planned).
14. Acknowledgements
- Inspired by FlashAttention research and community.
- Contributors: core maintainers & commit authors (see git history).
- Ecosystem: PyTorch / CUTLASS / Triton.
15. Version Delta Summary
Changes vs v0.3.0:
- Added forward skip bringing full forward/backward symmetry.
- Fixed block size condition + enhanced documentation.
- Shared memory alias + barrier ordering refinements (resolved dK NaNs).
- Skip branch pointer advancement semantics aligned with dense path.
- Comprehensive technical documentation and math derivations.
16. Formula Quick Reference
- Softmax:
- dS:
- Grad propagation:
- Skip predicate:
17. Alias & Barrier Snippet
load mask -> sMask
any_active = or_reduce(sMask)
if any_active:
# reuse sMask region as sP after consumption
compute S
softmax -> write P into aliased region (sP)
...
__syncthreads() # ensure dS consumed
# reuse sBias region as sdS in next iteration
18. Glossary
- Block / Tile: matrix sub-region processed per step.
- Skip: branch eliminating compute for fully masked tile.
- LSE: log-sum-exp cache for stability.
- Aliasing: reusing shared memory region across disjoint lifetimes.
- Fragment-level: granularity of Tensor Core MMA fragments.
19. Integration
- HuggingFace-style example: modeling_doge.py
- Drop-in custom attention module inside transformer blocks.
- Planned: wrapper matching
scaled_dot_product_attentionsignature for rapid...
v0.3.0
What's Changed
- Improve masking and memory management in backward kernel by @LoserCheems in #128
- Consolidate mask and bias memory operations by @LoserCheems in #129
- Enable sparse GEMM with mask tensor support by @LoserCheems in #130
Full Changelog: v0.2.0...v0.3.0
v0.2.0
What's Changed
- Remove unused CUDA generator includes for improved build performance by @LoserCheems in #105
- [WIP] Support Backward for Dynamic Mask Attention by @LoserCheems in #106
- Fix CUDA forward crash when seqlen_q == 1 by @LoserCheems in #108
- Add backward pass support for FlashDynamicMaskAttention by @LoserCheems in #109
- Fix varlen mask and bias tensor shapes for all varlen attention functions by @Copilot in #114
- Refactor backward pass and optimize kernel configurations by @LoserCheems in #116
- Integrate Flash Dynamic Mask Attention (FDMA) Into Transformers-Style Attention Flow by @LoserCheems in #118
- Fixes attention mask/bias shape documentation by @LoserCheems in #123
- Improve CUDA build configuration and fix gradient computation in attention by @LoserCheems in #124
- Enhance backward pass support and optimization for CUDA architectures by @LoserCheems in #125
- Bumps version to 0.2.0 by @LoserCheems in #126
Full Changelog: v0.1.0...v0.2.0
🎉 Flash-DMA v0.1.0
We're excited to announce the first official release of Flash-DMA (Flash Dynamic Mask Attention)!
🚀 What is Flash-DMA?
Flash-DMA is a high-performance attention implementation that combines:
- Flash Attention's memory efficiency
- Dynamic Mask Attention's sparse computation
- Support for extremely long sequences (128K+ tokens)
✨ Key Features
🔥 Performance
- Sparse Attention: Reduces computation from O(N²) to O(N·w) where w ≪ N
- Memory Efficient: Maintains O(N) memory complexity
- CUDA Accelerated: Custom sparse GEMM operations at kernel level
🛠️ Multiple Backends
- CUDA Backend: Maximum performance with custom kernels
- Triton Backend: Flexibility for research and development
- Flex Backend: Integration with Transformers library
📏 Long Sequence Support
- Efficiently handles sequences of 128K+ tokens
- Dynamic masking when sequence length exceeds
keep_window_size - Optimized memory layouts for large-scale processing
📦 Installation
Prerequisites
- Python 3.9+
- PyTorch 2.0+
- CUDA 11.8+
- NVIDIA GPU with Compute Capability 8.0+
Install from Source
git clone https://github.com/SmallDoges/flash-dmattn.git
cd flash-dmattn
git submodule update --init --recursive
pip install .What's Changed
- Workspace by @LoserCheems in #1
- Add namespace_config to csrc by @LoserCheems in #2
- Add hardware_info to csrc by @LoserCheems in #3
- Add block_info to csrc by @LoserCheems in #4
- Add flash params to csrc by @LoserCheems in #5
- Workspace by @LoserCheems in #6
- Update golobal to Shared Memory operation by @LoserCheems in #7
- Update PREDICATES by @LoserCheems in #8
- Fix some nits for layout by @LoserCheems in #9
- Workspace by @LoserCheems in #10
- Fix Dynamic Mask Attention Integration in FlashAttention CUDA Kernel by @Copilot in #12
- Fix dynamic mask attention equivalence issue between Python and CUDA by @Copilot in #14
- Fix CUDA dynamic mask attention scaling to match Python implementation by @Copilot in #16
- Update mask.h by @Evanwu1125 in #17
- Comprehensive README improvement with installation, usage examples, and documentation by @Copilot in #19
- Adds no-topk variant for kernel performance analysis by @LoserCheems in #20
- Optimize sparse GEMM and enable in attention computation by @LoserCheems in #21
- Adds row stride support to offset calculation methods by @LoserCheems in #22
- Corrects ZOH tensor dimension comment by @LoserCheems in #23
- Adds rotary positional encoding operations by @LoserCheems in #24
- Adds conditional softcap switch macro by @LoserCheems in #25
- Removes unused template parameter from DynamicMask by @LoserCheems in #26
- Updates tensor offset calculations and formatting by @LoserCheems in #27
- Adds split-K attention kernel with sparse computation by @LoserCheems in #28
- Adds dropout support and softcap feature to flash attention by @LoserCheems in #29
- Add specialized CUDA kernels for multi-head attention with various head dimensions by @LoserCheems in #30
- Remove cub submodule and add cutlass; implement FlashDynamicMaskAttention by @LoserCheems in #31
- Refactors setup.py for production-ready package build by @LoserCheems in #32
- Update integration by @LoserCheems in #33
- Adds comprehensive API reference documentation by @LoserCheems in #34
- Updates README with improved technical accuracy and examples by @LoserCheems in #35
- Fix bug by @LoserCheems in #36
- Adds column stride parameters to ZOH_params struct by @LoserCheems in #37
- Adds column stride support to offset calculations by @LoserCheems in #38
- Fixes attention benchmarking and expands test coverage by @LoserCheems in #39
- Reorders stride parameter assignments by @LoserCheems in #40
- Adds column stride support to tensor memory layouts by @LoserCheems in #41
- Improves code clarity and test coverage by @LoserCheems in #42
- Updates copy function defaults and clarifies comments by @LoserCheems in #43
- Updates copy operations to use improved vectorization by @LoserCheems in #44
- Updates benchmark test configurations for better coverage by @LoserCheems in #45
- Temporarily disables Split-KV feature by @LoserCheems in #46
- Optimizes CUDA kernel block sizes for better occupancy by @LoserCheems in #49
- Enables test case for 512x512 input dimensions by @LoserCheems in #50
- Renames dzero_hold to dzoh and adds column stride by @LoserCheems in #51
- Improves code formatting consistency in comments by @LoserCheems in #52
- Fixes tensor addressing for ZOH and active mask in splitkv by @LoserCheems in #53
- Refactor attention mask and bias structures for clarity by @LoserCheems in #54
- Refactor backward kernel for attention mask and bias support by @LoserCheems in #55
- Adds Flash Attention implementation with dynamic masking by @LoserCheems in #56
- Fixes mask validation in forward kernel by @LoserCheems in #57
- Fixes mask comparison and scaling logic in attention kernel by @LoserCheems in #58
- Enhance Flash Attention with required parameters and improved backward pass by @LoserCheems in #59
- Reorganizes flash attention files into instantiations directory by @LoserCheems in #60
- Rename flash_dma to flash_dmattn and improve usability by @LoserCheems in #61
- Adds bias gradient computation to backward kernel by @LoserCheems in #62
- Add backend selection and dynamic mask attention support by @LoserCheems in #63
- Update by @LoserCheems in #64
- Removes no-topk CUDA implementation from benchmarks by @LoserCheems in #65
- Enables comprehensive benchmark configurations by @LoserCheems in #66
- Renames Flash Attention to SDPA in benchmark suite by @LoserCheems in #67
- Refactors variable declarations for better readability by @LoserCheems in #68
- Add bias gradient computation support in backward kernel by @LoserCheems in #69
- Fix function naming and standardize memory copy alignment in attention kernel by @LoserCheems in #70
- Adds unified mask application function with causal support by @LoserCheems in #71
- Enables Split-KV avoidance and updates error messages by @LoserCheems in #74
- Adds variable length forward pass support by @LoserCheems in #75
- Simplify attention mask and bias parameter naming by @LoserCheems in #76
- Remove unused parameters and simplify mask logic by @LoserCheems in #77
- Add CUDA-integrated flash attention interface by @LoserCheems in #78
- Improves version comparison using packaging library by @LoserCheems in #79
- Refactor CUDA interface for improved usability by @LoserCheems in https://github.co...