Skip to content

Commit

Permalink
[Kernel] Flash Attention 3 Support (vllm-project#12093)
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson authored and rasmith committed Jan 30, 2025
1 parent d0a9088 commit 4cafa1c
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 82 deletions.
45 changes: 20 additions & 25 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}")

# Prevent installation of dependencies (cutlass) by default.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)

#
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
Expand Down Expand Up @@ -535,7 +532,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
endif()

# vllm-flash-attn currently only supported on CUDA
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
if (NOT VLLM_GPU_LANG STREQUAL "CUDA")
return()
endif ()

Expand All @@ -558,7 +555,7 @@ endif()
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
# If no component is specified, vllm-flash-attn is still installed.

# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
Expand All @@ -570,43 +567,41 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
endif()

if(VLLM_FLASH_ATTN_SRC_DIR)
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
FetchContent_Declare(
vllm-flash-attn SOURCE_DIR
${VLLM_FLASH_ATTN_SRC_DIR}
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c
GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
endif()

# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set(VLLM_PARENT_BUILD ON)

# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)

# Make sure vllm-flash-attn install rules are nested under vllm/
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)

# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")

# Restore the install prefix
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
# case only one is built, in the case both are built redundant work is done)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm_flash_attn
COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN "*.py"
)

# Copy over the vllm-flash-attn python files
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT vllm_flash_attn_c
FILES_MATCHING PATTERN "*.py"
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm_flash_attn
COMPONENT _vllm_fa3_C
FILES_MATCHING PATTERN "*.py"
)

# Nothing after vllm-flash-attn, see comment about macros above
12 changes: 8 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,11 @@ def target_name(s: str) -> str:

# CMake appends the extension prefix to the install path,
# and outdir already contains that prefix, so we need to remove it.
# We assume only the final component of extension prefix is added by
# CMake, this is currently true for current extensions but may not
# always be the case.
prefix = outdir
for i in range(ext.name.count('.')):
if '.' in ext.name:
prefix = prefix.parent

# prefix here should actually be the same for all components
Expand Down Expand Up @@ -298,7 +301,8 @@ def run(self) -> None:
files_to_copy = [
"vllm/_C.abi3.so",
"vllm/_moe_C.abi3.so",
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
"vllm/vllm_flash_attn/flash_attn_interface.py",
"vllm/vllm_flash_attn/__init__.py",
"vllm/cumem_allocator.abi3.so",
Expand Down Expand Up @@ -593,8 +597,8 @@ def _read_requirements(filename: str) -> List[str]:
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))

if _is_cuda():
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))

if _build_custom_ops():
Expand Down
24 changes: 14 additions & 10 deletions tests/kernels/test_cascade_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_merge_kernel(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("soft_cap", [None, 50])
@pytest.mark.parametrize("num_blocks", [2048])
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_cascade(
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
Expand All @@ -87,8 +88,14 @@ def test_cascade(
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")

current_platform.seed_everything(0)

window_size = (-1, -1)
Expand Down Expand Up @@ -118,9 +125,7 @@ def test_cascade(
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
Expand All @@ -140,7 +145,7 @@ def test_cascade(
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
seqused_k=kv_lens_tensor,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
Expand All @@ -154,10 +159,8 @@ def test_cascade(
assert all(common_prefix_len < kv_len for kv_len in kv_lens)
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens],
dtype=torch.int32)
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], dtype=torch.int32)
cu_suffix_kv_lens = (
cu_kv_lens -
torch.arange(num_seqs + 1, dtype=torch.int32) * common_prefix_len)
prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32)
suffix_kv_lens = kv_lens_tensor - common_prefix_len
output = torch.empty_like(query)
cascade_attention(
output=output,
Expand All @@ -167,15 +170,16 @@ def test_cascade(
cu_query_lens=cu_query_lens,
max_query_len=max_query_len,
cu_prefix_query_lens=cu_prefix_query_lens,
cu_prefix_kv_lens=cu_prefix_kv_lens,
cu_suffix_kv_lens=cu_suffix_kv_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
max_kv_len=max_kv_len,
softmax_scale=scale,
alibi_slopes=None,
sliding_window=window_size,
logits_soft_cap=soft_cap if soft_cap is not None else 0,
block_table=block_tables,
common_prefix_len=common_prefix_len,
fa_version=fa_version,
)

# Compare the results.
Expand Down
22 changes: 18 additions & 4 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
use_out: bool,
Expand All @@ -91,8 +92,14 @@ def test_flash_attn_with_paged_kv(
soft_cap: Optional[float],
num_blocks: int,
sliding_window: Optional[int],
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")

current_platform.seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
Expand Down Expand Up @@ -131,6 +138,7 @@ def test_flash_attn_with_paged_kv(
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
fa_version=fa_version,
)
output = output if not use_out else out
output = output.squeeze(1)
Expand Down Expand Up @@ -159,6 +167,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_varlen_with_paged_kv(
use_out: bool,
Expand All @@ -170,8 +179,14 @@ def test_varlen_with_paged_kv(
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")

current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
Expand All @@ -198,9 +213,7 @@ def test_varlen_with_paged_kv(
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
Expand All @@ -215,14 +228,15 @@ def test_varlen_with_paged_kv(
v=value_cache,
out=out,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
seqused_k=kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version,
)
output = output if not use_out else out

Expand Down
27 changes: 24 additions & 3 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)

from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
flash_attn_with_kvcache,
is_fa_version_supported)


class FlashAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -634,6 +637,20 @@ def __init__(
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type

# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2

if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION

assert is_fa_version_supported(self.fa_version)

def forward(
self,
layer: AttentionLayer,
Expand Down Expand Up @@ -752,6 +769,7 @@ def forward(
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
)
else:
# prefix-enabled attention
Expand All @@ -765,7 +783,7 @@ def forward(
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
seqused_k=prefill_meta.seq_lens_tensor,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
Expand All @@ -774,6 +792,7 @@ def forward(
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.fa_version,
)

if decode_meta := attn_metadata.decode_metadata:
Expand All @@ -793,7 +812,7 @@ def forward(
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
cu_seqlens_k=decode_meta.seq_start_loc,
seqused_k=decode_meta.seq_lens_tensor,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
Expand All @@ -802,6 +821,7 @@ def forward(
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
out=decode_output,
fa_version=self.fa_version,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
Expand All @@ -822,6 +842,7 @@ def forward(
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=decode_output.unsqueeze(1),
fa_version=self.fa_version,
)
return output

Expand Down
Loading

0 comments on commit 4cafa1c

Please sign in to comment.