Skip to content

[Perf] Improve MLA multistream performance #1353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 40 additions & 28 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -561,8 +561,6 @@
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla

# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
Expand All @@ -579,13 +577,14 @@
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
"{32, 64, 128}.")

def _v_up_proj_and_o_proj(self, x):
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
npu_prefetch(self.o_proj.weight, x, enabled=enable_multistream_mla)

Check warning on line 587 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L587

Added line #L587 was not covered by tests
return self.o_proj(x)[0]

# Return `ql_nope`, `q_pe`
Expand Down Expand Up @@ -870,21 +869,18 @@
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
with npu_stream_switch("mla_secondary",
0,
enabled=self.enable_multistream_mla):
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return k_pe, k_nope
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(

Check warning on line 872 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L872

Added line #L872 was not covered by tests
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return k_pe, k_nope, kv

Check warning on line 883 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L883

Added line #L883 was not covered by tests

def exec_kv_prefill(
self,
Expand Down Expand Up @@ -936,6 +932,7 @@
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AscendMLAMetadata,
enable_multistream_mla: bool = False,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None
Expand Down Expand Up @@ -1016,7 +1013,8 @@
out=attn_output)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output)
return self._v_up_proj_and_o_proj(attn_output,

Check warning on line 1016 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L1016

Added line #L1016 was not covered by tests
enable_multistream_mla)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
Expand All @@ -1032,6 +1030,8 @@
kv_cache: torch.Tensor,
attn_metadata: M,
output: Optional[torch.Tensor] = None,
enable_multistream_mla: bool = False,
ckq: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
Expand Down Expand Up @@ -1086,27 +1086,38 @@
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
with npu_stream_switch("mla_secondary",

Check warning on line 1089 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L1089

Added line #L1089 was not covered by tests
0,
enabled=enable_multistream_mla):
npu_wait_tensor(hidden_states_or_kv_c_normed,

Check warning on line 1092 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L1092

Added line #L1092 was not covered by tests
ckq,
enabled=enable_multistream_mla)
decode_k_pe, decode_k_nope, kv = self.exec_kv(

Check warning on line 1095 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L1095

Added line #L1095 was not covered by tests
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
npu_wait_tensor(decode_hs_or_q_c,
cos,
enabled=self.enable_multistream_mla)
enabled=enable_multistream_mla)
npu_wait_tensor(decode_hs_or_q_c,
sin,
enabled=self.enable_multistream_mla)
enabled=enable_multistream_mla)
npu_wait_tensor(decode_hs_or_q_c,

Check warning on line 1107 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L1107

Added line #L1107 was not covered by tests
kv,
enabled=enable_multistream_mla)

decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)

if self.running_in_graph:
decode_k_pe, decode_k_nope = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
with npu_stream_switch("mla_secondary",
0,
enabled=self.enable_multistream_mla):
enabled=enable_multistream_mla):
npu_wait_tensor(decode_q_pe,
decode_k_pe,
enabled=self.enable_multistream_mla)
enabled=enable_multistream_mla)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
else:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
Expand Down Expand Up @@ -1189,7 +1200,8 @@
if self.running_in_graph:
return self._forward_decode(decode_ql_nope, decode_q_pe,
decode_k_nope, decode_k_pe,
kv_cache, attn_metadata)
kv_cache, attn_metadata,
enable_multistream_mla)
else:
output_decode = self._forward_decode(decode_ql_nope,
decode_q_pe,
Expand Down
24 changes: 13 additions & 11 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor)
from vllm_ascend.utils import dispose_tensor, npu_prefetch


class CustomDeepseekV2SiluAndMul(SiluAndMul):
Expand Down Expand Up @@ -465,20 +464,23 @@
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
forward_kwargs = {}

Check warning on line 467 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L467

Added line #L467 was not covered by tests
if self.q_lora_rank is not None:
enable_multistream_mla = (self.enable_multistream_mla

Check warning on line 469 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L469

Added line #L469 was not covered by tests
and self.torchair_graph_enabled
and attn_metadata is not None and
not attn_metadata.with_prefill_across_dp
and attn_metadata.num_decodes > 0)
npu_prefetch(self.q_a_proj.weight,

Check warning on line 474 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L474

Added line #L474 was not covered by tests
hidden_states,
enabled=enable_multistream_mla)
ckq = self.q_a_proj(hidden_states)[0]
use_multistream_mla = (self.enable_multistream_mla
and attn_metadata is not None
and attn_metadata.num_decodes > 0)
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
with npu_stream_switch("mla_secondary",
0,
enabled=use_multistream_mla):
hidden_states_or_q_c = self.q_a_layernorm(ckq)
hidden_states_or_q_c = self.q_a_layernorm(ckq)
forward_kwargs["enable_multistream_mla"] = enable_multistream_mla
forward_kwargs['ckq'] = ckq if enable_multistream_mla else None

Check warning on line 480 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L478-L480

Added lines #L478 - L480 were not covered by tests
else:
hidden_states_or_q_c = hidden_states
if self.torchair_graph_enabled:
forward_kwargs = {}
if envs.VLLM_USE_V1:
output_shape = hidden_states.shape
output = torch.empty(output_shape,
Expand Down
13 changes: 13 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,19 @@
return _npu_wait_tensor(self, dependency) if enabled else self


def npu_prefetch(input: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
*,
enabled: bool = True):
if not enabled:
return
input_size = input.element_size() * input.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch_npu.npu_prefetch(input, dependency, max_size)

Check warning on line 402 in vllm_ascend/utils.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/utils.py#L397-L402

Added lines #L397 - L402 were not covered by tests


# TODO(zzzzwwjj): move this into forward_context
class FusedMoEState(Enum):
AllGather = 0
Expand Down
Loading