Skip to content

Commit 8b7eab5

Browse files
committed
add super kernel for deocode moe
Signed-off-by: NNUCJ <616151263@qq.com>
1 parent 8181790 commit 8b7eab5

File tree

6 files changed

+152
-89
lines changed

6 files changed

+152
-89
lines changed

tests/ut/ops/test_fused_ops.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,23 @@ def test_init_with_quant(self, mock_dist_env, default_moe_config):
238238

239239
@pytest.mark.parametrize(
240240
"others_param",
241-
[[None,
242-
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
243-
[2, None, False, 5, None], [None, None, True, 5, None],
244-
[None, None, False, 1, None], [None, None, True, 5, 1],
245-
[None, None, False, 5, 1]])
241+
[[
242+
None,
243+
MagicMock(return_value=torch.randn(5, 32)), False, 5, None, False,
244+
None, False
245+
], [2, None, False, 5, None, False, None, False],
246+
[None, None, True, 5, None, False, None, False],
247+
[None, None, False, 1, None, False, None, False],
248+
[None, None, True, 5, 1, False, None, False],
249+
[None, None, False, 5, 1, False, None, False],
250+
[
251+
None, None, True, 5, 1, True,
252+
MagicMock(return_value=(torch.randn(5, 8), None)), False
253+
],
254+
[
255+
None, None, False, 5, 1, True,
256+
MagicMock(return_value=(torch.randn(5, 8), None)), True
257+
]])
246258
def test_forward(self, mock_dist_env, default_moe_config, others_param):
247259
"""
248260
1 test has shared_experts
@@ -251,15 +263,22 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
251263
4 test single num_tokens(decode)
252264
5 test ep_size is 1 and is_prefill is true
253265
6 test ep_size is 1 and is_prefill is False
266+
7 test enable_multistream_moe and is_prefill is true
254267
"""
255-
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
268+
top_k, shared_experts, is_prefill, num_tokens, ep_size, enable_multistream_moe, gate, enable_super_kernel = others_param
256269
inputs = torch.randn(num_tokens, 32)
257270
router_logits = torch.randn(num_tokens, 8)
258271
moe = AscendFusedMoE(**default_moe_config)
259272

260273
if ep_size == 1:
261274
moe.moe_parallel_config.ep_size = 1
262275

276+
if enable_multistream_moe:
277+
moe.enable_multistream_moe = True
278+
279+
if enable_super_kernel:
280+
moe.enable_super_kernel = True
281+
263282
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
264283
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
265284
dtype=torch.bool),
@@ -270,7 +289,8 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
270289
router_logits,
271290
is_prefill=is_prefill,
272291
top_k=top_k,
273-
shared_experts=shared_experts)
292+
shared_experts=shared_experts,
293+
gate=gate)
274294

275295
moe.quant_method.apply.assert_called_once()
276296

tests/ut/torchair/test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import os
2+
from contextlib import nullcontext
3+
4+
import torchair
25

36
from tests.ut.base import TestBase
47
from vllm_ascend.torchair import utils
@@ -26,3 +29,11 @@ def test_torchair_cache_dir(self):
2629
"Delete torchair cache dir failed")
2730
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
2831
"Delete kv cache bytes cache dir failed")
32+
33+
def test_super_kernel(self):
34+
super_kernel_unenable = utils.super_kernel("prefix", "stream-fusion=1",
35+
False)
36+
self.assertTrue(super_kernel_unenable, nullcontext())
37+
super_kernel_enable = utils.super_kernel("prefix", "stream-fusion=1",
38+
True)
39+
self.assertIsInstance(super_kernel_enable, torchair.scope._Scope)

