28
28
from vllm_ascend .ascend_forward_context import FusedMoEState
29
29
from vllm_ascend .distributed .parallel_state import get_mc2_group
30
30
from vllm_ascend .ops .fused_moe import select_experts
31
- from vllm_ascend .torchair .utils import npu_stream_switch , npu_wait_tensor
31
+ from vllm_ascend .torchair .utils import (npu_stream_switch , npu_wait_tensor ,
32
+ super_kernel )
32
33
from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , AscendSocVersion ,
33
34
dispose_tensor , get_ascend_soc_version )
34
35
@@ -830,59 +831,62 @@ def apply(
830
831
shared_experts : Optional [Any ] = None ,
831
832
quantized_x_for_share : Optional [Any ] = None ,
832
833
dynamic_scale_for_share : Optional [Any ] = None ,
834
+ prefix : str = "" ,
835
+ running_in_super_kernel : bool = False ,
833
836
** kwargs ,
834
837
) -> torch .Tensor :
835
838
assert router_logits .shape [
836
839
1 ] == global_num_experts , "Number of global experts mismatch"
837
840
838
841
is_deepseek_v3_r1 = global_num_experts == 256
839
-
840
- # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
841
- if is_deepseek_v3_r1 :
842
- topk_weights , topk_ids , _ = torch_npu .npu_moe_gating_top_k (
843
- router_logits ,
844
- k = top_k , # topk当前写8
845
- bias = e_score_correction_bias ,
846
- k_group = topk_group , # fix: 4
847
- group_count = num_expert_group , # fix 8
848
- group_select_mode = 1 , # 0: group中的最大; 1: topk2.sum(fix)
849
- renorm = 0 , # 0: softmax->topk(fix); 1: topk->softmax
850
- norm_type = 1 , # 0: softmax; 1: sigmoid(fix)
851
- # out_flag=False, # todo new api; 第三个输出是否输出
852
- # y2_flag=False, # old api; 第三个输出是否输出
853
- routed_scaling_factor = 1 ,
854
- eps = float (1e-20 ))
855
- else :
856
- topk_weights , topk_ids = select_experts (
857
- hidden_states = x ,
858
- router_logits = router_logits ,
859
- top_k = top_k ,
860
- use_grouped_topk = use_grouped_topk ,
861
- renormalize = renormalize ,
862
- topk_group = topk_group ,
863
- num_expert_group = num_expert_group ,
864
- custom_routing_function = custom_routing_function ,
865
- scoring_func = scoring_func ,
866
- e_score_correction_bias = e_score_correction_bias ,
867
- )
868
-
869
842
fused_moe_state = get_forward_context ().fused_moe_state
870
843
shared_gate_up , shared_dequant_scale = None , None
871
- if shared_experts is not None and fused_moe_state == FusedMoEState .MC2 :
872
- with npu_stream_switch ("moe_secondary" , 0 ):
873
- npu_wait_tensor (quantized_x_for_share , router_logits )
874
- share_up_out , _ = shared_experts .gate_up_proj (
875
- (quantized_x_for_share , dynamic_scale_for_share ))
876
- shared_gate_up , shared_dequant_scale = share_up_out [
877
- 0 ], share_up_out [1 ]
878
-
879
- # this is a naive implementation for experts load balance so as
880
- # to avoid accumulating too much tokens on a single rank.
881
- # currently it is only activated when doing profile runs.
882
- if enable_force_load_balance :
883
- topk_ids = torch .randint_like (topk_ids , 0 , global_num_experts )
884
-
885
- topk_weights = topk_weights .to (x .dtype )
844
+ with super_kernel (prefix ,
845
+ "stream-fusion=1" ,
846
+ enabled = running_in_super_kernel ):
847
+ # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
848
+ if is_deepseek_v3_r1 :
849
+ topk_weights , topk_ids , _ = torch_npu .npu_moe_gating_top_k (
850
+ router_logits ,
851
+ k = top_k , # topk当前写8
852
+ bias = e_score_correction_bias ,
853
+ k_group = topk_group , # fix: 4
854
+ group_count = num_expert_group , # fix 8
855
+ group_select_mode = 1 , # 0: group中的最大; 1: topk2.sum(fix)
856
+ renorm = 0 , # 0: softmax->topk(fix); 1: topk->softmax
857
+ norm_type = 1 , # 0: softmax; 1: sigmoid(fix)
858
+ # out_flag=False, # todo new api; 第三个输出是否输出
859
+ # y2_flag=False, # old api; 第三个输出是否输出
860
+ routed_scaling_factor = 1 ,
861
+ eps = float (1e-20 ))
862
+ else :
863
+ topk_weights , topk_ids = select_experts (
864
+ hidden_states = x ,
865
+ router_logits = router_logits ,
866
+ top_k = top_k ,
867
+ use_grouped_topk = use_grouped_topk ,
868
+ renormalize = renormalize ,
869
+ topk_group = topk_group ,
870
+ num_expert_group = num_expert_group ,
871
+ custom_routing_function = custom_routing_function ,
872
+ scoring_func = scoring_func ,
873
+ e_score_correction_bias = e_score_correction_bias ,
874
+ )
875
+ if shared_experts is not None and fused_moe_state == FusedMoEState .MC2 :
876
+ with npu_stream_switch ("moe_secondary" , 0 ):
877
+ npu_wait_tensor (quantized_x_for_share , router_logits )
878
+ share_up_out , _ = shared_experts .gate_up_proj (
879
+ (quantized_x_for_share , dynamic_scale_for_share ))
880
+ shared_gate_up , shared_dequant_scale = share_up_out [
881
+ 0 ], share_up_out [1 ]
882
+
883
+ # this is a naive implementation for experts load balance so as
884
+ # to avoid accumulating too much tokens on a single rank.
885
+ # currently it is only activated when doing profile runs.
886
+ if enable_force_load_balance :
887
+ topk_ids = torch .randint_like (topk_ids , 0 , global_num_experts )
888
+
889
+ topk_weights = topk_weights .to (x .dtype )
886
890
if fused_moe_state == FusedMoEState .AllGatherEP :
887
891
return fused_experts_with_allgather (
888
892
hidden_states = x ,
@@ -895,24 +899,27 @@ def apply(
895
899
top_k = top_k ,
896
900
expert_map = expert_map )
897
901
elif fused_moe_state == FusedMoEState .MC2 :
898
- return fused_experts_with_mc2 (
899
- hidden_states = x ,
900
- w1 = layer .w13_weight ,
901
- w2 = layer .w2_weight ,
902
- w1_scale = layer .w13_weight_scale_fp32 ,
903
- w2_scale = layer .w2_weight_scale ,
904
- topk_weights = topk_weights ,
905
- topk_ids = topk_ids ,
906
- top_k = top_k ,
907
- expert_map = expert_map ,
908
- moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
909
- log2phy = log2phy ,
910
- global_redundant_expert_num = global_redundant_expert_num ,
911
- shared_experts = shared_experts ,
912
- is_torchair = self .torchair_graph_enabled ,
913
- mc2_mask = kwargs .get ("mc2_mask" , None ),
914
- shared_gate_up = shared_gate_up ,
915
- shared_dequant_scale = shared_dequant_scale )
902
+ with super_kernel (prefix ,
903
+ "stream-fusion=1" ,
904
+ enabled = running_in_super_kernel ):
905
+ return fused_experts_with_mc2 (
906
+ hidden_states = x ,
907
+ w1 = layer .w13_weight ,
908
+ w2 = layer .w2_weight ,
909
+ w1_scale = layer .w13_weight_scale_fp32 ,
910
+ w2_scale = layer .w2_weight_scale ,
911
+ topk_weights = topk_weights ,
912
+ topk_ids = topk_ids ,
913
+ top_k = top_k ,
914
+ expert_map = expert_map ,
915
+ moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
916
+ log2phy = log2phy ,
917
+ global_redundant_expert_num = global_redundant_expert_num ,
918
+ shared_experts = shared_experts ,
919
+ is_torchair = self .torchair_graph_enabled ,
920
+ mc2_mask = kwargs .get ("mc2_mask" , None ),
921
+ shared_gate_up = shared_gate_up ,
922
+ shared_dequant_scale = shared_dequant_scale )
916
923
elif fused_moe_state in [
917
924
FusedMoEState .AllGather , FusedMoEState .NaiveMulticast
918
925
]:
0 commit comments