A flexible and efficient implementation of Flash Attention 2.0 for JAX, supporting multiple backends (GPU/TPU/CPU) and platforms (Triton/Pallas/JAX).
- 🚀 Multiple backend support: GPU, TPU, and CPU
- 🔧 Multiple platform implementations: Triton, Pallas, and JAX
- ⚡ Efficient caching of attention instances
- 🔄 Support for Grouped Query Attention (GQA) and headdims up to 256.
- 📊 JAX sharding-friendly implementation
- 🎯 Automatic platform selection based on backend
- 🧩 Compatible with existing JAX mesh patterns
pip install jax-flash-attn2
from jax_flash_attn2 import get_cached_flash_attention
# Get a cached attention instance
attention = get_cached_flash_attention(
backend="gpu", # 'gpu', 'tpu', or 'cpu'
platform="triton", # 'triton', 'pallas', or 'jax'
blocksize_q=64, # BLOCK SIZE Q
blocksize_k=128, # BLOCK SIZE K
softmax_scale=headdim ** -0.5 # Optional scaling factor
)
# Use with your tensors
outputs = attention(
query=query_states,
key=key_states,
value=value_states,
bias=attention_bias, # Optional
)
with mesh:
attention_outputs = get_cached_flash_attention(
backend="gpu",
platform="triton",
blocksize_q=128,
blocksize_k=128,
softmax_scale=None,
)(
query=with_sharding_constraint(query_states, qps).astype(dtype),
key=with_sharding_constraint(key_states, kps).astype(dtype),
value=with_sharding_constraint(value_states, vps).astype(dtype),
bias=with_sharding_constraint(bias, bps).astype(dtype),
)
- Triton GPU MHA
- Triton GPU MQA (comming soon...)
- Pallas GPU MHA (comming soon...)
- Pallas TPU MHA (comming soon...)
- XLA CPU MHA (comming soon...)
gpu
: CUDA-capable GPUstpu
: Google Cloud TPUscpu
: CPU fallback
triton
: Optimized for NVIDIA GPUspallas
: Optimized for TPUs and supported on GPUsjax
: Universal fallback, supports all backends
Backend | Supported Platforms |
---|---|
GPU | Triton, Pallas, JAX |
TPU | Pallas, JAX |
CPU | JAX |
attention = get_cached_flash_attention(
backend="gpu",
platform="triton",
blocksize_q=128, # Customize query block size
blocksize_k=128, # Customize key block size
softmax_scale=1.0, # Custom softmax scaling
)
FORCE_MHA
: Set to "true", "1", or "on" to force using MHA implementation even for GQA casesFLASH_ATTN_BLOCK_PTR
: set to "1" to usetl.make_block_ptr
for accessing pointer in fwd mode (better for H100/H200 GPUs)
-
Block Sizes: Default block sizes (128) work well for most cases, but you might want to tune them for your specific hardware and model architecture.
-
Platform Selection:
- For NVIDIA GPUs: prefer
triton
- For TPUs: prefer
pallas
- For CPU or fallback: use
jax
- For NVIDIA GPUs: prefer
-
Caching: The
get_cached_flash_attention
function automatically caches instances based on parameters. No need to manage caching manually.
- JAX
- einops
- chex
- jax.experimental.pallas (for TPU support)
- triton (for GPU optimized implementation)
- Triton platform is only available on NVIDIA GPUs.
- Some platform-backend combinations are not supported (see table above).
- Custom attention masks are not yet supported (use bias instead).
Contributions are welcome! Please feel free to submit a Pull Request.
If you use this implementation in your research, please cite:
@software{jax_flash_attn2,
title = {JAX Flash Attention 2.0},
year = {2024},
url = {https://github.com/erfanzar/jax-flash-attn2}
}
- This implementation (MHA) is based on:
- Flash Attention 2.0 paper
- JAX ecosystem tools and libraries
- Triton and Pallas optimization frameworks
-
Custom Triton Uses
JAX-Triton
-
All of kernels are copied from
EasyDeL