Skip to content

Commit ccc0b01

Browse files
authored
Merge pull request #134 from SmallDoges/optimize-sparse-logic
Fix block size condition and enhance documentation
2 parents 064a533 + fd1dea2 commit ccc0b01

File tree

4 files changed

+41
-15
lines changed

4 files changed

+41
-15
lines changed

README.md

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ Flash-DMA is a high-performance attention implementation that integrates Flash A
1717

1818
## Key Features
1919

20-
- **Sparse Attention Computation**: Dynamically selects the most important keys for each query, reducing computation from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$.
21-
- **Memory Efficiency**: Maintains Flash Attention's $O(N)$ memory complexity without materializing the full attention matrix.
22-
- **CUDA-Accelerated**: Deep integration at the CUDA kernel level with custom sparse GEMM operations for maximum performance.
23-
- **Long Sequence Support**: Efficiently handles sequences of 128K+ tokens through dynamic masking when sequence length exceeds `keep_window_size`.
24-
- **Advanced Integration**: Complete integration from Python frontend to CUDA backend with optimized memory layouts and sparse computation strategies.
20+
- **Dynamic Sparse Attention**: Dynamically selects the most relevant keys for each query, reducing computational complexity from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$, supporting trainable sparse patterns.
21+
- **Memory Efficiency**: Maintains Flash Attention's $O(N)$ memory complexity without instantiating the full attention matrix.
22+
- **CUDA Deep Optimization**: Utilizes custom CUDA kernels with shared memory aliasing, pipelined prefetching, and block skipping for high throughput and low memory access overhead.
23+
- **Extremely Long Context Support**: Handles 128K+ token sequences efficiently through dynamic mask windowing while preserving accuracy.
24+
- **Learnable Bias**: Built-in learnable attention bias and its gradient path dbias, eliminating the need for additional external operators.
25+
- **Fusion-Friendly Training**: Both forward and backward passes support block-level zero-mask skipping, further reducing computation in sparse scenarios.
2526

2627

2728
## Performance
@@ -129,7 +130,7 @@ The integration happens at the CUDA kernel level with several key components:
129130

130131
- **ZOH States**: Pre-computed importance scores for key selection
131132
- **Active Masks**: Binary masks indicating which keys should be considered for each query
132-
- **Sparse Matrix Multiplication**: Custom CUDA kernels for efficient sparse attention computation
133+
- **Sparse Skipping**: Custom CUDA kernels for efficient sparse attention computation
133134
- **Block-Based Processing**: Maintains Flash Attention's block-based approach for memory efficiency
134135

135136
This creates a hybrid attention mechanism that achieves both memory and computational efficiency for long sequences.
@@ -185,12 +186,24 @@ python benchmarks/forward_equivalence.py
185186
```
186187
Validates numerical consistency between Python reference and CUDA implementation.
187188

188-
### Performance Benchmarking
189+
### Forward Pass Performance Benchmarking
189190
```bash
190191
python benchmarks/forward_performance.py
191192
```
192193
Compares Flash-DMA against standard SDPA across various sequence lengths and batch sizes.
193194

195+
### Backward Pass Equivalence
196+
```bash
197+
python benchmarks/backward_equivalence.py
198+
```
199+
Validates numerical consistency between Python reference and CUDA implementation.
200+
201+
### Backward Pass Performance Benchmarking
202+
```bash
203+
python benchmarks/backward_performance.py
204+
```
205+
Compares Flash-DMA against standard SDPA across various sequence lengths and batch sizes.
206+
194207
### Gradient Computation
195208
```bash
196209
python benchmarks/grad_equivalence.py

README_zh.md

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ Flash-DMA 是一个高性能的注意力实现,将 Flash Attention 的内存
1717

1818
## 主要特性
1919

