28
28
from vllm_ascend .ascend_forward_context import FusedMoEState
29
29
from vllm_ascend .distributed .parallel_state import get_mc2_group
30
30
from vllm_ascend .ops .fused_moe import select_experts
31
- from vllm_ascend .torchair .utils import npu_stream_switch , npu_wait_tensor
31
+ from vllm_ascend .torchair .utils import (npu_stream_switch , npu_wait_tensor ,
32
+ super_kernel )
32
33
from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , AscendSocVersion ,
33
34
dispose_tensor , get_ascend_soc_version )
34
35
@@ -898,59 +899,62 @@ def apply(
898
899
shared_experts : Optional [Any ] = None ,
899
900
quantized_x_for_share : Optional [Any ] = None ,
900
901
dynamic_scale_for_share : Optional [Any ] = None ,
902
+ prefix : str = "" ,
903
+ running_in_super_kernel : bool = False ,
901
904
** kwargs ,
902
905
) -> torch .Tensor :
903
906
assert router_logits .shape [
904
907
1 ] == global_num_experts , "Number of global experts mismatch"
905
908
906
909
is_deepseek_v3_r1 = global_num_experts == 256
907
-
908
- # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
909
- if is_deepseek_v3_r1 :
910
- topk_weights , topk_ids , _ = torch_npu .npu_moe_gating_top_k (
911
- router_logits ,
912
- k = top_k , # topk当前写8
913
- bias = e_score_correction_bias ,
914
- k_group = topk_group , # fix: 4
915
- group_count = num_expert_group , # fix 8
916
- group_select_mode = 1 , # 0: group中的最大; 1: topk2.sum(fix)
917
- renorm = 0 , # 0: softmax->topk(fix); 1: topk->softmax
918
- norm_type = 1 , # 0: softmax; 1: sigmoid(fix)
919
- # out_flag=False, # todo new api; 第三个输出是否输出
920
- # y2_flag=False, # old api; 第三个输出是否输出
921
- routed_scaling_factor = 1 ,
922
- eps = float (1e-20 ))
923
- else :
924
- topk_weights , topk_ids = select_experts (
925
- hidden_states = x ,
926
- router_logits = router_logits ,
927
- top_k = top_k ,
928
- use_grouped_topk = use_grouped_topk ,
929
- renormalize = renormalize ,
930
- topk_group = topk_group ,
931
- num_expert_group = num_expert_group ,
932
- custom_routing_function = custom_routing_function ,
933
- scoring_func = scoring_func ,
934
- e_score_correction_bias = e_score_correction_bias ,
935
- )
936
-
937
910
fused_moe_state = get_forward_context ().fused_moe_state
938
911
shared_gate_up , shared_dequant_scale = None , None
939
- if shared_experts is not None and fused_moe_state == FusedMoEState .MC2 :
940
- with npu_stream_switch ("moe_secondary" , 0 ):
941
- npu_wait_tensor (quantized_x_for_share , router_logits )
942
- share_up_out , _ = shared_experts .gate_up_proj (
943
- (quantized_x_for_share , dynamic_scale_for_share ))
944
- shared_gate_up , shared_dequant_scale = share_up_out [
945
- 0 ], share_up_out [1 ]
946
-
947
- # this is a naive implementation for experts load balance so as
948
- # to avoid accumulating too much tokens on a single rank.
949
- # currently it is only activated when doing profile runs.
950
- if enable_force_load_balance :
951
- topk_ids = torch .randint_like (topk_ids , 0 , global_num_experts )
952
-
953
- topk_weights = topk_weights .to (x .dtype )
912
+ with super_kernel (prefix ,
913
+ "stream-fusion=1" ,
914
+ enabled = running_in_super_kernel ):
915
+ # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
916
+ if is_deepseek_v3_r1 :
917
+ topk_weights , topk_ids , _ = torch_npu .npu_moe_gating_top_k (
918
+ router_logits ,
919
+ k = top_k , # topk当前写8
920
+ bias = e_score_correction_bias ,
921
+ k_group = topk_group , # fix: 4
922
+ group_count = num_expert_group , # fix 8
923
+ group_select_mode = 1 , # 0: group中的最大; 1: topk2.sum(fix)
924
+ renorm = 0 , # 0: softmax->topk(fix); 1: topk->softmax
925
+ norm_type = 1 , # 0: softmax; 1: sigmoid(fix)
926
+ # out_flag=False, # todo new api; 第三个输出是否输出
927
+ # y2_flag=False, # old api; 第三个输出是否输出
928
+ routed_scaling_factor = 1 ,
929
+ eps = float (1e-20 ))
930
+ else :
931
+ topk_weights , topk_ids = select_experts (
932
+ hidden_states = x ,
933
+ router_logits = router_logits ,
934
+ top_k = top_k ,
935
+ use_grouped_topk = use_grouped_topk ,
936
+ renormalize = renormalize ,
937
+ topk_group = topk_group ,
938
+ num_expert_group = num_expert_group ,
939
+ custom_routing_function = custom_routing_function ,
940
+ scoring_func = scoring_func ,
941
+ e_score_correction_bias = e_score_correction_bias ,
942
+ )
943
+ if shared_experts is not None and fused_moe_state == FusedMoEState .MC2 :
944
+ with npu_stream_switch ("moe_secondary" , 0 ):
945
+ npu_wait_tensor (quantized_x_for_share , router_logits )
946
+ share_up_out , _ = shared_experts .gate_up_proj (
947
+ (quantized_x_for_share , dynamic_scale_for_share ))
948
+ shared_gate_up , shared_dequant_scale = share_up_out [
949
+ 0 ], share_up_out [1 ]
950
+
951
+ # this is a naive implementation for experts load balance so as
952
+ # to avoid accumulating too much tokens on a single rank.
953
+ # currently it is only activated when doing profile runs.
954
+ if enable_force_load_balance :
955
+ topk_ids = torch .randint_like (topk_ids , 0 , global_num_experts )
956
+
957
+ topk_weights = topk_weights .to (x .dtype )
954
958
if fused_moe_state == FusedMoEState .AllGatherEP :
955
959
return fused_experts_with_allgather (
956
960
hidden_states = x ,
@@ -963,24 +967,27 @@ def apply(
963
967
top_k = top_k ,
964
968
expert_map = expert_map )
965
969
elif fused_moe_state == FusedMoEState .MC2 :
966
- return fused_experts_with_mc2 (
967
- hidden_states = x ,
968
- w1 = layer .w13_weight ,
969
- w2 = layer .w2_weight ,
970
- w1_scale = layer .w13_weight_scale_fp32 ,
971
- w2_scale = layer .w2_weight_scale ,
972
- topk_weights = topk_weights ,
973
- topk_ids = topk_ids ,
974
- top_k = top_k ,
975
- expert_map = expert_map ,
976
- moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
977
- log2phy = log2phy ,
978
- global_redundant_expert_num = global_redundant_expert_num ,
979
- shared_experts = shared_experts ,
980
- is_torchair = self .torchair_graph_enabled ,
981
- mc2_mask = kwargs .get ("mc2_mask" , None ),
982
- shared_gate_up = shared_gate_up ,
983
- shared_dequant_scale = shared_dequant_scale )
970
+ with super_kernel (prefix ,
971
+ "stream-fusion=1" ,
972
+ enabled = running_in_super_kernel ):
973
+ return fused_experts_with_mc2 (
974
+ hidden_states = x ,
975
+ w1 = layer .w13_weight ,
976
+ w2 = layer .w2_weight ,
977
+ w1_scale = layer .w13_weight_scale_fp32 ,
978
+ w2_scale = layer .w2_weight_scale ,
979
+ topk_weights = topk_weights ,
980
+ topk_ids = topk_ids ,
981
+ top_k = top_k ,
982
+ expert_map = expert_map ,
983
+ moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
984
+ log2phy = log2phy ,
985
+ global_redundant_expert_num = global_redundant_expert_num ,
986
+ shared_experts = shared_experts ,
987
+ is_torchair = self .torchair_graph_enabled ,
988
+ mc2_mask = kwargs .get ("mc2_mask" , None ),
989
+ shared_gate_up = shared_gate_up ,
990
+ shared_dequant_scale = shared_dequant_scale )
984
991
elif fused_moe_state in [
985
992
FusedMoEState .AllGather , FusedMoEState .NaiveMulticast
986
993
]:
0 commit comments