55import textwrap
66import traceback
77from itertools import product
8- from typing import Optional
8+ from typing import Any , Optional
99
1010import pytest
1111import torch
1212
1313import vllm .model_executor .layers .fused_moe .modular_kernel as mk
1414from vllm .config import VllmConfig , set_current_vllm_config
1515from vllm .platforms import current_platform
16- from vllm .utils import has_deep_ep , has_deep_gemm , has_pplx
16+ from vllm .utils import (cuda_device_count_stateless , has_deep_ep ,
17+ has_deep_gemm , has_pplx )
1718from vllm .utils .flashinfer import has_flashinfer_cutlass_fused_moe
1819
19- from ...utils import multi_gpu_test
2020from .modular_kernel_tools .common import (Config , RankTensors , WeightTensors ,
2121 reference_moe_impl ,
2222 run_modular_kernel )
@@ -122,7 +122,8 @@ def rank_worker(
122122
123123
124124def run (config : Config , verbose : bool ):
125- assert config .is_valid ()
125+ assert config .is_valid ()[0 ]
126+ assert not is_nyi_config (config )
126127
127128 weights : WeightTensors = WeightTensors .make (config )
128129
@@ -156,24 +157,63 @@ def is_nyi_config(config: Config) -> bool:
156157 return not info .supports_expert_map
157158
158159
159- @pytest .mark .parametrize ("k" , Ks )
160- @pytest .mark .parametrize ("n" , Ns )
161- @pytest .mark .parametrize ("e" , Es )
162- @pytest .mark .parametrize ("dtype" , DTYPEs )
163- @pytest .mark .parametrize ("quant_config" , MK_QUANT_CONFIGS )
160+ def generate_valid_test_cases (world_size : int ,
161+ prepare_finalize_types ) -> list [tuple [Any , ...]]:
162+ cases = []
163+
164+ for k , n , e , dtype , quant_config , combination , chunk_size in product (
165+ Ks , Ns , Es , DTYPEs , MK_QUANT_CONFIGS ,
166+ product (prepare_finalize_types , MK_FUSED_EXPERT_TYPES ),
167+ FUSED_MOE_CHUNK_SIZEs ):
168+
169+ config = Config (
170+ Ms = Ms ,
171+ K = k ,
172+ N = n ,
173+ E = e ,
174+ topks = TOPKs ,
175+ dtype = dtype ,
176+ quant_config = quant_config ,
177+ prepare_finalize_type = combination [0 ],
178+ fused_experts_type = combination [1 ],
179+ fused_moe_chunk_size = chunk_size ,
180+ world_size = world_size ,
181+ )
182+
183+ # TODO
184+ verbose = False #pytestconfig.getoption('verbose') > 0
185+
186+ valid , reason = config .is_valid ()
187+
188+ if not valid :
189+ if verbose :
190+ print (f"Tests config { config } is not valid: { reason } " )
191+ continue
192+
193+ if is_nyi_config (config ):
194+ if verbose :
195+ print (f"Tests config { config } is nyi." )
196+ continue
197+
198+ cases .append ((k , n , e , dtype , quant_config , combination [0 ],
199+ combination [1 ], chunk_size , world_size ))
200+
201+ return cases
202+
203+
164204@pytest .mark .parametrize (
165- "combination" ,
166- product (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES , MK_FUSED_EXPERT_TYPES ))
167- @pytest .mark .parametrize ("fused_moe_chunk_size" , FUSED_MOE_CHUNK_SIZEs )
168- @pytest .mark .parametrize ("world_size" , [2 ])
169- @multi_gpu_test (num_gpus = 2 )
205+ "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size" ,
206+ generate_valid_test_cases (
207+ world_size = 2 ,
208+ prepare_finalize_types = MK_MULTI_GPU_PREPARE_FINALIZE_TYPES ))
170209@meets_multi_gpu_requirements
171210def test_modular_kernel_combinations_multigpu (
172211 k : int , n : int , e : int , dtype : torch .dtype ,
173212 quant_config : Optional [TestMoEQuantConfig ],
174- combination : tuple [mk .FusedMoEPrepareAndFinalize ,
175- mk .FusedMoEPermuteExpertsUnpermute ],
176- fused_moe_chunk_size : Optional [int ], world_size : int , pytestconfig ):
213+ prepare_finalize_type : mk .FusedMoEPrepareAndFinalize ,
214+ fused_experts_type : mk .FusedMoEPermuteExpertsUnpermute ,
215+ chunk_size : Optional [int ], world_size : int , pytestconfig ):
216+ assert cuda_device_count_stateless () >= world_size
177217
178218 config = Config (
179219 Ms = Ms ,
@@ -183,38 +223,26 @@ def test_modular_kernel_combinations_multigpu(
183223 topks = TOPKs ,
184224 dtype = dtype ,
185225 quant_config = quant_config ,
186- prepare_finalize_type = combination [ 0 ] ,
187- fused_experts_type = combination [ 1 ] ,
188- fused_moe_chunk_size = fused_moe_chunk_size ,
226+ prepare_finalize_type = prepare_finalize_type ,
227+ fused_experts_type = fused_experts_type ,
228+ fused_moe_chunk_size = chunk_size ,
189229 world_size = world_size ,
190230 )
191-
192- if not config .is_valid ():
193- pytest .skip (f"Tests config { config } is not valid. Skipping ..." )
194-
195- if is_nyi_config (config ):
196- pytest .skip (f"Tests config { config } is nyi. Skipping ..." )
197-
198231 verbosity = pytestconfig .getoption ('verbose' )
199232 run (config , verbosity > 0 )
200233
201234
202- @pytest .mark .parametrize ("k" , Ks )
203- @pytest .mark .parametrize ("n" , Ns )
204- @pytest .mark .parametrize ("e" , Es )
205- @pytest .mark .parametrize ("dtype" , DTYPEs )
206- @pytest .mark .parametrize ("quant_config" , MK_QUANT_CONFIGS )
207235@pytest .mark .parametrize (
208- "combination " ,
209- product ( MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES , MK_FUSED_EXPERT_TYPES ))
210- @ pytest . mark . parametrize ( "fused_moe_chunk_size" , FUSED_MOE_CHUNK_SIZEs )
211- @ pytest . mark . parametrize ( "world_size" , [ 1 ] )
236+ "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size " ,
237+ generate_valid_test_cases (
238+ world_size = 1 ,
239+ prepare_finalize_types = MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES ) )
212240def test_modular_kernel_combinations_singlegpu (
213241 k : int , n : int , e : int , dtype : torch .dtype ,
214242 quant_config : Optional [TestMoEQuantConfig ],
215- combination : tuple [ mk .FusedMoEPrepareAndFinalize ,
216- mk .FusedMoEPermuteExpertsUnpermute ] ,
217- fused_moe_chunk_size : Optional [int ], world_size : int , pytestconfig ):
243+ prepare_finalize_type : mk .FusedMoEPrepareAndFinalize ,
244+ fused_experts_type : mk .FusedMoEPermuteExpertsUnpermute ,
245+ chunk_size : Optional [int ], world_size : int , pytestconfig ):
218246 config = Config (
219247 Ms = Ms ,
220248 K = k ,
@@ -223,18 +251,12 @@ def test_modular_kernel_combinations_singlegpu(
223251 topks = TOPKs ,
224252 dtype = dtype ,
225253 quant_config = quant_config ,
226- prepare_finalize_type = combination [ 0 ] ,
227- fused_experts_type = combination [ 1 ] ,
228- fused_moe_chunk_size = fused_moe_chunk_size ,
254+ prepare_finalize_type = prepare_finalize_type ,
255+ fused_experts_type = fused_experts_type ,
256+ fused_moe_chunk_size = chunk_size ,
229257 world_size = world_size ,
230258 )
231259
232- if not config .is_valid ():
233- pytest .skip (f"Tests config { config } is not valid. Skipping ..." )
234-
235- if is_nyi_config (config ):
236- pytest .skip (f"Tests config { config } is nyi. Skipping ..." )
237-
238260 verbosity = pytestconfig .getoption ('verbose' )
239261 run (config , verbosity > 0 )
240262
0 commit comments