18
18
except ImportError :
19
19
has_pplx = False
20
20
21
+ from tests .kernels .moe .utils import make_test_weights , naive_batched_moe
21
22
from tests .kernels .utils import torch_experts
22
- from tests .kernels .moe .utils import (make_test_weights , naive_batched_moe )
23
23
from vllm .config import VllmConfig , set_current_vllm_config
24
- from vllm .model_executor .layers .fused_moe import (
25
- override_config ,
26
- fused_topk )
27
- from vllm .model_executor .layers .fused_moe .fused_moe import get_default_config
24
+ from vllm .model_executor .layers .fused_moe import fused_topk , override_config
28
25
from vllm .model_executor .layers .fused_moe .config import FusedMoEQuantConfig
29
- from vllm .model_executor .layers .fused_moe .modular_kernel import (
30
- FusedMoEModularKernel )
31
26
from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
32
27
BatchedPrepareAndFinalize , BatchedTritonExperts , NaiveBatchedExperts )
28
+ from vllm .model_executor .layers .fused_moe .fused_moe import get_default_config
29
+ from vllm .model_executor .layers .fused_moe .modular_kernel import (
30
+ FusedMoEModularKernel )
33
31
from vllm .platforms import current_platform
34
32
from vllm .utils import round_up
35
33
@@ -573,11 +571,14 @@ def _pplx_moe(
573
571
574
572
with set_current_vllm_config (vllm_config ), override_config (moe_config ):
575
573
topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
576
- torch_output = torch_experts (a , w1 , w2 , topk_weight , topk_ids , w1_s , w2_s ,
577
- qtype , per_act_token_quant , block_shape )
578
- pplx_output = pplx_moe (group_name , pgi .rank , pgi .world_size , dp_size , a ,
579
- w1 , w2 , topk_weight , topk_ids , w1_s , w2_s , qtype ,
580
- per_act_token_quant , block_shape )
574
+ torch_output = torch_experts (a , w1 , w2 , topk_weight , topk_ids ,
575
+ w1_scale = w1_s , w2_scale = w2_s ,
576
+ quant_dtype = qtype ,
577
+ per_act_token_quant = per_act_token_quant ,
578
+ block_shape = block_shape )
579
+ pplx_output = pplx_moe (group_name , pgi .rank , pgi .world_size , dp_size ,
580
+ a , w1 , w2 , topk_weight , topk_ids , w1_s , w2_s ,
581
+ qtype , per_act_token_quant , block_shape )
581
582
# TODO (bnell): fix + re-enable
582
583
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
583
584
# topk_ids)
@@ -595,7 +596,7 @@ def _pplx_moe(
595
596
@pytest .mark .parametrize ("mnk" , PPLX_MOE_COMBOS )
596
597
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
597
598
@pytest .mark .parametrize ("topk" , TOP_KS )
598
- @pytest .mark .parametrize ("dtype" , [torch .bfloat16 ]) # torch.float8_e4m3fn,
599
+ @pytest .mark .parametrize ("dtype" , [torch .bfloat16 ]) # torch.float8_e4m3fn,
599
600
@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]])
600
601
@pytest .mark .parametrize ("per_act_token_quant" , [False , True ])
601
602
@pytest .mark .parametrize ("block_shape" , [None , [128 , 128 ]])
@@ -628,8 +629,11 @@ def test_pplx_moe(
628
629
a = torch .randn ((m , k ), device = "cuda" , dtype = torch .bfloat16 ) / 10
629
630
score = torch .randn ((m , e ), device = "cuda" , dtype = torch .bfloat16 )
630
631
631
- _ , w1 , w1_s , _ , w2 , w2_s = make_test_weights (
632
- e , n , k , quant_dtype = quant_dtype , block_shape = block_shape )
632
+ _ , w1 , w1_s , _ , w2 , w2_s = make_test_weights (e ,
633
+ n ,
634
+ k ,
635
+ quant_dtype = quant_dtype ,
636
+ block_shape = block_shape )
633
637
634
638
parallel_launch (world_size , _pplx_moe , dp_size , a , w1 , w2 , score , topk ,
635
639
w1_s , w2_s , quant_dtype , per_act_token_quant , block_shape ,
0 commit comments