Skip to content
37 changes: 23 additions & 14 deletions tests/kernels/moe/modular_kernel_tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,18 @@ def all2all_backend(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend

def is_valid(self):
def is_valid(self) -> tuple[bool, Optional[str]]:
# Check prepare-finalize and fused-experts compatibility
if self.is_batched_prepare_finalize():
if not self.is_batched_fused_experts():
return False
return False, "Mismatched format."
else:
if not self.is_standard_fused_experts():
return False
return False, "Mismatched format."

use_chunking = self.fused_moe_chunk_size is not None
if use_chunking and not self.is_fe_supports_chunking():
return False
return False, "Chunking not supported."

# Check quantization sanity
if (
Expand All @@ -229,42 +229,51 @@ def is_valid(self):
+ int(self.quant_block_shape is not None)
) > 1:
# invalid quant config
return False
return False, f"Bad quant_config {self.quant_config}."

# check type support
if self.quant_dtype is None:
if (
self.dtype not in self.pf_supported_types()
or self.dtype not in self.fe_supported_types()
):
return False
return False, (
f"Unsupported type {self.dtype} not in "
f"{self.pf_supported_types()} and "
f"{self.fe_supported_types()}."
)
else:
if (
self.quant_dtype not in self.pf_supported_types()
or self.quant_dtype not in self.fe_supported_types()
):
return False
return False, (
f"Unsupported quant type {self.quant_dtype} "
f"not in {self.pf_supported_types()} and "
f"{self.fe_supported_types()}."
)

# Check block quanization support
is_block_quatized = self.quant_block_shape is not None
if is_block_quatized and self.quant_dtype is None:
return False
return False, "No block quantization support."

if is_block_quatized and not self.is_block_quant_supported():
return False
return False, "Mismatched block quantization support."

# deep_gemm only works with block-quantized
if self.needs_deep_gemm() and not is_block_quatized:
return False
return False, "Needs DeepGEMM but not block quantized."

# Check dependencies (turn into asserts?)
if self.needs_deep_ep() and not has_deep_ep():
return False
return False, "Needs DeepEP, but DeepEP not available."
if self.needs_deep_gemm() and not has_deep_gemm():
return False
return False, "Needs DeepGEMM, but DeepGEMM not available."
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
return False
return False, "Needs PPLX, but PPLX not available."

return True
return True, None


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def add_to_results(
)

success = None
if config.is_valid():
if config.is_valid()[0]:
print(f"Running config : {config.describe()} ...")
try:
weights: WeightTensors = WeightTensors.make(config)
Expand Down
26 changes: 12 additions & 14 deletions tests/kernels/moe/modular_kernel_tools/mk_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def expert_info(kind) -> ExpertInfo:
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
standard_format,
nvfp4_types,
nvfp4_types + fp8_types,
blocked_quantization_support=True,
backend=None,
force_multigpu=True,
Expand All @@ -254,7 +254,7 @@ def expert_info(kind) -> ExpertInfo:
register_experts(
FlashInferExperts,
standard_format,
nvfp4_types,
nvfp4_types + fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
# Note: this is a hack to get it to run for now
Expand All @@ -274,17 +274,15 @@ def expert_info(kind) -> ExpertInfo:
needs_matching_quant=False,
needs_deep_gemm=True,
)
(
register_experts(
DeepGemmExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
),
register_experts(
DeepGemmExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
)
register_experts(
BatchedTritonOrDeepGemmExperts,
Expand Down Expand Up @@ -464,7 +462,7 @@ def make_fused_experts(
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts {quant_config} ...")
print(f"Making DeepGemmExperts {quant_config} ...")
experts = DeepGemmExperts(quant_config)
elif fused_experts_type == TritonExperts:
kwargs = quant_kwargs
Expand Down
144 changes: 96 additions & 48 deletions tests/kernels/moe/test_modular_kernel_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@
import textwrap
import traceback
from itertools import product
from typing import Optional
from typing import Any, Optional

import pytest
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe

from ...utils import multi_gpu_test
from .modular_kernel_tools.common import (
Config,
RankTensors,
Expand Down Expand Up @@ -132,7 +131,8 @@ def rank_worker(


def run(config: Config, verbose: bool):
assert config.is_valid()
assert config.is_valid()[0]
assert not is_nyi_config(config)

weights: WeightTensors = WeightTensors.make(config)

Expand Down Expand Up @@ -168,31 +168,97 @@ def is_nyi_config(config: Config) -> bool:
return not info.supports_expert_map


@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
def generate_valid_test_cases(
world_size: int, prepare_finalize_types
) -> list[tuple[Any, ...]]:
cases = []
total = 0

for k, n, e, dtype, quant_config, combination, chunk_size in product(
Ks,
Ns,
Es,
DTYPEs,
MK_QUANT_CONFIGS,
product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
FUSED_MOE_CHUNK_SIZEs,
):
total = total + 1

config = Config(
Ms=Ms,
K=k,
N=n,
E=e,
topks=TOPKs,
dtype=dtype,
quant_config=quant_config,
prepare_finalize_type=combination[0],
fused_experts_type=combination[1],
fused_moe_chunk_size=chunk_size,
world_size=world_size,
)

# TODO(bnell): figure out how to get verbose flag here.
verbose = False # pytestconfig.getoption('verbose') > 0

valid, reason = config.is_valid()

if not valid:
if verbose:
print(f"Test config {config} is not valid: {reason}")
continue

if is_nyi_config(config):
if verbose:
print(f"Test config {config} is nyi.")
continue

cases.append(
(
k,
n,
e,
dtype,
quant_config,
combination[0],
combination[1],
chunk_size,
world_size,
)
)

print(f"{len(cases)} of {total} valid configs generated.")

return cases


@pytest.mark.parametrize(
"combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
generate_valid_test_cases(
world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
),
)
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2])
@multi_gpu_test(num_gpus=2)
@meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu(
k: int,
n: int,
e: int,
dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
],
fused_moe_chunk_size: Optional[int],
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
chunk_size: Optional[int],
world_size: int,
pytestconfig,
):
if cuda_device_count_stateless() < world_size:
pytest.skip(
f"Not enough GPUs available to run, got "
f"{cuda_device_count_stateless()} exepected "
f"{world_size}."
)

config = Config(
Ms=Ms,
K=k,
Expand All @@ -201,42 +267,30 @@ def test_modular_kernel_combinations_multigpu(
topks=TOPKs,
dtype=dtype,
quant_config=quant_config,
prepare_finalize_type=combination[0],
fused_experts_type=combination[1],
fused_moe_chunk_size=fused_moe_chunk_size,
prepare_finalize_type=prepare_finalize_type,
fused_experts_type=fused_experts_type,
fused_moe_chunk_size=chunk_size,
world_size=world_size,
)

if not config.is_valid():
pytest.skip(f"Tests config {config} is not valid. Skipping ...")

if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")

verbosity = pytestconfig.getoption("verbose")
run(config, verbosity > 0)


@pytest.mark.parametrize("k", Ks)
@pytest.mark.parametrize("n", Ns)
@pytest.mark.parametrize("e", Es)
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
"combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
generate_valid_test_cases(
world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
),
)
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1])
def test_modular_kernel_combinations_singlegpu(
k: int,
n: int,
e: int,
dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
],
fused_moe_chunk_size: Optional[int],
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
chunk_size: Optional[int],
world_size: int,
pytestconfig,
):
Expand All @@ -248,18 +302,12 @@ def test_modular_kernel_combinations_singlegpu(
topks=TOPKs,
dtype=dtype,
quant_config=quant_config,
prepare_finalize_type=combination[0],
fused_experts_type=combination[1],
fused_moe_chunk_size=fused_moe_chunk_size,
prepare_finalize_type=prepare_finalize_type,
fused_experts_type=fused_experts_type,
fused_moe_chunk_size=chunk_size,
world_size=world_size,
)

if not config.is_valid():
pytest.skip(f"Tests config {config} is not valid. Skipping ...")

if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")

verbosity = pytestconfig.getoption("verbose")
run(config, verbosity > 0)

Expand Down
Loading