Skip to content

Commit b4e3fb1

Browse files
committed
add super kernel for deocode moe
Signed-off-by: NNUCJ <616151263@qq.com>
1 parent 875a86c commit b4e3fb1

File tree

6 files changed

+149
-104
lines changed

6 files changed

+149
-104
lines changed

tests/ut/ops/test_fused_ops.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,23 @@ def test_init_with_quant(self, mock_dist_env, default_moe_config):
224224

225225
@pytest.mark.parametrize(
226226
"others_param",
227-
[[None,
228-
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
229-
[2, None, False, 5, None], [None, None, True, 5, None],
230-
[None, None, False, 1, None], [None, None, True, 5, 1],
231-
[None, None, False, 5, 1]])
227+
[[
228+
None,
229+
MagicMock(return_value=torch.randn(5, 32)), False, 5, None, False,
230+
None, False
231+
], [2, None, False, 5, None, False, None, False],
232+
[None, None, True, 5, None, False, None, False],
233+
[None, None, False, 1, None, False, None, False],
234+
[None, None, True, 5, 1,
235+
False, None, False], [None, None, False, 5, 1, False, None, False],
236+
[
237+
None, None, True, 5, 1, True,
238+
MagicMock(return_value=(torch.randn(5, 8), None)), False
239+
],
240+
[
241+
None, None, False, 5, 1, True,
242+
MagicMock(return_value=(torch.randn(5, 8), None)), True
243+
]])
232244
def test_forward(self, mock_dist_env, default_moe_config, others_param):
233245
"""
234246
1 test has shared_experts
@@ -237,15 +249,22 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
237249
4 test single num_tokens(decode)
238250
5 test ep_size is 1 and is_prefill is true
239251
6 test ep_size is 1 and is_prefill is False
252+
7 test enable_multistream_moe and is_prefill is true
240253
"""
241-
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
254+
top_k, shared_experts, is_prefill, num_tokens, ep_size, enable_multistream_moe, gate, enable_super_kernel = others_param
242255
inputs = torch.randn(num_tokens, 32)
243256
router_logits = torch.randn(num_tokens, 8)
244257
moe = AscendFusedMoE(**default_moe_config)
245258

246259
if ep_size == 1:
247260
moe.moe_parallel_config.ep_size = 1
248261

262+
if enable_multistream_moe:
263+
moe.enable_multistream_moe = True
264+
265+
if enable_super_kernel:
266+
moe.enable_super_kernel = True
267+
249268
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
250269
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
251270
dtype=torch.bool),
@@ -256,7 +275,8 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
256275
router_logits,
257276
is_prefill=is_prefill,
258277
top_k=top_k,
259-
shared_experts=shared_experts)
278+
shared_experts=shared_experts,
279+
gate=gate)
260280

261281
moe.quant_method.apply.assert_called_once()
262282

@@ -266,22 +286,6 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
266286
else:
267287
assert output.shape == (num_tokens, 32)
268288

269-
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
270-
default_moe_config):
271-
inputs = torch.randn(5, 32)
272-
router_logits = torch.randn(5, 8)
273-
moe = AscendFusedMoE(**default_moe_config)
274-
275-
moe.quant_method = MockQuantMethod(None, 5)
276-
output = moe._forward_ms_fused_moe_comp(inputs,
277-
router_logits,
278-
is_prefill=False,
279-
real_top_k=1)
280-
281-
moe.quant_method.apply.assert_called_once()
282-
283-
assert output.shape == (5, 32)
284-
285289

286290
class TestAscendUnquantizedFusedMoEMethod:
287291

tests/ut/torchair/test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import os
2+
from contextlib import nullcontext
3+
4+
import torchair
25

36
from tests.ut.base import TestBase
47
from vllm_ascend.torchair import utils
@@ -26,3 +29,11 @@ def test_torchair_cache_dir(self):
2629
"Delete torchair cache dir failed")
2730
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
2831
"Delete kv cache bytes cache dir failed")
32+
33+
def test_super_kernel(self):
34+
super_kernel_unenable = utils.super_kernel("prefix", "stream-fusion=1",
35+
False)
36+
self.assertTrue(super_kernel_unenable, nullcontext())
37+
super_kernel_enable = utils.super_kernel("prefix", "stream-fusion=1",
38+
True)
39+
self.assertIsInstance(super_kernel_enable, torchair.scope._Scope)

