Transformer attention computes:
Attention(Q, K, V) = softmax(QKᵀ / √d) V
For a sequence length N, standard attention creates an N x N attention matrix. This becomes expensive for long sequences because memory usage grows quadratically with sequence length.
FlashAttention avoids materializing the full attention matrix in GPU high-bandwidth memory. Instead, it processes blocks of queries, keys, and values, keeps intermediate values in faster on-chip memory when possible, and uses an online softmax update to maintain numerical stability.
This reduces memory traffic and allows attention to scale better to longer sequences.
This project implements a Triton-based FlashAttention kernel with:
- Block-wise forward attention computation
- Causal and non-causal masking
- Numerically stable softmax updates
- Custom PyTorch autograd integration
- Backward kernels for gradients
- CPU reference attention for correctness testing
Naive attention materializes the full score matrix before applying softmax and multiplying by V.
The Triton implementation processes attention in blocks:
- Load a block of
Q. - Iterate over blocks of
KandV. - Compute partial attention scores.
- Apply causal masking when needed.
- Update running softmax statistics.
- Accumulate the output block.
This approach reduces the amount of intermediate data written to and read from GPU global memory.