vllm_ascend/models/deepseek_v2.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,15 +315,24 @@ def __init__(
315315
self.enable_multistream_moe = \
316316
ascend_config.torchair_graph_config.enable_multistream_moe and \
317317
self.torchair_graph_enabled
318-
319-
self.gate = ReplicatedLinear(config.hidden_size,
320-
config.n_routed_experts,
321-
bias=False,
322-
quant_config=None,
323-
prefix=f"{prefix}.gate")
318+
self.enable_super_kernel = self.enable_multistream_moe and self.tp_size == 1
319+
self.params_dtype = torch.float32 if self.enable_super_kernel else torch.get_default_dtype()
320+
321+
# Converting gate weight to fp32 is to adapt to the super kernel feature.
322+
# Super kernel feature currently cannot fuse operators such as cast, stridedslice, and add.
323+
# In the moe stage, Cast will interrupt the fusion of the super kernel. To avoid this problem,
324+
# modifications will be made in the initialization stage.
325+
self.gate = ReplicatedLinear(
326+
config.hidden_size,
327+
config.n_routed_experts,
328+
bias=False,
329+
quant_config=None,
330+
params_dtype=self.params_dtype,
331+
prefix=f"{prefix}.gate")
324332
if config.topk_method == "noaux_tc":
325333
self.gate.e_score_correction_bias = nn.Parameter(
326-
torch.empty(config.n_routed_experts))
334+
torch.empty(config.n_routed_experts,
335+
dtype=self.params_dtype))
327336
else:
328337
self.gate.e_score_correction_bias = None
329338

@@ -370,7 +379,6 @@ def __init__(
370379
if transfer_config is not None:
371380
self.kv_consumer = transfer_config.kv_role == "kv_consumer"
372381

373-
self.params_dtype = torch.get_default_dtype()
374382
self.rm_router_logits = self.experts.rm_router_logits
375383

376384
def forward(self,

vllm_ascend/ops/fused_moe.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@
4848
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
4949
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
5050
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
51-
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
51+
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
52+
super_kernel)
5253
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
5354
get_all_reduce_merge_state,
5455
get_ascend_soc_version,
@@ -1204,6 +1205,7 @@ def __init__(
12041205
)
12051206
AscendFusedMoE.moe_counter += 1
12061207
self.moe_instance_id = AscendFusedMoE.moe_counter
1208+
self.prefix = prefix
12071209

12081210
if params_dtype is None:
12091211
params_dtype = torch.get_default_dtype()
@@ -1268,6 +1270,7 @@ def __init__(
12681270
self.enable_multistream_moe = \
12691271
ascend_config.torchair_graph_config.enable_multistream_moe and \
12701272
self.torchair_graph_enabled
1273+
self.enable_super_kernel = self.enable_multistream_moe and tp_size == 1
12711274

12721275
if self.scoring_func != "softmax" and not self.use_grouped_topk:
12731276
raise ValueError("Only softmax scoring function is supported for "
@@ -1372,16 +1375,23 @@ def forward(self,
13721375
quantized_x_for_share, dynamic_scale_for_share = None, None
13731376
from vllm_ascend.quantization.w8a8_dynamic import \
13741377
AscendW8A8DynamicFusedMoEMethod
1378+
running_in_super_kernel = self.enable_super_kernel and fused_moe_state == FusedMoEState.MC2
13751379
if self.enable_multistream_moe:
1376-
if not self.rm_router_logits:
1377-
router_logits, _ = gate(hidden_states)
1378-
if hasattr(self.quant_method, "quant_method") and \
1379-
isinstance(self.quant_method.quant_method,
1380-
AscendW8A8DynamicFusedMoEMethod
1381-
) and fused_moe_state == FusedMoEState.MC2:
1382-
with npu_stream_switch("moe_secondary", 0):
1383-
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
1384-
hidden_states)
1380+
with super_kernel(self.prefix,
1381+
"stream-fusion=1",
1382+
enabled=running_in_super_kernel):
1383+
if not self.rm_router_logits:
1384+
if self.enable_super_kernel:
1385+
router_logits, _ = gate(hidden_states.float())
1386+
else:
1387+
router_logits, _ = gate(hidden_states)
1388+
if hasattr(self.quant_method, "quant_method") and \
1389+
isinstance(self.quant_method.quant_method,
1390+
AscendW8A8DynamicFusedMoEMethod
1391+
) and fused_moe_state == FusedMoEState.MC2:
1392+
with npu_stream_switch("moe_secondary", 0):
1393+
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
1394+
hidden_states)
13851395

13861396
if shared_experts:
13871397
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
@@ -1481,6 +1491,8 @@ def forward(self,
14811491
token_dispatcher=self.token_dispatcher,
14821492
quantized_x_for_share=quantized_x_for_share,
14831493
dynamic_scale_for_share=dynamic_scale_for_share,
1494+
prefix=self.prefix,
1495+
running_in_super_kernel=running_in_super_kernel,
14841496
)
14851497

14861498
if shared_experts:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 71 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from vllm_ascend.ascend_forward_context import FusedMoEState
2929
from vllm_ascend.distributed.parallel_state import get_mc2_group
3030
from vllm_ascend.ops.fused_moe import select_experts
31-
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
31+
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
32+
super_kernel)
3233
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
3334
dispose_tensor, get_ascend_soc_version)
3435

@@ -898,59 +899,62 @@ def apply(
898899
shared_experts: Optional[Any] = None,
899900
quantized_x_for_share: Optional[Any] = None,
900901
dynamic_scale_for_share: Optional[Any] = None,
902+
prefix: str = "",
903+
running_in_super_kernel: bool = False,
901904
**kwargs,
902905
) -> torch.Tensor:
903906
assert router_logits.shape[
904907
1] == global_num_experts, "Number of global experts mismatch"
905908

906909
is_deepseek_v3_r1 = global_num_experts == 256
907-
908-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
909-
if is_deepseek_v3_r1:
910-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
911-
router_logits,
912-
k=top_k, # topk当前写8
913-
bias=e_score_correction_bias,
914-
k_group=topk_group, # fix: 4
915-
group_count=num_expert_group, # fix 8
916-
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
917-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
918-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
919-
# out_flag=False, # todo new api; 第三个输出是否输出
920-
# y2_flag=False, # old api; 第三个输出是否输出
921-
routed_scaling_factor=1,
922-
eps=float(1e-20))
923-
else:
924-
topk_weights, topk_ids = select_experts(
925-
hidden_states=x,
926-
router_logits=router_logits,
927-
top_k=top_k,
928-
use_grouped_topk=use_grouped_topk,
929-
renormalize=renormalize,
930-
topk_group=topk_group,
931-
num_expert_group=num_expert_group,
932-
custom_routing_function=custom_routing_function,
933-
scoring_func=scoring_func,
934-
e_score_correction_bias=e_score_correction_bias,
935-
)
936-
937910
fused_moe_state = get_forward_context().fused_moe_state
938911
shared_gate_up, shared_dequant_scale = None, None
939-
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
940-
with npu_stream_switch("moe_secondary", 0):
941-
npu_wait_tensor(quantized_x_for_share, router_logits)
942-
share_up_out, _ = shared_experts.gate_up_proj(
943-
(quantized_x_for_share, dynamic_scale_for_share))
944-
shared_gate_up, shared_dequant_scale = share_up_out[
945-
0], share_up_out[1]
946-
947-
# this is a naive implementation for experts load balance so as
948-
# to avoid accumulating too much tokens on a single rank.
949-
# currently it is only activated when doing profile runs.
950-
if enable_force_load_balance:
951-
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
952-
953-
topk_weights = topk_weights.to(x.dtype)
912+
with super_kernel(prefix,
913+
"stream-fusion=1",
914+
enabled=running_in_super_kernel):
915+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
916+
if is_deepseek_v3_r1:
917+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
918+
router_logits,
919+
k=top_k, # topk当前写8
920+
bias=e_score_correction_bias,
921+
k_group=topk_group, # fix: 4
922+
group_count=num_expert_group, # fix 8
923+
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
924+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
925+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
926+
# out_flag=False, # todo new api; 第三个输出是否输出
927+
# y2_flag=False, # old api; 第三个输出是否输出
928+
routed_scaling_factor=1,
929+
eps=float(1e-20))
930+
else:
931+
topk_weights, topk_ids = select_experts(
932+
hidden_states=x,
933+
router_logits=router_logits,
934+
top_k=top_k,
935+
use_grouped_topk=use_grouped_topk,
936+
renormalize=renormalize,
937+
topk_group=topk_group,
938+
num_expert_group=num_expert_group,
939+
custom_routing_function=custom_routing_function,
940+
scoring_func=scoring_func,
941+
e_score_correction_bias=e_score_correction_bias,
942+
)
943+
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
944+
with npu_stream_switch("moe_secondary", 0):
945+
npu_wait_tensor(quantized_x_for_share, router_logits)
946+
share_up_out, _ = shared_experts.gate_up_proj(
947+
(quantized_x_for_share, dynamic_scale_for_share))
948+
shared_gate_up, shared_dequant_scale = share_up_out[
949+
0], share_up_out[1]
950+
951+
# this is a naive implementation for experts load balance so as
952+
# to avoid accumulating too much tokens on a single rank.
953+
# currently it is only activated when doing profile runs.
954+
if enable_force_load_balance:
955+
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
956+
957+
topk_weights = topk_weights.to(x.dtype)
954958
if fused_moe_state == FusedMoEState.AllGatherEP:
955959
return fused_experts_with_allgather(
956960
hidden_states=x,
@@ -963,24 +967,27 @@ def apply(
963967
top_k=top_k,
964968
expert_map=expert_map)
965969
elif fused_moe_state == FusedMoEState.MC2:
966-
return fused_experts_with_mc2(
967-
hidden_states=x,
968-
w1=layer.w13_weight,
969-
w2=layer.w2_weight,
970-
w1_scale=layer.w13_weight_scale_fp32,
971-
w2_scale=layer.w2_weight_scale,
972-
topk_weights=topk_weights,
973-
topk_ids=topk_ids,
974-
top_k=top_k,
975-
expert_map=expert_map,
976-
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
977-
log2phy=log2phy,
978-
global_redundant_expert_num=global_redundant_expert_num,
979-
shared_experts=shared_experts,
980-
is_torchair=self.torchair_graph_enabled,
981-
mc2_mask=kwargs.get("mc2_mask", None),
982-
shared_gate_up=shared_gate_up,
983-
shared_dequant_scale=shared_dequant_scale)
970+
with super_kernel(prefix,
971+
"stream-fusion=1",
972+
enabled=running_in_super_kernel):
973+
return fused_experts_with_mc2(
974+
hidden_states=x,
975+
w1=layer.w13_weight,
976+
w2=layer.w2_weight,
977+
w1_scale=layer.w13_weight_scale_fp32,
978+
w2_scale=layer.w2_weight_scale,
979+
topk_weights=topk_weights,
980+
topk_ids=topk_ids,
981+
top_k=top_k,
982+
expert_map=expert_map,
983+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
984+
log2phy=log2phy,
985+
global_redundant_expert_num=global_redundant_expert_num,
986+
shared_experts=shared_experts,
987+
is_torchair=self.torchair_graph_enabled,
988+
mc2_mask=kwargs.get("mc2_mask", None),
989+
shared_gate_up=shared_gate_up,
990+
shared_dequant_scale=shared_dequant_scale)
984991
elif fused_moe_state in [
985992
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
986993
]:

vllm_ascend/torchair/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from contextlib import contextmanager, nullcontext
55

66
import torch
7+
from torchair.scope import super_kernel as _super_kernel
78

89
try:
910
# Recent release of torchair has moved these ops to `.scope`.
@@ -96,3 +97,7 @@ def npu_wait_tensor(self: torch.Tensor,
9697
*,
9798
enabled: bool = True):
9899
return _npu_wait_tensor(self, dependency) if enabled else self
100+
101+
102+
def super_kernel(prefix: str, option: str, enabled: bool = True):
103+
return _super_kernel(prefix, option) if enabled else nullcontext()

0 commit comments

Comments
 (0)