Skip to content

Commit 35fb647

Browse files
bnellnmxuebwang-amd
authored andcommitted
[Kernels] Modular kernel refactor (vllm-project#24812)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 73279f0 commit 35fb647

22 files changed

+597
-505
lines changed

tests/kernels/moe/modular_kernel_tools/common.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,18 @@ def all2all_backend(self):
209209
info = prepare_finalize_info(self.prepare_finalize_type)
210210
return info.backend
211211

212-
def is_valid(self):
212+
def is_valid(self) -> tuple[bool, Optional[str]]:
213213
# Check prepare-finalize and fused-experts compatibility
214214
if self.is_batched_prepare_finalize():
215215
if not self.is_batched_fused_experts():
216-
return False
216+
return False, "Mismatched format."
217217
else:
218218
if not self.is_standard_fused_experts():
219-
return False
219+
return False, "Mismatched format."
220220

221221
use_chunking = self.fused_moe_chunk_size is not None
222222
if use_chunking and not self.is_fe_supports_chunking():
223-
return False
223+
return False, "Chunking not supported."
224224

225225
# Check quantization sanity
226226
if (
@@ -229,42 +229,51 @@ def is_valid(self):
229229
+ int(self.quant_block_shape is not None)
230230
) > 1:
231231
# invalid quant config
232-
return False
232+
return False, f"Bad quant_config {self.quant_config}."
233233

234234
# check type support
235235
if self.quant_dtype is None:
236236
if (
237237
self.dtype not in self.pf_supported_types()
238238
or self.dtype not in self.fe_supported_types()
239239
):
240-
return False
240+
return False, (
241+
f"Unsupported type {self.dtype} not in "
242+
f"{self.pf_supported_types()} and "
243+
f"{self.fe_supported_types()}."
244+
)
241245
else:
242246
if (
243247
self.quant_dtype not in self.pf_supported_types()
244248
or self.quant_dtype not in self.fe_supported_types()
245249
):
246-
return False
250+
return False, (
251+
f"Unsupported quant type {self.quant_dtype} "
252+
f"not in {self.pf_supported_types()} and "
253+
f"{self.fe_supported_types()}."
254+
)
247255

248256
# Check block quanization support
249257
is_block_quatized = self.quant_block_shape is not None
250258
if is_block_quatized and self.quant_dtype is None:
251-
return False
259+
return False, "No block quantization support."
260+
252261
if is_block_quatized and not self.is_block_quant_supported():
253-
return False
262+
return False, "Mismatched block quantization support."
254263

255264
# deep_gemm only works with block-quantized
256265
if self.needs_deep_gemm() and not is_block_quatized:
257-
return False
266+
return False, "Needs DeepGEMM but not block quantized."
258267

259268
# Check dependencies (turn into asserts?)
260269
if self.needs_deep_ep() and not has_deep_ep():
261-
return False
270+
return False, "Needs DeepEP, but DeepEP not available."
262271
if self.needs_deep_gemm() and not has_deep_gemm():
263-
return False
272+
return False, "Needs DeepGEMM, but DeepGEMM not available."
264273
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
265-
return False
274+
return False, "Needs PPLX, but PPLX not available."
266275

267-
return True
276+
return True, None
268277

269278

270279
@dataclass

tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def add_to_results(
140140
)
141141

142142
success = None
143-
if config.is_valid():
143+
if config.is_valid()[0]:
144144
print(f"Running config : {config.describe()} ...")
145145
try:
146146
weights: WeightTensors = WeightTensors.make(config)

tests/kernels/moe/modular_kernel_tools/mk_objects.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def expert_info(kind) -> ExpertInfo:
244244
register_prepare_and_finalize(
245245
FlashInferCutlassMoEPrepareAndFinalize,
246246
standard_format,
247-
nvfp4_types,
247+
nvfp4_types + fp8_types,
248248
blocked_quantization_support=True,
249249
backend=None,
250250
force_multigpu=True,
@@ -254,7 +254,7 @@ def expert_info(kind) -> ExpertInfo:
254254
register_experts(
255255
FlashInferExperts,
256256
standard_format,
257-
nvfp4_types,
257+
nvfp4_types + fp8_types,
258258
blocked_quantization_support=True,
259259
supports_chunking=True,
260260
# Note: this is a hack to get it to run for now
@@ -274,17 +274,15 @@ def expert_info(kind) -> ExpertInfo:
274274
needs_matching_quant=False,
275275
needs_deep_gemm=True,
276276
)
277-
(
278-
register_experts(
279-
DeepGemmExperts,
280-
standard_format,
281-
fp8_types,
282-
blocked_quantization_support=True,
283-
supports_chunking=True,
284-
supports_expert_map=True,
285-
needs_matching_quant=False,
286-
needs_deep_gemm=True,
287-
),
277+
register_experts(
278+
DeepGemmExperts,
279+
standard_format,
280+
fp8_types,
281+
blocked_quantization_support=True,
282+
supports_chunking=True,
283+
supports_expert_map=True,
284+
needs_matching_quant=False,
285+
needs_deep_gemm=True,
288286
)
289287
register_experts(
290288
BatchedTritonOrDeepGemmExperts,
@@ -464,7 +462,7 @@ def make_fused_experts(
464462
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
465463
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
466464
elif fused_experts_type == DeepGemmExperts:
467-
print("Making DeepGemmExperts {quant_config} ...")
465+
print(f"Making DeepGemmExperts {quant_config} ...")
468466
experts = DeepGemmExperts(quant_config)
469467
elif fused_experts_type == TritonExperts:
470468
kwargs = quant_kwargs

tests/kernels/moe/test_modular_kernel_combinations.py

Lines changed: 96 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,17 @@
55
import textwrap
66
import traceback
77
from itertools import product
8-
from typing import Optional
8+
from typing import Any, Optional
99

1010
import pytest
1111
import torch
1212

1313
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
1414
from vllm.config import VllmConfig, set_current_vllm_config
1515
from vllm.platforms import current_platform
16-
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
16+
from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx
1717
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
1818

19-
from ...utils import multi_gpu_test
2019
from .modular_kernel_tools.common import (
2120
Config,
2221
RankTensors,
@@ -132,7 +131,8 @@ def rank_worker(
132131

133132

134133
def run(config: Config, verbose: bool):
135-
assert config.is_valid()
134+
assert config.is_valid()[0]
135+
assert not is_nyi_config(config)
136136

137137
weights: WeightTensors = WeightTensors.make(config)
138138

@@ -168,31 +168,97 @@ def is_nyi_config(config: Config) -> bool:
168168
return not info.supports_expert_map
169169

170170

171-
@pytest.mark.parametrize("k", Ks)
172-
@pytest.mark.parametrize("n", Ns)
173-
@pytest.mark.parametrize("e", Es)
174-
@pytest.mark.parametrize("dtype", DTYPEs)
175-
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
171+
def generate_valid_test_cases(
172+
world_size: int, prepare_finalize_types
173+
) -> list[tuple[Any, ...]]:
174+
cases = []
175+
total = 0
176+
177+
for k, n, e, dtype, quant_config, combination, chunk_size in product(
178+
Ks,
179+
Ns,
180+
Es,
181+
DTYPEs,
182+
MK_QUANT_CONFIGS,
183+
product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
184+
FUSED_MOE_CHUNK_SIZEs,
185+
):
186+
total = total + 1
187+
188+
config = Config(
189+
Ms=Ms,
190+
K=k,
191+
N=n,
192+
E=e,
193+
topks=TOPKs,
194+
dtype=dtype,
195+
quant_config=quant_config,
196+
prepare_finalize_type=combination[0],
197+
fused_experts_type=combination[1],
198+
fused_moe_chunk_size=chunk_size,
199+
world_size=world_size,
200+
)
201+
202+
# TODO(bnell): figure out how to get verbose flag here.
203+
verbose = False # pytestconfig.getoption('verbose') > 0
204+
205+
valid, reason = config.is_valid()
206+
207+
if not valid:
208+
if verbose:
209+
print(f"Test config {config} is not valid: {reason}")
210+
continue
211+
212+
if is_nyi_config(config):
213+
if verbose:
214+
print(f"Test config {config} is nyi.")
215+
continue
216+
217+
cases.append(
218+
(
219+
k,
220+
n,
221+
e,
222+
dtype,
223+
quant_config,
224+
combination[0],
225+
combination[1],
226+
chunk_size,
227+
world_size,
228+
)
229+
)
230+
231+
print(f"{len(cases)} of {total} valid configs generated.")
232+
233+
return cases
234+
235+
176236
@pytest.mark.parametrize(
177-
"combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
237+
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
238+
generate_valid_test_cases(
239+
world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
240+
),
178241
)
179-
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
180-
@pytest.mark.parametrize("world_size", [2])
181-
@multi_gpu_test(num_gpus=2)
182242
@meets_multi_gpu_requirements
183243
def test_modular_kernel_combinations_multigpu(
184244
k: int,
185245
n: int,
186246
e: int,
187247
dtype: torch.dtype,
188248
quant_config: Optional[TestMoEQuantConfig],
189-
combination: tuple[
190-
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
191-
],
192-
fused_moe_chunk_size: Optional[int],
249+
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
250+
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
251+
chunk_size: Optional[int],
193252
world_size: int,
194253
pytestconfig,
195254
):
255+
if cuda_device_count_stateless() < world_size:
256+
pytest.skip(
257+
f"Not enough GPUs available to run, got "
258+
f"{cuda_device_count_stateless()} exepected "
259+
f"{world_size}."
260+
)
261+
196262
config = Config(
197263
Ms=Ms,
198264
K=k,
@@ -201,42 +267,30 @@ def test_modular_kernel_combinations_multigpu(
201267
topks=TOPKs,
202268
dtype=dtype,
203269
quant_config=quant_config,
204-
prepare_finalize_type=combination[0],
205-
fused_experts_type=combination[1],
206-
fused_moe_chunk_size=fused_moe_chunk_size,
270+
prepare_finalize_type=prepare_finalize_type,
271+
fused_experts_type=fused_experts_type,
272+
fused_moe_chunk_size=chunk_size,
207273
world_size=world_size,
208274
)
209-
210-
if not config.is_valid():
211-
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
212-
213-
if is_nyi_config(config):
214-
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
215-
216275
verbosity = pytestconfig.getoption("verbose")
217276
run(config, verbosity > 0)
218277

219278

220-
@pytest.mark.parametrize("k", Ks)
221-
@pytest.mark.parametrize("n", Ns)
222-
@pytest.mark.parametrize("e", Es)
223-
@pytest.mark.parametrize("dtype", DTYPEs)
224-
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
225279
@pytest.mark.parametrize(
226-
"combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
280+
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
281+
generate_valid_test_cases(
282+
world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
283+
),
227284
)
228-
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
229-
@pytest.mark.parametrize("world_size", [1])
230285
def test_modular_kernel_combinations_singlegpu(
231286
k: int,
232287
n: int,
233288
e: int,
234289
dtype: torch.dtype,
235290
quant_config: Optional[TestMoEQuantConfig],
236-
combination: tuple[
237-
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
238-
],
239-
fused_moe_chunk_size: Optional[int],
291+
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
292+
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
293+
chunk_size: Optional[int],
240294
world_size: int,
241295
pytestconfig,
242296
):
@@ -248,18 +302,12 @@ def test_modular_kernel_combinations_singlegpu(
248302
topks=TOPKs,
249303
dtype=dtype,
250304
quant_config=quant_config,
251-
prepare_finalize_type=combination[0],
252-
fused_experts_type=combination[1],
253-
fused_moe_chunk_size=fused_moe_chunk_size,
305+
prepare_finalize_type=prepare_finalize_type,
306+
fused_experts_type=fused_experts_type,
307+
fused_moe_chunk_size=chunk_size,
254308
world_size=world_size,
255309
)
256310

257-
if not config.is_valid():
258-
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
259-
260-
if is_nyi_config(config):
261-
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
262-
263311
verbosity = pytestconfig.getoption("verbose")
264312
run(config, verbosity > 0)
265313

0 commit comments

Comments
 (0)