Skip to content

Latest commit

 

History

History
49 lines (32 loc) · 1.86 KB

File metadata and controls

49 lines (32 loc) · 1.86 KB

Algorithm Overview

Standard Attention

Transformer attention computes:

Attention(Q, K, V) = softmax(QKᵀ / √d) V

For a sequence length N, standard attention creates an N x N attention matrix. This becomes expensive for long sequences because memory usage grows quadratically with sequence length.

FlashAttention Idea

FlashAttention avoids materializing the full attention matrix in GPU high-bandwidth memory. Instead, it processes blocks of queries, keys, and values, keeps intermediate values in faster on-chip memory when possible, and uses an online softmax update to maintain numerical stability.

This reduces memory traffic and allows attention to scale better to longer sequences.

What This Project Implements

This project implements a Triton-based FlashAttention kernel with:

  • Block-wise forward attention computation
  • Causal and non-causal masking
  • Numerically stable softmax updates
  • Custom PyTorch autograd integration
  • Backward kernels for gradients
  • CPU reference attention for correctness testing

Why Tiling Matters

Naive attention materializes the full score matrix before applying softmax and multiplying by V.

The Triton implementation processes attention in blocks:

  1. Load a block of Q.
  2. Iterate over blocks of K and V.
  3. Compute partial attention scores.
  4. Apply causal masking when needed.
  5. Update running softmax statistics.
  6. Accumulate the output block.

This approach reduces the amount of intermediate data written to and read from GPU global memory.

References