20-
- **稀疏注意力计算**: 为每个查询动态选择最重要的键,将计算复杂度从 $O(N^2)$ 降低到 $O(N \cdot w)$,其中 $w \ll N$。
20+
- **动态稀疏注意力**: 为每个查询动态选择最重要的键,将计算复杂度从 $O(N^2)$ 降低到 $O(N \cdot w)$,其中 $w \ll N$,支持可训练的稀疏结构
2121
- **内存效率**: 保持 Flash Attention 的 $O(N)$ 内存复杂度,无需实例化完整的注意力矩阵。
22-
- **CUDA 加速**: 在 CUDA 内核层面深度集成,采用自定义稀疏 GEMM 运算以获得最佳性能。
23-
- **长序列支持**: 当序列长度超过 `keep_window_size` 时,通过动态掩码高效处理 128K+ 标记的序列。
24-
- **高级集成**: 从 Python 前端到 CUDA 后端的完整集成,具有优化的内存布局和稀疏计算策略。
22+
- **CUDA 深度优化**:使用自定义 CUDA Kernel, 含共享内存别名、流水线预取、按块跳过, 实现高吞吐与低访存开销。
23+
- **超长上下文支持**:通过动态掩码窗口裁剪,在保持精度的前提下支撑 128K+ 令牌级别的上下文处理。
24+
- **可学习偏置**:内置可学习 attention bias 及其梯度反向路径 dbias,无需额外外部算子。
25+
- **融合式训练友好**:正向与反向过程均支持 block 级全零掩码跳过,在稀疏场景进一步降低计算开销。
2526

2627

2728
## 性能
@@ -129,7 +130,7 @@ Flash-DMA 结合了两种互补的技术:
129130

130131
- **ZOH 状态**: 预计算的键选择重要性分数
131132
- **活跃掩码**: 指示每个查询应考虑哪些键的二进制掩码
132-
- **稀疏矩阵乘法**: 高效稀疏注意力计算的自定义 CUDA 内核
133+
- **稀疏跳过**: 高效稀疏注意力计算的自定义 CUDA 内核
133134
- **分块处理**: 保持 Flash Attention 的分块方法以提高内存效率
134135

135136
这创建了一种混合注意力机制,为长序列实现了内存和计算效率。
@@ -184,12 +185,24 @@ python benchmarks/forward_equivalence.py
184185
```
185186
验证 Python 参考实现与 CUDA 实现之间的数值一致性。
186187

187-
### 性能基准测试
188+
### 前向传播性能基准测试
188189
```bash
189190
python benchmarks/forward_performance.py
190191
```
191192
在各种序列长度和批大小下比较 Flash-DMA 与标准 SDPA。
192193

194+
### 反向传播等效性
195+
```bash
196+
python benchmarks/backward_equivalence.py
197+
```
198+
验证 Python 参考实现与 CUDA 实现之间的数值一致性。
199+
200+
### 反向传播性能基准测试
201+
```bash
202+
python benchmarks/backward_performance.py
203+
```
204+
比较 Flash-DMA 与标准 SDPA 在各种序列长度和批大小下的性能。
205+
193206
### 梯度计算
194207
```bash
195208
python benchmarks/grad_equivalence.py

csrc/flash_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
298298
) {
299299

300300
// This needs to match with run_mha_fwd_splitkv_dispatch
301-
const int block_n = head_size <= 64 ? 64 : (head_size < 128 ? 64 : 32);
301+
const int block_n = head_size <= 64 ? 64 : (head_size <= 128 ? 64 : 32);
302302
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
303303
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
304304
// In any case we don't expect seqlen_q to be larger than 64 for inference.

csrc/src/flash_fwd_launch_template.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
155155
template<typename T, int Headdim, bool Is_causal>
156156
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
157157
constexpr static int kBlockM = 64; // Fixed for all head dimensions
158-
constexpr static int kBlockN = Headdim <= 64 ? 64 : (Headdim < 128 ? 64 : 32);
158+
constexpr static int kBlockN = Headdim <= 64 ? 64 : (Headdim <= 128 ? 64 : 32);
159159
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
160160
}
161161

0 commit comments

Comments
 (0)