vllm_ascend/models/deepseek_v2.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -315,15 +315,22 @@ def __init__(
315315
self.enable_multistream_moe = \
316316
ascend_config.torchair_graph_config.enable_multistream_moe and \
317317
self.torchair_graph_enabled
318+
self.enable_super_kernel = self.enable_multistream_moe and self.tp_size == 1
319+
self.params_dtype = torch.get_default_dtype()
318320

319-
self.gate = ReplicatedLinear(config.hidden_size,
320-
config.n_routed_experts,
321-
bias=False,
322-
quant_config=None,
323-
prefix=f"{prefix}.gate")
321+
self.gate = ReplicatedLinear(
322+
config.hidden_size,
323+
config.n_routed_experts,
324+
bias=False,
325+
quant_config=None,
326+
params_dtype=torch.float32
327+
if self.enable_super_kernel else self.params_dtype,
328+
prefix=f"{prefix}.gate")
324329
if config.topk_method == "noaux_tc":
325330
self.gate.e_score_correction_bias = nn.Parameter(
326-
torch.empty(config.n_routed_experts))
331+
torch.empty(config.n_routed_experts,
332+
dtype=torch.float if self.enable_super_kernel else
333+
self.params_dtype))
327334
else:
328335
self.gate.e_score_correction_bias = None
329336

