55import  textwrap 
66import  traceback 
77from  itertools  import  product 
8- from  typing  import  Optional 
8+ from  typing  import  Any ,  Optional 
99
1010import  pytest 
1111import  torch 
1212
1313import  vllm .model_executor .layers .fused_moe .modular_kernel  as  mk 
1414from  vllm .config  import  VllmConfig , set_current_vllm_config 
1515from  vllm .platforms  import  current_platform 
16- from  vllm .utils  import  has_deep_ep , has_deep_gemm , has_pplx 
16+ from  vllm .utils  import  cuda_device_count_stateless ,  has_deep_ep , has_deep_gemm , has_pplx 
1717from  vllm .utils .flashinfer  import  has_flashinfer_cutlass_fused_moe 
1818
19- from  ...utils  import  multi_gpu_test 
2019from  .modular_kernel_tools .common  import  (
2120    Config ,
2221    RankTensors ,
@@ -132,7 +131,8 @@ def rank_worker(
132131
133132
134133def  run (config : Config , verbose : bool ):
135-     assert  config .is_valid ()
134+     assert  config .is_valid ()[0 ]
135+     assert  not  is_nyi_config (config )
136136
137137    weights : WeightTensors  =  WeightTensors .make (config )
138138
@@ -168,31 +168,97 @@ def is_nyi_config(config: Config) -> bool:
168168    return  not  info .supports_expert_map 
169169
170170
171- @pytest .mark .parametrize ("k" , Ks ) 
172- @pytest .mark .parametrize ("n" , Ns ) 
173- @pytest .mark .parametrize ("e" , Es ) 
174- @pytest .mark .parametrize ("dtype" , DTYPEs ) 
175- @pytest .mark .parametrize ("quant_config" , MK_QUANT_CONFIGS ) 
171+ def  generate_valid_test_cases (
172+     world_size : int , prepare_finalize_types 
173+ ) ->  list [tuple [Any , ...]]:
174+     cases  =  []
175+     total  =  0 
176+ 
177+     for  k , n , e , dtype , quant_config , combination , chunk_size  in  product (
178+         Ks ,
179+         Ns ,
180+         Es ,
181+         DTYPEs ,
182+         MK_QUANT_CONFIGS ,
183+         product (prepare_finalize_types , MK_FUSED_EXPERT_TYPES ),
184+         FUSED_MOE_CHUNK_SIZEs ,
185+     ):
186+         total  =  total  +  1 
187+ 
188+         config  =  Config (
189+             Ms = Ms ,
190+             K = k ,
191+             N = n ,
192+             E = e ,
193+             topks = TOPKs ,
194+             dtype = dtype ,
195+             quant_config = quant_config ,
196+             prepare_finalize_type = combination [0 ],
197+             fused_experts_type = combination [1 ],
198+             fused_moe_chunk_size = chunk_size ,
199+             world_size = world_size ,
200+         )
201+ 
202+         # TODO(bnell): figure out how to get verbose flag here. 
203+         verbose  =  False   # pytestconfig.getoption('verbose') > 0 
204+ 
205+         valid , reason  =  config .is_valid ()
206+ 
207+         if  not  valid :
208+             if  verbose :
209+                 print (f"Test config { config } { reason }  )
210+             continue 
211+ 
212+         if  is_nyi_config (config ):
213+             if  verbose :
214+                 print (f"Test config { config }  )
215+             continue 
216+ 
217+         cases .append (
218+             (
219+                 k ,
220+                 n ,
221+                 e ,
222+                 dtype ,
223+                 quant_config ,
224+                 combination [0 ],
225+                 combination [1 ],
226+                 chunk_size ,
227+                 world_size ,
228+             )
229+         )
230+ 
231+     print (f"{ len (cases )} { total }  )
232+ 
233+     return  cases 
234+ 
235+ 
176236@pytest .mark .parametrize ( 
177-     "combination" , product (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES , MK_FUSED_EXPERT_TYPES ) 
237+     "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size" , 
238+     generate_valid_test_cases ( 
239+         world_size = 2 , prepare_finalize_types = MK_MULTI_GPU_PREPARE_FINALIZE_TYPES  
240+     ), 
178241) 
179- @pytest .mark .parametrize ("fused_moe_chunk_size" , FUSED_MOE_CHUNK_SIZEs ) 
180- @pytest .mark .parametrize ("world_size" , [2 ]) 
181- @multi_gpu_test (num_gpus = 2 ) 
182242@meets_multi_gpu_requirements  
183243def  test_modular_kernel_combinations_multigpu (
184244    k : int ,
185245    n : int ,
186246    e : int ,
187247    dtype : torch .dtype ,
188248    quant_config : Optional [TestMoEQuantConfig ],
189-     combination : tuple [
190-         mk .FusedMoEPrepareAndFinalize , mk .FusedMoEPermuteExpertsUnpermute 
191-     ],
192-     fused_moe_chunk_size : Optional [int ],
249+     prepare_finalize_type : mk .FusedMoEPrepareAndFinalize ,
250+     fused_experts_type : mk .FusedMoEPermuteExpertsUnpermute ,
251+     chunk_size : Optional [int ],
193252    world_size : int ,
194253    pytestconfig ,
195254):
255+     if  cuda_device_count_stateless () <  world_size :
256+         pytest .skip (
257+             f"Not enough GPUs available to run, got " 
258+             f"{ cuda_device_count_stateless ()}  
259+             f"{ world_size }  
260+         )
261+ 
196262    config  =  Config (
197263        Ms = Ms ,
198264        K = k ,
@@ -201,42 +267,30 @@ def test_modular_kernel_combinations_multigpu(
201267        topks = TOPKs ,
202268        dtype = dtype ,
203269        quant_config = quant_config ,
204-         prepare_finalize_type = combination [ 0 ] ,
205-         fused_experts_type = combination [ 1 ] ,
206-         fused_moe_chunk_size = fused_moe_chunk_size ,
270+         prepare_finalize_type = prepare_finalize_type ,
271+         fused_experts_type = fused_experts_type ,
272+         fused_moe_chunk_size = chunk_size ,
207273        world_size = world_size ,
208274    )
209- 
210-     if  not  config .is_valid ():
211-         pytest .skip (f"Tests config { config }  )
212- 
213-     if  is_nyi_config (config ):
214-         pytest .skip (f"Tests config { config }  )
215- 
216275    verbosity  =  pytestconfig .getoption ("verbose" )
217276    run (config , verbosity  >  0 )
218277
219278
220- @pytest .mark .parametrize ("k" , Ks ) 
221- @pytest .mark .parametrize ("n" , Ns ) 
222- @pytest .mark .parametrize ("e" , Es ) 
223- @pytest .mark .parametrize ("dtype" , DTYPEs ) 
224- @pytest .mark .parametrize ("quant_config" , MK_QUANT_CONFIGS ) 
225279@pytest .mark .parametrize ( 
226-     "combination" , product (MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES , MK_FUSED_EXPERT_TYPES ) 
280+     "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size" , 
281+     generate_valid_test_cases ( 
282+         world_size = 1 , prepare_finalize_types = MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES  
283+     ), 
227284) 
228- @pytest .mark .parametrize ("fused_moe_chunk_size" , FUSED_MOE_CHUNK_SIZEs ) 
229- @pytest .mark .parametrize ("world_size" , [1 ]) 
230285def  test_modular_kernel_combinations_singlegpu (
231286    k : int ,
232287    n : int ,
233288    e : int ,
234289    dtype : torch .dtype ,
235290    quant_config : Optional [TestMoEQuantConfig ],
236-     combination : tuple [
237-         mk .FusedMoEPrepareAndFinalize , mk .FusedMoEPermuteExpertsUnpermute 
238-     ],
239-     fused_moe_chunk_size : Optional [int ],
291+     prepare_finalize_type : mk .FusedMoEPrepareAndFinalize ,
292+     fused_experts_type : mk .FusedMoEPermuteExpertsUnpermute ,
293+     chunk_size : Optional [int ],
240294    world_size : int ,
241295    pytestconfig ,
242296):
@@ -248,18 +302,12 @@ def test_modular_kernel_combinations_singlegpu(
248302        topks = TOPKs ,
249303        dtype = dtype ,
250304        quant_config = quant_config ,
251-         prepare_finalize_type = combination [ 0 ] ,
252-         fused_experts_type = combination [ 1 ] ,
253-         fused_moe_chunk_size = fused_moe_chunk_size ,
305+         prepare_finalize_type = prepare_finalize_type ,
306+         fused_experts_type = fused_experts_type ,
307+         fused_moe_chunk_size = chunk_size ,
254308        world_size = world_size ,
255309    )
256310
257-     if  not  config .is_valid ():
258-         pytest .skip (f"Tests config { config }  )
259- 
260-     if  is_nyi_config (config ):
261-         pytest .skip (f"Tests config { config }  )
262- 
263311    verbosity  =  pytestconfig .getoption ("verbose" )
264312    run (config , verbosity  >  0 )
265313
0 commit comments