A modular Triton implementation of FlashAttention with PyTorch autograd support, correctness tests, and GPU benchmarks.
FlashAttention is an optimized attention algorithm designed to make Transformer attention faster and more memory-efficient on GPUs.
Standard attention materializes a large attention matrix, which becomes expensive for long sequences. FlashAttention improves this by using GPU-aware tiling and reducing memory reads and writes between high-bandwidth memory and on-chip SRAM.
Current implementation:
- Triton FlashAttention forward kernel
- Causal and non-causal attention support
- Custom PyTorch autograd integration
- Backward kernels for gradients
- Naive PyTorch attention comparison
- CPU-safe package interface
- Pytest correctness tests
- GPU-marked Triton tests that skip automatically on non-CUDA machines
- Runtime benchmark script
- CUDA memory benchmark script
Planned improvements:
- Run full GPU correctness tests on a CUDA-capable machine
- Add verified GPU benchmark results
- Add benchmark plots
- Add memory reduction tables
- Benchmark across multiple sequence lengths and GPU architectures
- Improve modular split of Triton forward and backward kernels
- Triton implementation of FlashAttention
- PyTorch custom autograd integration
- Causal and non-causal attention support
- Naive PyTorch reference implementation
- CPU-safe fallback for non-CUDA machines
- GPU tests marked with
pytest.mark.gpu - Benchmarking against:
- Naive PyTorch attention
- PyTorch scaled dot product attention
- Triton FlashAttention, when CUDA is available
- CUDA warm-up kernels for learning GPU programming fundamentals
Flash-Attention-GPU-Kernel/
├── src/
│ └── flash_attention_kernel/
│ ├── __init__.py
│ ├── interface.py
│ ├── naive_attention.py
│ └── triton_flash_attention.py
├── tests/
│ ├── test_imports.py
│ ├── test_naive_attention.py
│ └── test_triton_attention_gpu.py
├── examples/
│ └── run_triton_flash_attention.py
├── benchmarks/
│ ├── README.md
│ ├── benchmark_attention.py
│ ├── benchmark_memory.py
│ └── results/
├── docs/
│ ├── algorithm.md
│ ├── limitations.md
│ └── running.md
├── triton/
│ ├── flash_attention.py
│ └── requirements.txt
├── cuda/
│ ├── Makefile
│ ├── cuda_common.cuh
│ ├── matrix_add.cu
│ ├── vector_add.cu
│ └── vector_add_simple.cu
├── .github/
│ └── workflows/
│ └── tests.yml
├── pyproject.toml
├── .gitignore
└── README.md
Use this mode on machines without CUDA, including most laptops:
pip install -e ".[dev]"Run tests:
pytestExpected result on a non-CUDA machine:
7 passed, 2 skipped
The skipped tests are GPU-specific Triton tests.
Use this mode on a CUDA-capable machine:
pip install -e ".[dev,gpu]"Run all tests:
pytestRun only GPU tests:
pytest -m gpuimport torch
from flash_attention_kernel import flash_attention
q = torch.randn(1, 2, 128, 64)
k = torch.randn_like(q)
v = torch.randn_like(q)
out = flash_attention(q, k, v, causal=True)
print(out.shape)On CPU, the public API falls back to the naive PyTorch reference implementation.
On CUDA tensors, the API can use the Triton backend:
out = flash_attention(q, k, v, causal=True, use_triton=True)python examples/run_triton_flash_attention.pyThis requires a CUDA-capable GPU and Triton installed.
Run the full test suite:
pytestRun only GPU tests:
pytest -m gpuCurrent non-CUDA expected result:
7 passed, 2 skipped
The GPU tests are skipped automatically when CUDA or Triton is unavailable.
python benchmarks/benchmark_attention.py --seq-lens 128 256 512 1024 2048 --causalThe runtime benchmark compares:
- Naive PyTorch attention
- PyTorch scaled dot product attention
- Triton FlashAttention, when CUDA is available
On non-CUDA machines, the benchmark still runs for the PyTorch baselines, but Triton results are left empty.
python benchmarks/benchmark_memory.py --seq-lens 128 256 512 1024 2048 --causalThe memory benchmark requires CUDA.
Generated benchmark outputs are saved under:
benchmarks/results/
Local generated CSV and PNG files are ignored by Git by default.
The cuda/ directory contains small CUDA programs used to practice GPU programming fundamentals, including vector addition and matrix addition.
These files are separate from the Triton FlashAttention implementation.
Build CUDA warm-up examples:
cd cuda
make- The main FlashAttention implementation is written in Triton, not CUDA.
- The CUDA files are warm-up kernels for learning GPU programming basics.
- GPU correctness tests require a CUDA-capable machine with Triton installed.
- Benchmark and memory results require CUDA.
- CPU execution falls back to the naive PyTorch reference implementation.
- Dropout is not implemented.
- Variable-length packed sequences are not implemented.
- The implementation has not yet been benchmarked across multiple GPU architectures.
See Current limitations for more details.
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- Triton Fused Attention Tutorial
The goal of this project is to demonstrate practical understanding of:
- GPU kernel programming
- Triton
- PyTorch custom autograd
- FlashAttention-style memory-efficient attention
- Correctness testing
- Runtime benchmarking
- CUDA memory benchmarking
- Clean Python project structure
├── LICENSE ├── pyproject.toml ├── .gitignore └── README.md