Skip to content

Commit 709d83e

Browse files
committed
add super kernel for deocode moe
Signed-off-by: NNUCJ <616151263@qq.com>
1 parent 4b3a210 commit 709d83e

File tree

5 files changed

+122
-81
lines changed

5 files changed

+122
-81
lines changed

tests/ut/torchair/test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import os
2+
from contextlib import nullcontext
3+
4+
import torchair
25

36
from tests.ut.base import TestBase
47
from vllm_ascend.torchair import utils
@@ -26,3 +29,11 @@ def test_torchair_cache_dir(self):
2629
"Delete torchair cache dir failed")
2730
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
2831
"Delete kv cache bytes cache dir failed")
32+
33+
def test_super_kernel(self):
34+
super_kernel_unenable = utils.super_kernel("prefix", "stream-fusion=1",
35+
False)
36+
self.assertTrue(super_kernel_unenable, nullcontext())
37+
super_kernel_enable = utils.super_kernel("prefix", "stream-fusion=1",
38+
True)
39+
self.assertIsInstance(super_kernel_enable, torchair.scope._Scope)

vllm_ascend/models/deepseek_v2.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -315,15 +315,22 @@ def __init__(
315315
self.enable_multistream_moe = \
316316
ascend_config.torchair_graph_config.enable_multistream_moe and \
317317
self.torchair_graph_enabled
318+
self.enable_super_kernel = self.enable_multistream_moe and self.tp_size == 1
319+
self.params_dtype = torch.get_default_dtype()
318320

319-
self.gate = ReplicatedLinear(config.hidden_size,
320-
config.n_routed_experts,
321-
bias=False,
322-
quant_config=None,
323-
prefix=f"{prefix}.gate")
321+
self.gate = ReplicatedLinear(
322+
config.hidden_size,
323+
config.n_routed_experts,
324+
bias=False,
325+
quant_config=None,
326+
params_dtype=torch.float32
327+
if self.enable_super_kernel else self.params_dtype,
328+
prefix=f"{prefix}.gate")
324329
if config.topk_method == "noaux_tc":
325330
self.gate.e_score_correction_bias = nn.Parameter(
326-
torch.empty(config.n_routed_experts))
331+
torch.empty(config.n_routed_experts,
332+
dtype=torch.float if self.enable_super_kernel else
333+
self.params_dtype))
327334
else:
328335
self.gate.e_score_correction_bias = None
329336

