Skip to content

[Attention] Default to FlashMLA backend for MLA #14451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config

if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
Expand Down Expand Up @@ -142,14 +143,21 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

# TODO(lucas): handle this more gracefully
if envs.VLLM_ATTENTION_BACKEND is not None \
and envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" \
and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"FlashMLA: Forcing kv cache block size to 64 since this"
" is currently the only block size supported by the kernel.")
# Note: model_config may be None during testing
if model_config is not None and model_config.use_mla:
# if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
# we default to FlashMLA backend, so we need to force the blocksize
# here
use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
from vllm.attention.backends.flashmla import is_flashmla_supported
if use_flashmla and is_flashmla_supported()[0] \
and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLA backend.")

if (parallel_config.data_parallel_size > 1
and compilation_config.use_cudagraph):
Expand All @@ -173,7 +181,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
if use_mla:
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
if selected_backend == _Backend.FLASHMLA:
if selected_backend == _Backend.TRITON_MLA or block_size != 64:
if use_v1:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
else:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
else:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
if not is_flashmla_supported()[0]:
Expand All @@ -195,14 +211,6 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
logger.info("Using FlashMLA backend.")
return ("vllm.attention.backends."
"flashmla.FlashMLABackend")

if use_v1:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
else:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
if use_v1:
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends.flash_attn."
Expand Down