@@ -370,7 +377,6 @@ def __init__(
370377
if transfer_config is not None:
371378
self.kv_consumer = transfer_config.kv_role == "kv_consumer"
372379

373-
self.params_dtype = torch.get_default_dtype()
374380
self.rm_router_logits = self.experts.rm_router_logits
375381

376382
def forward(self,

vllm_ascend/ops/fused_moe.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4848
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
4949
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
50-
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
50+
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
51+
super_kernel)
5152
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
5253
get_all_reduce_merge_state,
5354
get_ascend_soc_version,
@@ -1198,6 +1199,7 @@ def __init__(
11981199

11991200
AscendFusedMoE.moe_counter += 1
12001201
self.moe_instance_id = AscendFusedMoE.moe_counter
1202+
self.prefix = prefix
12011203

12021204
if params_dtype is None:
12031205
params_dtype = torch.get_default_dtype()
@@ -1262,6 +1264,7 @@ def __init__(
12621264
self.enable_multistream_moe = \
12631265
ascend_config.torchair_graph_config.enable_multistream_moe and \
12641266
self.torchair_graph_enabled
1267+
self.enable_super_kernel = self.enable_multistream_moe and tp_size == 1
12651268

12661269
if self.scoring_func != "softmax" and not self.use_grouped_topk:
12671270
raise ValueError("Only softmax scoring function is supported for "
@@ -1365,16 +1368,23 @@ def forward(self,
13651368
quantized_x_for_share, dynamic_scale_for_share = None, None
13661369
from vllm_ascend.quantization.w8a8_dynamic import \
13671370
AscendW8A8DynamicFusedMoEMethod
1371+
running_in_super_kernel = self.enable_super_kernel and fused_moe_state == FusedMoEState.MC2
13681372
if self.enable_multistream_moe:
1369-
if not self.rm_router_logits:
1370-
router_logits, _ = gate(hidden_states)
1371-
if hasattr(self.quant_method, "quant_method") and \
1372-
isinstance(self.quant_method.quant_method,
1373-
AscendW8A8DynamicFusedMoEMethod
1374-
) and fused_moe_state == FusedMoEState.MC2:
1375-
with npu_stream_switch("moe_secondary", 0):
1376-
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
1377-
hidden_states)
1373+
with super_kernel(self.prefix,
1374+
"stream-fusion=1",
1375+
enabled=running_in_super_kernel):
1376+
if not self.rm_router_logits:
1377+
if self.enable_super_kernel:
1378+
router_logits, _ = gate(hidden_states.float())
1379+
else:
1380+
router_logits, _ = gate(hidden_states)
1381+
if hasattr(self.quant_method, "quant_method") and \
1382+
isinstance(self.quant_method.quant_method,
1383+
AscendW8A8DynamicFusedMoEMethod
1384+
) and fused_moe_state == FusedMoEState.MC2:
1385+
with npu_stream_switch("moe_secondary", 0):
1386+
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
1387+
hidden_states)
13781388

13791389
if shared_experts:
13801390
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
@@ -1464,6 +1474,8 @@ def forward(self,
14641474
token_dispatcher=self.token_dispatcher,
14651475
quantized_x_for_share=quantized_x_for_share,
14661476
dynamic_scale_for_share=dynamic_scale_for_share,
1477+
prefix=self.prefix,
1478+
running_in_super_kernel=running_in_super_kernel,
14671479
)
14681480

14691481
if shared_experts:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 71 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from vllm_ascend.ascend_forward_context import FusedMoEState
2929
from vllm_ascend.distributed.parallel_state import get_mc2_group
3030
from vllm_ascend.ops.fused_moe import select_experts
31-
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
31+
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
32+
super_kernel)
3233
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
3334
dispose_tensor, get_ascend_soc_version)
3435

@@ -898,59 +899,62 @@ def apply(
898899
shared_experts: Optional[Any] = None,
899900
quantized_x_for_share: Optional[Any] = None,
900901
dynamic_scale_for_share: Optional[Any] = None,
902+
prefix: str = "",
903+
running_in_super_kernel: bool = False,
901904
**kwargs,
902905
) -> torch.Tensor:
903906
assert router_logits.shape[
904907
1] == global_num_experts, "Number of global experts mismatch"
905908

906909
is_deepseek_v3_r1 = global_num_experts == 256
907-
908-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
909-
if is_deepseek_v3_r1:
910-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
911-
router_logits,
912-
k=top_k, # topk当前写8
913-
bias=e_score_correction_bias,
914-
k_group=topk_group, # fix: 4
915-
group_count=num_expert_group, # fix 8
916-
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
917-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
918-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
919-
# out_flag=False, # todo new api; 第三个输出是否输出
920-
# y2_flag=False, # old api; 第三个输出是否输出
921-
routed_scaling_factor=1,
922-
eps=float(1e-20))
923-
else:
924-
topk_weights, topk_ids = select_experts(
925-
hidden_states=x,
926-
router_logits=router_logits,
927-
top_k=top_k,
928-
use_grouped_topk=use_grouped_topk,
929-
renormalize=renormalize,
930-
topk_group=topk_group,
931-
num_expert_group=num_expert_group,
932-
custom_routing_function=custom_routing_function,
933-
scoring_func=scoring_func,
934-
e_score_correction_bias=e_score_correction_bias,
935-
)
936-
937910
fused_moe_state = get_forward_context().fused_moe_state
938911
shared_gate_up, shared_dequant_scale = None, None
939-
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
940-
with npu_stream_switch("moe_secondary", 0):
941-
npu_wait_tensor(quantized_x_for_share, router_logits)
942-
share_up_out, _ = shared_experts.gate_up_proj(
943-
(quantized_x_for_share, dynamic_scale_for_share))
944-
shared_gate_up, shared_dequant_scale = share_up_out[
945-
0], share_up_out[1]
946-
947-
# this is a naive implementation for experts load balance so as
948-
# to avoid accumulating too much tokens on a single rank.
949-
# currently it is only activated when doing profile runs.
950-
if enable_force_load_balance:
951-
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
952-
953-
topk_weights = topk_weights.to(x.dtype)
912+
with super_kernel(prefix,
913+
"stream-fusion=1",
914+
enabled=running_in_super_kernel):
915+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
916+
if is_deepseek_v3_r1:
917+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
918+
router_logits,
919+
k=top_k, # topk当前写8
920+
bias=e_score_correction_bias,
921+
k_group=topk_group, # fix: 4
922+
group_count=num_expert_group, # fix 8
923+
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
924+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
925+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
926+
# out_flag=False, # todo new api; 第三个输出是否输出
927+
# y2_flag=False, # old api; 第三个输出是否输出
928+
routed_scaling_factor=1,
929+
eps=float(1e-20))
930+
else:
931+
topk_weights, topk_ids = select_experts(
932+
hidden_states=x,
933+
router_logits=router_logits,
934+
top_k=top_k,
935+
use_grouped_topk=use_grouped_topk,
936+
renormalize=renormalize,
937+
topk_group=topk_group,
938+
num_expert_group=num_expert_group,
939+
custom_routing_function=custom_routing_function,
940+
scoring_func=scoring_func,
941+
e_score_correction_bias=e_score_correction_bias,
942+
)
943+
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
944+
with npu_stream_switch("moe_secondary", 0):
945+
npu_wait_tensor(quantized_x_for_share, router_logits)
946+
share_up_out, _ = shared_experts.gate_up_proj(
947+
(quantized_x_for_share, dynamic_scale_for_share))
948+
shared_gate_up, shared_dequant_scale = share_up_out[
949+
0], share_up_out[1]
950+
951+
# this is a naive implementation for experts load balance so as
952+
# to avoid accumulating too much tokens on a single rank.
953+
# currently it is only activated when doing profile runs.
954+
if enable_force_load_balance:
955+
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
956+
957+
topk_weights = topk_weights.to(x.dtype)
954958
if fused_moe_state == FusedMoEState.AllGatherEP:
955959
return fused_experts_with_allgather(
956960
hidden_states=x,
@@ -963,24 +967,27 @@ def apply(
963967
top_k=top_k,
964968
expert_map=expert_map)
965969
elif fused_moe_state == FusedMoEState.MC2:
966-
return fused_experts_with_mc2(
967-
hidden_states=x,
968-
w1=layer.w13_weight,
969-
w2=layer.w2_weight,
970-
w1_scale=layer.w13_weight_scale_fp32,
971-
w2_scale=layer.w2_weight_scale,
972-
topk_weights=topk_weights,
973-
topk_ids=topk_ids,
974-
top_k=top_k,
975-
expert_map=expert_map,
976-
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
977-
log2phy=log2phy,
978-
global_redundant_expert_num=global_redundant_expert_num,
979-
shared_experts=shared_experts,
980-
is_torchair=self.torchair_graph_enabled,
981-
mc2_mask=kwargs.get("mc2_mask", None),
982-
shared_gate_up=shared_gate_up,
983-
shared_dequant_scale=shared_dequant_scale)
970+
with super_kernel(prefix,
971+
"stream-fusion=1",
972+
enabled=running_in_super_kernel):
973+
return fused_experts_with_mc2(
974+
hidden_states=x,
975+
w1=layer.w13_weight,
976+
w2=layer.w2_weight,
977+
w1_scale=layer.w13_weight_scale_fp32,
978+
w2_scale=layer.w2_weight_scale,
979+
topk_weights=topk_weights,
980+
topk_ids=topk_ids,
981+
top_k=top_k,
982+
expert_map=expert_map,
983+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
984+
log2phy=log2phy,
985+
global_redundant_expert_num=global_redundant_expert_num,
986+
shared_experts=shared_experts,
987+
is_torchair=self.torchair_graph_enabled,
988+
mc2_mask=kwargs.get("mc2_mask", None),
989+
shared_gate_up=shared_gate_up,
990+
shared_dequant_scale=shared_dequant_scale)
984991
elif fused_moe_state in [
985992
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
986993
]:

vllm_ascend/torchair/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from contextlib import contextmanager, nullcontext
55

66
import torch
7+
from torchair.scope import super_kernel as _super_kernel
78

89
try:
910
# Recent release of torchair has moved these ops to `.scope`.
@@ -96,3 +97,7 @@ def npu_wait_tensor(self: torch.Tensor,
9697
*,
9798
enabled: bool = True):
9899
return _npu_wait_tensor(self, dependency) if enabled else self
100+
101+
102+
def super_kernel(prefix: str, stream: str, enabled: bool = True):
103+
return _super_kernel(prefix, stream) if enabled else nullcontext()

0 commit comments

Comments
 (0)