Skip to content

[NVIDIA] Add Cutlass MLA backend #17625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/attention/mla/cutlass_mla_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ typename T::Fmha::Arguments args_from_options(
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
static_cast<ElementAcc*>(nullptr), stride_LSE},
hw_info,
-1, // split_kv
1, // split_kv
nullptr, // is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
Expand Down
4 changes: 3 additions & 1 deletion tests/kernels/test_cutlass_mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor

q = torch.randn(bs, h_q, d)
# Amplify input values to ensure test coverage of edge cases where CUTLASS
# kernel errors occur with split_k settings.
q = torch.randn(bs, h_q, d) * 100
block_table = torch.randint(0,
bs * block_num, (bs, block_num),
dtype=torch.int32)
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1",
"TRITON_MLA",
"CUTLASS_MLA_VLLM_V1",
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
Expand Down
8 changes: 8 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,14 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
if use_mla:
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the perf looks really good! I think we should turn this on by default for blackwell

if use_v1:
logger.info_once("Using Cutlass MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"cutlass_mla.CutlassMLABackend")
else:
logger.warning(
"Cutlass MLA backend is only supported on V1 engine")
if selected_backend == _Backend.TRITON_MLA or block_size != 64:
if use_v1:
logger.info_once("Using Triton MLA backend on V1 engine.")
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class _Backend(enum.Enum):
TRITON_MLA_VLLM_V1 = enum.auto()
FLASHMLA_VLLM_V1 = enum.auto()
FLASHMLA = enum.auto() # Supported by V1
CUTLASS_MLA_VLLM_V1 = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def __init__(self,
self.num_heads = model_config.get_num_attention_heads(
runner.parallel_config)
self.mla_dims = get_mla_dims(model_config)
self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3)
self.aot_schedule = current_platform.is_cuda()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right? I thought we still needed to guard against using aot_schedule if using FA2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment here. My understanding is that this aot_schedule is irrelevant here. @LucasWilkinson can correct me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya this is right, page-size is used by chunked-prefill in MLA, this naming was a bug

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I need to make any further changes, or is this good to go now?

self.kv_cache_spec = kv_cache_spec

# Dont try to access the runner on AMD
Expand Down
96 changes: 96 additions & 0 deletions vllm/v1/attention/backends/mla/cutlass_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Optional

import torch

import vllm._custom_ops as ops
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)

logger = init_logger(__name__)


class CutlassMLABackend(MLACommonBackend):

@staticmethod
def get_name() -> str:
return "CUTLASS_MLA_VLLM_V1"

@staticmethod
def get_impl_cls() -> type["CutlassMLAImpl"]:
return CutlassMLAImpl


class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)

unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"CutlassMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CutlassMLAImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"CutlassMLA V1 with FP8 KV cache not yet supported")

def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None

if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Cutlass MLA not yet supported")

B = q_nope.shape[0]

o = torch.empty((B, self.num_heads, self.kv_lora_rank),
dtype=q_nope.dtype,
device=q_nope.device)

# Run MLA
# Clone q_nope and q_pe to make sure strides computation is correct.
q_nope = q_nope.clone()
q_pe = q_pe.clone()
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table, self.scale)

return self._v_up_proj(o)