You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
3.**Multi-Query Attention** (`sparse_mla_fwd.py`, `sparse_mla_fwd_pipelined.py`, and `sparse_mla_bwd.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward and backward passes
25
26
26
27
### Lightning Indexer
27
28
@@ -166,3 +167,57 @@ for i_i in T.serial(T.ceildiv(NI, 2)):
166
167
```
167
168
168
169
Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul.
170
+
171
+
### Sparse MLA Backward
172
+
173
+
The Sparse MLA backward kernel (`sparse_mla_bwd.py`) computes gradients with respect to queries (dQ) and key-values (dKV) for the sparse attention mechanism. Like the forward pass, it processes only the selected top-k indices, maintaining O(seq_len * topk) complexity.
174
+
175
+
The backward pass consists of three main stages:
176
+
177
+
**1. Preprocessing**: Computes delta values (row-wise dot products of output and output gradient):
178
+
179
+
```python
180
+
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
181
+
T.copy(O[bz, by * block_ND:(by +1) * block_ND, bx, k * block_ND:(k +1) * block_ND], o)
182
+
T.copy(dO[bz, by * block_ND:(by +1) * block_ND, bx, k * block_ND:(k +1) * block_ND], do)
183
+
for i, j in T.Parallel(block_ND, block_ND):
184
+
acc[i, j] += o[i, j] * do[i, j]
185
+
T.reduce_sum(acc, delta, 1)
186
+
```
187
+
188
+
**2. Main Backward Computation**: Computes gradients through sparse attention:
189
+
190
+
```python
191
+
# Sparse MLA backward: iterate over selected indices only
192
+
for i_i in T.Pipelined(NI, num_stages=num_stages):
**Performance**: The sparse MLA backward achieves excellent performance:
220
+
-**H800 SXM**: ~100 TFlops
221
+
-**H200 SXM**: ~115 TFlops
222
+
223
+
The implementation efficiently handles the irregular memory access patterns inherent in sparse attention while maintaining high compute utilization through careful memory management and atomic update strategies. Note that this is a relatively naive implementation that requires further optimization.
0 commit comments