-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[NVIDIA] Add Cutlass MLA backend #17625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -349,7 +349,7 @@ def __init__(self, | |
self.num_heads = model_config.get_num_attention_heads( | ||
runner.parallel_config) | ||
self.mla_dims = get_mla_dims(model_config) | ||
self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) | ||
self.aot_schedule = current_platform.is_cuda() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this right? I thought we still needed to guard against using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See comment here. My understanding is that this aot_schedule is irrelevant here. @LucasWilkinson can correct me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ya this is right, page-size is used by chunked-prefill in MLA, this naming was a bug There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do I need to make any further changes, or is this good to go now? |
||
self.kv_cache_spec = kv_cache_spec | ||
|
||
# Dont try to access the runner on AMD | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Any, Optional | ||
|
||
import torch | ||
|
||
import vllm._custom_ops as ops | ||
from vllm.attention.backends.abstract import (AttentionType, | ||
is_quantized_kv_cache) | ||
from vllm.logger import init_logger | ||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend, | ||
MLACommonImpl, | ||
MLACommonMetadata) | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class CutlassMLABackend(MLACommonBackend): | ||
|
||
@staticmethod | ||
def get_name() -> str: | ||
return "CUTLASS_MLA_VLLM_V1" | ||
|
||
@staticmethod | ||
def get_impl_cls() -> type["CutlassMLAImpl"]: | ||
return CutlassMLAImpl | ||
|
||
|
||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): | ||
|
||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: int, | ||
alibi_slopes: Optional[list[float]], | ||
sliding_window: Optional[int], | ||
kv_cache_dtype: str, | ||
blocksparse_params: Optional[dict[str, Any]], | ||
logits_soft_cap: Optional[float], | ||
attn_type: str, | ||
# MLA Specific Arguments | ||
**mla_args) -> None: | ||
super().__init__(num_heads, head_size, scale, num_kv_heads, | ||
alibi_slopes, sliding_window, kv_cache_dtype, | ||
blocksparse_params, logits_soft_cap, attn_type, | ||
**mla_args) | ||
|
||
unsupported_features = [ | ||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap | ||
] | ||
if any(unsupported_features): | ||
raise NotImplementedError( | ||
"CutlassMLAImpl does not support one of the following: " | ||
"alibi_slopes, sliding_window, blocksparse_params, " | ||
"logits_soft_cap") | ||
|
||
if attn_type != AttentionType.DECODER: | ||
raise NotImplementedError("Encoder self-attention and " | ||
"encoder/decoder cross-attention " | ||
"are not implemented for " | ||
"CutlassMLAImpl") | ||
|
||
if is_quantized_kv_cache(self.kv_cache_dtype): | ||
raise NotImplementedError( | ||
"CutlassMLA V1 with FP8 KV cache not yet supported") | ||
|
||
def _forward_decode( | ||
self, | ||
q_nope: torch.Tensor, | ||
q_pe: torch.Tensor, | ||
kv_c_and_k_pe_cache: torch.Tensor, | ||
attn_metadata: MLACommonMetadata, | ||
) -> torch.Tensor: | ||
assert kv_c_and_k_pe_cache.numel() > 0 | ||
assert attn_metadata.decode is not None | ||
|
||
if self.kv_cache_dtype.startswith("fp8"): | ||
raise NotImplementedError("FP8 Cutlass MLA not yet supported") | ||
|
||
B = q_nope.shape[0] | ||
|
||
o = torch.empty((B, self.num_heads, self.kv_lora_rank), | ||
dtype=q_nope.dtype, | ||
device=q_nope.device) | ||
|
||
# Run MLA | ||
# Clone q_nope and q_pe to make sure strides computation is correct. | ||
q_nope = q_nope.clone() | ||
q_pe = q_pe.clone() | ||
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache, | ||
attn_metadata.decode.seq_lens, | ||
attn_metadata.decode.block_table, self.scale) | ||
|
||
return self._v_up_proj(o) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the perf looks really good! I think we should turn this on by default for blackwell