@@ -370,7 +377,6 @@ def __init__(
370377
if transfer_config is not None:
371378
self.kv_consumer = transfer_config.kv_role == "kv_consumer"
372379

373-
self.params_dtype = torch.get_default_dtype()
374380
self.rm_router_logits = self.experts.rm_router_logits
375381

376382
def forward(self,

vllm_ascend/ops/fused_moe.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@
5050
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
5151
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
5252
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
53-
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
53+
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
54+
super_kernel)
5455
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
5556
get_all_reduce_merge_state,
5657
get_ascend_soc_version,
@@ -1201,6 +1202,7 @@ def __init__(
12011202

12021203
AscendFusedMoE.moe_counter += 1
12031204
self.moe_instance_id = AscendFusedMoE.moe_counter
1205+
self.prefix = prefix
12041206

12051207
if params_dtype is None:
12061208
params_dtype = torch.get_default_dtype()
@@ -1265,6 +1267,7 @@ def __init__(
12651267
self.enable_multistream_moe = \
12661268
ascend_config.torchair_graph_config.enable_multistream_moe and \
12671269
self.torchair_graph_enabled
1270+
self.enable_super_kernel = self.enable_multistream_moe and tp_size == 1
12681271

12691272
if self.scoring_func != "softmax" and not self.use_grouped_topk:
12701273
raise ValueError("Only softmax scoring function is supported for "
@@ -1368,16 +1371,23 @@ def forward(self,
13681371
quantized_x_for_share, dynamic_scale_for_share = None, None
13691372
from vllm_ascend.quantization.w8a8_dynamic import \
13701373
AscendW8A8DynamicFusedMoEMethod
1374+
running_in_super_kernel = self.enable_super_kernel and fused_moe_state == FusedMoEState.MC2
13711375
if self.enable_multistream_moe:
1372-
if not self.rm_router_logits:
1373-
router_logits, _ = gate(hidden_states)
1374-
if hasattr(self.quant_method, "quant_method") and \
1375-
isinstance(self.quant_method.quant_method,
1376-
AscendW8A8DynamicFusedMoEMethod
1377-
) and fused_moe_state == FusedMoEState.MC2:
1378-
with npu_stream_switch("moe_secondary", 0):
1379-
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
1380-
hidden_states)
1376+
with super_kernel(self.prefix,
1377+
"stream-fusion=1",
1378+
enabled=running_in_super_kernel):
1379+
if not self.rm_router_logits:
1380+
if self.enable_super_kernel:
1381+
router_logits, _ = gate(hidden_states.float())
1382+
else:
1383+
router_logits, _ = gate(hidden_states)
1384+
if hasattr(self.quant_method, "quant_method") and \
1385+
isinstance(self.quant_method.quant_method,
1386+
AscendW8A8DynamicFusedMoEMethod
1387+
) and fused_moe_state == FusedMoEState.MC2:
1388+
with npu_stream_switch("moe_secondary", 0):
1389+
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
1390+
hidden_states)
13811391

13821392
if shared_experts:
13831393
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
@@ -1467,6 +1477,8 @@ def forward(self,
14671477
token_dispatcher=self.token_dispatcher,
14681478
quantized_x_for_share=quantized_x_for_share,
14691479
dynamic_scale_for_share=dynamic_scale_for_share,
1480+
prefix=self.prefix,
1481+
running_in_super_kernel=running_in_super_kernel,
14701482
)
14711483

14721484
if shared_experts:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 71 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from vllm_ascend.ascend_forward_context import FusedMoEState
2929
from vllm_ascend.distributed.parallel_state import get_mc2_group
3030
from vllm_ascend.ops.fused_moe import select_experts
31-
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
31+
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
32+
super_kernel)
3233
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
3334
dispose_tensor, get_ascend_soc_version)
3435

@@ -830,59 +831,62 @@ def apply(
830831
shared_experts: Optional[Any] = None,
831832
quantized_x_for_share: Optional[Any] = None,
832833
dynamic_scale_for_share: Optional[Any] = None,
834+
prefix: str = "",
835+
running_in_super_kernel: bool = False,
833836
**kwargs,
834837
) -> torch.Tensor:
835838
assert router_logits.shape[
836839
1] == global_num_experts, "Number of global experts mismatch"
837840

838841
is_deepseek_v3_r1 = global_num_experts == 256
839-
840-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
841-
if is_deepseek_v3_r1:
842-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
843-
router_logits,
844-
k=top_k, # topk当前写8
845-
bias=e_score_correction_bias,
846-
k_group=topk_group, # fix: 4
847-
group_count=num_expert_group, # fix 8
848-
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
849-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
850-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
851-
# out_flag=False, # todo new api; 第三个输出是否输出
852-
# y2_flag=False, # old api; 第三个输出是否输出
853-
routed_scaling_factor=1,
854-
eps=float(1e-20))
855-
else:
856-
topk_weights, topk_ids = select_experts(
857-
hidden_states=x,
858-
router_logits=router_logits,
859-
top_k=top_k,
860-
use_grouped_topk=use_grouped_topk,
861-
renormalize=renormalize,
862-
topk_group=topk_group,
863-
num_expert_group=num_expert_group,
864-
custom_routing_function=custom_routing_function,
865-
scoring_func=scoring_func,
866-
e_score_correction_bias=e_score_correction_bias,
867-
)
868-
869842
fused_moe_state = get_forward_context().fused_moe_state
870843
shared_gate_up, shared_dequant_scale = None, None
871-
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
872-
with npu_stream_switch("moe_secondary", 0):
873-
npu_wait_tensor(quantized_x_for_share, router_logits)
874-
share_up_out, _ = shared_experts.gate_up_proj(
875-
(quantized_x_for_share, dynamic_scale_for_share))
876-
shared_gate_up, shared_dequant_scale = share_up_out[
877-
0], share_up_out[1]
878-
879-
# this is a naive implementation for experts load balance so as
880-
# to avoid accumulating too much tokens on a single rank.
881-
# currently it is only activated when doing profile runs.
882-
if enable_force_load_balance:
883-
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
884-
885-
topk_weights = topk_weights.to(x.dtype)
844+
with super_kernel(prefix,
845+
"stream-fusion=1",
846+
enabled=running_in_super_kernel):
847+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
848+
if is_deepseek_v3_r1:
849+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
850+
router_logits,
851+
k=top_k, # topk当前写8
852+
bias=e_score_correction_bias,
853+
k_group=topk_group, # fix: 4
854+
group_count=num_expert_group, # fix 8
855+
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
856+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
857+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
858+
# out_flag=False, # todo new api; 第三个输出是否输出
859+
# y2_flag=False, # old api; 第三个输出是否输出
860+
routed_scaling_factor=1,
861+
eps=float(1e-20))
862+
else:
863+
topk_weights, topk_ids = select_experts(
864+
hidden_states=x,
865+
router_logits=router_logits,
866+
top_k=top_k,
867+
use_grouped_topk=use_grouped_topk,
868+
renormalize=renormalize,
869+
topk_group=topk_group,
870+
num_expert_group=num_expert_group,
871+
custom_routing_function=custom_routing_function,
872+
scoring_func=scoring_func,
873+
e_score_correction_bias=e_score_correction_bias,
874+
)
875+
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
876+
with npu_stream_switch("moe_secondary", 0):
877+
npu_wait_tensor(quantized_x_for_share, router_logits)
878+
share_up_out, _ = shared_experts.gate_up_proj(
879+
(quantized_x_for_share, dynamic_scale_for_share))
880+
shared_gate_up, shared_dequant_scale = share_up_out[
881+
0], share_up_out[1]
882+
883+
# this is a naive implementation for experts load balance so as
884+
# to avoid accumulating too much tokens on a single rank.
885+
# currently it is only activated when doing profile runs.
886+
if enable_force_load_balance:
887+
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
888+
889+
topk_weights = topk_weights.to(x.dtype)
886890
if fused_moe_state == FusedMoEState.AllGatherEP:
887891
return fused_experts_with_allgather(
888892
hidden_states=x,
@@ -895,24 +899,27 @@ def apply(
895899
top_k=top_k,
896900
expert_map=expert_map)
897901
elif fused_moe_state == FusedMoEState.MC2:
898-
return fused_experts_with_mc2(
899-
hidden_states=x,
900-
w1=layer.w13_weight,
901-
w2=layer.w2_weight,
902-
w1_scale=layer.w13_weight_scale_fp32,
903-
w2_scale=layer.w2_weight_scale,
904-
topk_weights=topk_weights,
905-
topk_ids=topk_ids,
906-
top_k=top_k,
907-
expert_map=expert_map,
908-
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
909-
log2phy=log2phy,
910-
global_redundant_expert_num=global_redundant_expert_num,
911-
shared_experts=shared_experts,
912-
is_torchair=self.torchair_graph_enabled,
913-
mc2_mask=kwargs.get("mc2_mask", None),
914-
shared_gate_up=shared_gate_up,
915-
shared_dequant_scale=shared_dequant_scale)
902+
with super_kernel(prefix,
903+
"stream-fusion=1",
904+
enabled=running_in_super_kernel):
905+
return fused_experts_with_mc2(
906+
hidden_states=x,
907+
w1=layer.w13_weight,
908+
w2=layer.w2_weight,
909+
w1_scale=layer.w13_weight_scale_fp32,
910+
w2_scale=layer.w2_weight_scale,
911+
topk_weights=topk_weights,
912+
topk_ids=topk_ids,
913+
top_k=top_k,
914+
expert_map=expert_map,
915+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
916+
log2phy=log2phy,
917+
global_redundant_expert_num=global_redundant_expert_num,
918+
shared_experts=shared_experts,
919+
is_torchair=self.torchair_graph_enabled,
920+
mc2_mask=kwargs.get("mc2_mask", None),
921+
shared_gate_up=shared_gate_up,
922+
shared_dequant_scale=shared_dequant_scale)
916923
elif fused_moe_state in [
917924
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
918925
]:

vllm_ascend/torchair/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from contextlib import contextmanager, nullcontext
55

66
import torch
7+
from torchair.scope import super_kernel as _super_kernel
78

89
try:
910
# Recent release of torchair has moved these ops to `.scope`.
@@ -96,3 +97,7 @@ def npu_wait_tensor(self: torch.Tensor,
9697
*,
9798
enabled: bool = True):
9899
return _npu_wait_tensor(self, dependency) if enabled else self
100+
101+
102+
def super_kernel(prefix: str, stream: str, enabled: bool = True):
103+
return _super_kernel(prefix, stream) if enabled else nullcontext()

0 commit comments

Comments
 (0)