Skip to content

Commit d4dacd7

Browse files
committed
more refactoring + filter out invalid tests ahead of time
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 2428aa1 commit d4dacd7

File tree

5 files changed

+181
-134
lines changed

5 files changed

+181
-134
lines changed

tests/kernels/moe/modular_kernel_tools/common.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -197,56 +197,57 @@ def all2all_backend(self):
197197
info = prepare_finalize_info(self.prepare_finalize_type)
198198
return info.backend
199199

200-
def is_valid(self):
200+
def is_valid(self) -> tuple[bool, Optional[str]]:
201201
# Check prepare-finalize and fused-experts compatibility
202202
if self.is_batched_prepare_finalize():
203203
if not self.is_batched_fused_experts():
204-
return False
204+
return False, "Mismatched format."
205205
else:
206206
if not self.is_standard_fused_experts():
207-
return False
207+
return False, "Mismatched format."
208208

209209
use_chunking = self.fused_moe_chunk_size is not None
210210
if use_chunking and not self.is_fe_supports_chunking():
211-
return False
211+
return False, "Chunking not supported."
212212

213213
# Check quantization sanity
214214
if (int(self.is_per_act_token_quant) +
215215
int(self.is_per_tensor_act_quant) +
216216
int(self.quant_block_shape is not None)) > 1:
217217
# invalid quant config
218-
return False
218+
return False, "Bad quant_config."
219219

220220
# check type support
221221
if self.quant_dtype is None:
222222
if (self.dtype not in self.pf_supported_types()
223223
or self.dtype not in self.fe_supported_types()):
224-
return False
224+
return False, "Unsupported type 1."
225225
else:
226226
if (self.quant_dtype not in self.pf_supported_types()
227227
or self.quant_dtype not in self.fe_supported_types()):
228-
return False
228+
return False, "Unsupported type 2."
229229

230230
# Check block quanization support
231231
is_block_quatized = self.quant_block_shape is not None
232232
if is_block_quatized and self.quant_dtype is None:
233-
return False
233+
return False, "No block quantization support."
234+
234235
if is_block_quatized and not self.is_block_quant_supported():
235-
return False
236+
return False, "Mismatched block quantization support."
236237

237238
# deep_gemm only works with block-quantized
238239
if self.needs_deep_gemm() and not is_block_quatized:
239-
return False
240+
return False, "Needs DeepGEMM but not block quantized."
240241

241242
# Check dependencies (turn into asserts?)
242243
if self.needs_deep_ep() and not has_deep_ep():
243-
return False
244+
return False, "Needs DeepEP."
244245
if self.needs_deep_gemm() and not has_deep_gemm():
245-
return False
246+
return False, "Needs DeepGEMM."
246247
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
247-
return False
248+
return False, "Needs PPLX."
248249

249-
return True
250+
return True, None
250251

251252

252253
@dataclass

tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def add_to_results(config: Config,
130130
fused_moe_chunk_size=None)
131131

132132
success = None
133-
if config.is_valid():
133+
if config.is_valid()[0]:
134134
print(f"Running config : {config.describe()} ...")
135135
try:
136136
weights: WeightTensors = WeightTensors.make(config)

