Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
# required block_size.
use_flashmla = False
use_cutlass_mla = False
use_flashinfer_mla = False

if envs.VLLM_ATTENTION_BACKEND is None:
# Default case
Expand All @@ -164,6 +165,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
use_cutlass_mla = (
envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
use_flashinfer_mla = (
envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA")

from vllm.attention.ops.flashmla import is_flashmla_supported
if use_flashmla and is_flashmla_supported()[0] \
Expand All @@ -176,6 +179,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
cache_config.block_size = 128
logger.info("Forcing kv cache block size to 128 for "
"CUTLASS_MLA backend.")
if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLA "
"backend.")

# lazy import to avoid circular import
from vllm.config import CUDAGraphMode
Expand Down Expand Up @@ -228,8 +236,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
selected_backend is None and cls.is_device_capability(100)
and block_size == 128)
use_flashinfermla = (selected_backend == _Backend.FLASHINFER_MLA
and cls.has_device_capability(100))
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
selected_backend is None and cls.is_device_capability(100)
and block_size in [32, 64])
Comment on lines +239 to +241
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of cls.is_device_capability(100) for auto-selecting the FlashInfer MLA backend is too restrictive. It will only match for devices with exactly compute capability 10.0 (Blackwell), and will not automatically select this backend for future architectures with higher compute capabilities (e.g., > 10.0).

The corresponding test for this kernel (tests/kernels/attention/test_flashinfer_mla_decode.py) uses current_platform.has_device_capability(100), which suggests the kernel is expected to work on compute capabilities 10.0 and above.

To ensure future compatibility and correct auto-selection on upcoming hardware, cls.has_device_capability(100) should be used instead. This will match devices with compute capability 10.0 or greater.

A similar issue exists for the cutlass_mla backend logic, which you may want to address in a separate change for consistency.

Suggested change
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
selected_backend is None and cls.is_device_capability(100)
and block_size in [32, 64])
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
selected_backend is None and cls.has_device_capability(100)
and block_size in [32, 64])

use_flashmla = selected_backend in [
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
] or (selected_backend is None and is_flashmla_supported()[0])
Expand Down