Skip to content

mohamedabbouda/Flash-Attention-GPU-Kernel

Repository files navigation

FlashAttention GPU Kernel in Triton

Tests Python Triton License

A modular Triton implementation of FlashAttention with PyTorch autograd support, correctness tests, and GPU benchmarks.


Overview

FlashAttention is an optimized attention algorithm designed to make Transformer attention faster and more memory-efficient on GPUs.

Standard attention materializes a large attention matrix, which becomes expensive for long sequences. FlashAttention improves this by using GPU-aware tiling and reducing memory reads and writes between high-bandwidth memory and on-chip SRAM.


Project Status

Current implementation:

  • Triton FlashAttention forward kernel
  • Causal and non-causal attention support
  • Custom PyTorch autograd integration
  • Backward kernels for gradients
  • Naive PyTorch attention comparison
  • CPU-safe package interface
  • Pytest correctness tests
  • GPU-marked Triton tests that skip automatically on non-CUDA machines
  • Runtime benchmark script
  • CUDA memory benchmark script

Planned improvements:

  • Run full GPU correctness tests on a CUDA-capable machine
  • Add verified GPU benchmark results
  • Add benchmark plots
  • Add memory reduction tables
  • Benchmark across multiple sequence lengths and GPU architectures
  • Improve modular split of Triton forward and backward kernels

Documentation


Features

  • Triton implementation of FlashAttention
  • PyTorch custom autograd integration
  • Causal and non-causal attention support
  • Naive PyTorch reference implementation
  • CPU-safe fallback for non-CUDA machines
  • GPU tests marked with pytest.mark.gpu
  • Benchmarking against:
    • Naive PyTorch attention
    • PyTorch scaled dot product attention
    • Triton FlashAttention, when CUDA is available
  • CUDA warm-up kernels for learning GPU programming fundamentals

Repository Structure

Flash-Attention-GPU-Kernel/
├── src/
│   └── flash_attention_kernel/
│       ├── __init__.py
│       ├── interface.py
│       ├── naive_attention.py
│       └── triton_flash_attention.py
├── tests/
│   ├── test_imports.py
│   ├── test_naive_attention.py
│   └── test_triton_attention_gpu.py
├── examples/
│   └── run_triton_flash_attention.py
├── benchmarks/
│   ├── README.md
│   ├── benchmark_attention.py
│   ├── benchmark_memory.py
│   └── results/
├── docs/
│   ├── algorithm.md
│   ├── limitations.md
│   └── running.md
├── triton/
│   ├── flash_attention.py
│   └── requirements.txt
├── cuda/
│   ├── Makefile
│   ├── cuda_common.cuh
│   ├── matrix_add.cu
│   ├── vector_add.cu
│   └── vector_add_simple.cu
├── .github/
│   └── workflows/
│       └── tests.yml
├── pyproject.toml
├── .gitignore
└── README.md

Installation

CPU development mode

Use this mode on machines without CUDA, including most laptops:

pip install -e ".[dev]"

Run tests:

pytest

Expected result on a non-CUDA machine:

7 passed, 2 skipped

The skipped tests are GPU-specific Triton tests.

GPU development mode

Use this mode on a CUDA-capable machine:

pip install -e ".[dev,gpu]"

Run all tests:

pytest

Run only GPU tests:

pytest -m gpu

Usage

Python API

import torch

from flash_attention_kernel import flash_attention

q = torch.randn(1, 2, 128, 64)
k = torch.randn_like(q)
v = torch.randn_like(q)

out = flash_attention(q, k, v, causal=True)
print(out.shape)

On CPU, the public API falls back to the naive PyTorch reference implementation.

On CUDA tensors, the API can use the Triton backend:

out = flash_attention(q, k, v, causal=True, use_triton=True)

Running the Triton Example

python examples/run_triton_flash_attention.py

This requires a CUDA-capable GPU and Triton installed.


Testing

Run the full test suite:

pytest

Run only GPU tests:

pytest -m gpu

Current non-CUDA expected result:

7 passed, 2 skipped

The GPU tests are skipped automatically when CUDA or Triton is unavailable.


Benchmarks

Runtime benchmark

python benchmarks/benchmark_attention.py --seq-lens 128 256 512 1024 2048 --causal

The runtime benchmark compares:

  • Naive PyTorch attention
  • PyTorch scaled dot product attention
  • Triton FlashAttention, when CUDA is available

On non-CUDA machines, the benchmark still runs for the PyTorch baselines, but Triton results are left empty.

Memory benchmark

python benchmarks/benchmark_memory.py --seq-lens 128 256 512 1024 2048 --causal

The memory benchmark requires CUDA.

Generated benchmark outputs are saved under:

benchmarks/results/

Local generated CSV and PNG files are ignored by Git by default.


CUDA Warm-up Kernels

The cuda/ directory contains small CUDA programs used to practice GPU programming fundamentals, including vector addition and matrix addition.

These files are separate from the Triton FlashAttention implementation.

Build CUDA warm-up examples:

cd cuda
make

Current Limitations

  • The main FlashAttention implementation is written in Triton, not CUDA.
  • The CUDA files are warm-up kernels for learning GPU programming basics.
  • GPU correctness tests require a CUDA-capable machine with Triton installed.
  • Benchmark and memory results require CUDA.
  • CPU execution falls back to the naive PyTorch reference implementation.
  • Dropout is not implemented.
  • Variable-length packed sequences are not implemented.
  • The implementation has not yet been benchmarked across multiple GPU architectures.

See Current limitations for more details.


References


Project Goal

The goal of this project is to demonstrate practical understanding of:

  • GPU kernel programming
  • Triton
  • PyTorch custom autograd
  • FlashAttention-style memory-efficient attention
  • Correctness testing
  • Runtime benchmarking
  • CUDA memory benchmarking
  • Clean Python project structure

├── LICENSE ├── pyproject.toml ├── .gitignore └── README.md

About

Triton FlashAttention kernel with PyTorch autograd, correctness tests, and GPU benchmarks.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors