Skip to content

[Bugfix][Kernel] Fix CUDA 11.8 being broken by FA3 build #12375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954
GIT_TAG 0aff05f577e8a10086066a00618609199b25231d
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
5 changes: 4 additions & 1 deletion setup.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,10 @@ def _read_requirements(filename: str) -> List[str]:

if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.0"):
# FA3 requires CUDA 12.0 or later
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))

if _build_custom_ops():
Expand Down
11 changes: 6 additions & 5 deletions tests/kernels/test_cascade_flash_attn.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (cascade_attention,
merge_attn_states)
from vllm.vllm_flash_attn import flash_attn_varlen_func
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported)

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 192, 256]
Expand Down Expand Up @@ -91,10 +93,9 @@ def test_cascade(
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")

current_platform.seed_everything(0)

Expand Down
21 changes: 10 additions & 11 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import torch

from vllm.platforms import current_platform
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
Expand Down Expand Up @@ -95,10 +97,9 @@ def test_flash_attn_with_paged_kv(
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")

current_platform.seed_everything(0)
num_seqs = len(kv_lens)
Expand Down Expand Up @@ -182,11 +183,9 @@ def test_varlen_with_paged_kv(
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")

if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
Expand Down
14 changes: 11 additions & 3 deletions vllm/attention/backends/flash_attn.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)

from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)
logger = init_logger(__name__)


class FlashAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -652,6 +655,11 @@ def __init__(
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION

if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))

assert is_fa_version_supported(self.fa_version)

def forward(
Expand Down
11 changes: 10 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported)

logger = init_logger(__name__)


class FlashAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -143,6 +147,11 @@ def __init__(
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION

if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))

assert is_fa_version_supported(self.fa_version)

def forward(
Expand Down
Loading