tests/kernels/moe/modular_kernel_tools/mk_objects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def make_fused_experts(
430430
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
431431
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
432432
elif fused_experts_type == DeepGemmExperts:
433-
print("Making DeepGemmExperts {quant_config} ...")
433+
print(f"Making DeepGemmExperts {quant_config} ...")
434434
experts = DeepGemmExperts(quant_config)
435435
elif fused_experts_type == TritonExperts:
436436
kwargs = quant_kwargs

tests/kernels/moe/test_modular_kernel_combinations.py

Lines changed: 70 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
import textwrap
66
import traceback
77
from itertools import product
8-
from typing import Optional
8+
from typing import Any, Optional
99

1010
import pytest
1111
import torch
1212

1313
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
1414
from vllm.config import VllmConfig, set_current_vllm_config
1515
from vllm.platforms import current_platform
16-
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
16+
from vllm.utils import (cuda_device_count_stateless, has_deep_ep,
17+
has_deep_gemm, has_pplx)
1718
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
1819

19-
from ...utils import multi_gpu_test
2020
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
2121
reference_moe_impl,
2222
run_modular_kernel)
@@ -122,7 +122,8 @@ def rank_worker(
122122

123123

124124
def run(config: Config, verbose: bool):
125-
assert config.is_valid()
125+
assert config.is_valid()[0]
126+
assert not is_nyi_config(config)
126127

127128
weights: WeightTensors = WeightTensors.make(config)
128129

@@ -156,24 +157,63 @@ def is_nyi_config(config: Config) -> bool:
156157
return not info.supports_expert_map
157158

158159

159-
@pytest.mark.parametrize("k", Ks)
160-
@pytest.mark.parametrize("n", Ns)
161-
@pytest.mark.parametrize("e", Es)
162-
@pytest.mark.parametrize("dtype", DTYPEs)
163-
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
160+
def generate_valid_test_cases(world_size: int,
161+
prepare_finalize_types) -> list[tuple[Any, ...]]:
162+
cases = []
163+
164+
for k, n, e, dtype, quant_config, combination, chunk_size in product(
165+
Ks, Ns, Es, DTYPEs, MK_QUANT_CONFIGS,
166+
product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
167+
FUSED_MOE_CHUNK_SIZEs):
168+
169+
config = Config(
170+
Ms=Ms,
171+
K=k,
172+
N=n,
173+
E=e,
174+
topks=TOPKs,
175+
dtype=dtype,
176+
quant_config=quant_config,
177+
prepare_finalize_type=combination[0],
178+
fused_experts_type=combination[1],
179+
fused_moe_chunk_size=chunk_size,
180+
world_size=world_size,
181+
)
182+
183+
# TODO
184+
verbose = False #pytestconfig.getoption('verbose') > 0
185+
186+
valid, reason = config.is_valid()
187+
188+
if not valid:
189+
if verbose:
190+
print(f"Tests config {config} is not valid: {reason}")
191+
continue
192+
193+
if is_nyi_config(config):
194+
if verbose:
195+
print(f"Tests config {config} is nyi.")
196+
continue
197+
198+
cases.append((k, n, e, dtype, quant_config, combination[0],
199+
combination[1], chunk_size, world_size))
200+
201+
return cases
202+
203+
164204
@pytest.mark.parametrize(
165-
"combination",
166-
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
167-
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
168-
@pytest.mark.parametrize("world_size", [2])
169-
@multi_gpu_test(num_gpus=2)
205+
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
206+
generate_valid_test_cases(
207+
world_size=2,
208+
prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES))
170209
@meets_multi_gpu_requirements
171210
def test_modular_kernel_combinations_multigpu(
172211
k: int, n: int, e: int, dtype: torch.dtype,
173212
quant_config: Optional[TestMoEQuantConfig],
174-
combination: tuple[mk.FusedMoEPrepareAndFinalize,
175-
mk.FusedMoEPermuteExpertsUnpermute],
176-
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
213+
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
214+
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
215+
chunk_size: Optional[int], world_size: int, pytestconfig):
216+
assert cuda_device_count_stateless() >= world_size
177217

178218
config = Config(
179219
Ms=Ms,
@@ -183,38 +223,26 @@ def test_modular_kernel_combinations_multigpu(
183223
topks=TOPKs,
184224
dtype=dtype,
185225
quant_config=quant_config,
186-
prepare_finalize_type=combination[0],
187-
fused_experts_type=combination[1],
188-
fused_moe_chunk_size=fused_moe_chunk_size,
226+
prepare_finalize_type=prepare_finalize_type,
227+
fused_experts_type=fused_experts_type,
228+
fused_moe_chunk_size=chunk_size,
189229
world_size=world_size,
190230
)
191-
192-
if not config.is_valid():
193-
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
194-
195-
if is_nyi_config(config):
196-
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
197-
198231
verbosity = pytestconfig.getoption('verbose')
199232
run(config, verbosity > 0)
200233

201234

202-
@pytest.mark.parametrize("k", Ks)
203-
@pytest.mark.parametrize("n", Ns)
204-
@pytest.mark.parametrize("e", Es)
205-
@pytest.mark.parametrize("dtype", DTYPEs)
206-
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
207235
@pytest.mark.parametrize(
208-
"combination",
209-
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
210-
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
211-
@pytest.mark.parametrize("world_size", [1])
236+
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
237+
generate_valid_test_cases(
238+
world_size=1,
239+
prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES))
212240
def test_modular_kernel_combinations_singlegpu(
213241
k: int, n: int, e: int, dtype: torch.dtype,
214242
quant_config: Optional[TestMoEQuantConfig],
215-
combination: tuple[mk.FusedMoEPrepareAndFinalize,
216-
mk.FusedMoEPermuteExpertsUnpermute],
217-
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
243+
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
244+
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
245+
chunk_size: Optional[int], world_size: int, pytestconfig):
218246
config = Config(
219247
Ms=Ms,
220248
K=k,
@@ -223,18 +251,12 @@ def test_modular_kernel_combinations_singlegpu(
223251
topks=TOPKs,
224252
dtype=dtype,
225253
quant_config=quant_config,
226-
prepare_finalize_type=combination[0],
227-
fused_experts_type=combination[1],
228-
fused_moe_chunk_size=fused_moe_chunk_size,
254+
prepare_finalize_type=prepare_finalize_type,
255+
fused_experts_type=fused_experts_type,
256+
fused_moe_chunk_size=chunk_size,
229257
world_size=world_size,
230258
)
231259

232-
if not config.is_valid():
233-
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
234-
235-
if is_nyi_config(config):
236-
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
237-
238260
verbosity = pytestconfig.getoption('verbose')
239261
run(config, verbosity > 0)
240262

0 commit comments

Comments
 (0)