Skip to content

[Bugfix][V1] Allow manual FlashAttention for Blackwell #19492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 12, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,15 +226,21 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.")
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
if selected_backend == _Backend.FLEX_ATTENTION:
elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info("Using FlexAttenion backend on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif selected_backend == _Backend.FLASH_ATTN:
elif selected_backend == _Backend.FLASH_ATTN_VLLM_V1:

Seems the v1 FA enum should be FLASH_ATTN_VLLM_V1, same to FLASHINFER_VLLM_V1:

FLASH_ATTN_VLLM_V1 = enum.auto()
TRITON_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto() # Supported by V1
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_VLLM_V1 = enum.auto()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We seem inconsistent here between using the V1 vs "V0" attention backend names

        if use_v1:
            if selected_backend == _Backend.FLASHINFER:
                logger.info_once("Using FlashInfer backend on V1 engine.")
                return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
            if selected_backend == _Backend.FLEX_ATTENTION:
                logger.info("Using FlexAttenion backend on V1 engine.")
                return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"  # noqa: E501
            if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
                logger.info_once("Using Triton backend on V1 engine.")
                return ("vllm.v1.attention.backends."
                        "triton_attn.TritonAttentionBackend")

I'm also not sure that it makes sense as a user to specify FLASH_ATTN and have FlashInfer be used by default on V1 then

Copy link
Collaborator

@Isotr0py Isotr0py Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We seem inconsistent here between using the V1 vs "V0" attention backend names

Yea, and _VLLM_V1 suffix is also sometimes annoying, because it's easy to type _V1_VLLM and I finally found the engine initialized with unexpected backend. 😅

Given we have had use_v1 to control the v1 enablement, I think it should be OK to use enum without _VLLM_V1 suffix.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do a followup to allow for both variants

logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")

# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100):
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
try:
import flashinfer # noqa: F401
logger.info_once(
Expand All @@ -248,10 +254,13 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance.")
pass
if cls.has_device_capability(80):
# FlashAttention is the default for SM 8.0+ GPUs
elif cls.has_device_capability(80):
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")

# Backends for V0 engine
if selected_backend == _Backend.FLASHINFER:
logger.info("Using FlashInfer backend.")
return "vllm.attention.backends.flashinfer.FlashInferBackend"
Expand Down