Skip to content

[Kernel] Integrate batched/masked deepgemm kernel #19111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/kernels/moe/deepep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,14 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts //
pgi.world_size)

return DeepEPLLPrepareAndFinalize(
buffer=buffer,
world_size=pgi.world_size,
dp_size=dp_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
quant_dtype=q_dtype,
block_shape=block_shape,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
)

Expand All @@ -185,4 +187,5 @@ def make_deepep_a2a(pg: ProcessGroup,
block_shape)

assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype)
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
block_shape)
190 changes: 166 additions & 24 deletions tests/kernels/moe/test_deepep_deepgemm_moe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
"""
Test DeepEP + DeepGEMM integration
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
"""

import dataclasses
Expand Down Expand Up @@ -33,10 +35,14 @@
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)

from .deepep_utils import DeepEPHTArgs, make_deepep_a2a
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a

if has_deep_gemm:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)

Expand All @@ -53,6 +59,13 @@
P = ParamSpec("P")


def next_power_of_2(x):
import math
if x == 0:
return 1
return 2**math.ceil(math.log2(x))


def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -126,6 +139,9 @@ class TestConfig:
n: int
num_experts: int
block_size: list[int]
# configs for testing low-latency kernels
low_latency: bool
use_fp8_dispatch: Optional[bool] = False


@dataclasses.dataclass
Expand Down Expand Up @@ -170,9 +186,43 @@ def make(config: TestConfig, rank) -> "TestTensors":
config=config)


def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, q_dtype: Optional[torch.dtype],
block_shape: list[int]) -> FusedMoEModularKernel:
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
max_tokens_per_rank: int, dp_size: int,
hidden_size: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig) -> FusedMoEModularKernel:

assert test_config.low_latency
assert test_config.use_fp8_dispatch is not None

a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
pg=pg,
pgi=pgi,
dp_size=dp_size,
deepep_ht_args=None,
deepep_ll_args=DeepEPLLArgs(
max_tokens_per_rank=max_tokens_per_rank,
hidden_size=hidden_size,
num_experts=test_config.num_experts,
use_fp8_dispatch=test_config.use_fp8_dispatch),
q_dtype=q_dtype,
block_shape=test_config.block_size)

fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank,
world_size=pgi.world_size,
dp_size=dp_size,
block_shape=test_config.block_size)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk


def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
dp_size: int, num_local_experts: int,
q_dtype: Optional[torch.dtype],
test_config: TestConfig) -> FusedMoEModularKernel:

assert not test_config.low_latency
assert test_config.use_fp8_dispatch is None

a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
pg=pg,
Expand All @@ -181,20 +231,50 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
deepep_ll_args=None,
q_dtype=q_dtype,
block_shape=block_shape)
block_shape=test_config.block_size)

fused_experts = DeepGemmExperts()
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk


def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
num_experts: int) -> torch.Tensor:
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int,
test_tensors: TestTensors) -> FusedMoEModularKernel:

q_dtype = torch.float8_e4m3fn
test_config = test_tensors.config

mk: FusedMoEModularKernel
# Make modular kernel
if test_config.low_latency:
max_tokens_per_rank = max(
64, next_power_of_2(test_tensors.rank_tokens.size(0)))
hidden_size = test_tensors.rank_tokens.size(-1)

mk = make_ll_modular_kernel(pg=pg,
pgi=pgi,
max_tokens_per_rank=max_tokens_per_rank,
dp_size=dp_size,
hidden_size=hidden_size,
q_dtype=q_dtype,
test_config=test_config)
else:
mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts,
q_dtype, test_config)

return mk


def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
dp_size: int, test_tensors: TestTensors,
w1: torch.Tensor, w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor]) -> torch.Tensor:

test_config = test_tensors.config
num_experts = test_config.num_experts
num_local_experts = w1.size(0)

def build_expert_map():
Expand All @@ -208,14 +288,17 @@ def build_expert_map():
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)

q_dtype = torch.float8_e4m3fn

# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, dp_size, num_local_experts, q_dtype,
test_tensors.config.block_size)
pg=pg,
pgi=pgi,
dp_size=dp_size,
num_local_experts=num_local_experts,
test_tensors=test_tensors)

a1_scale = test_tensors.rank_token_scales
# Low-Latency kernels can't dispatch scales.
a1_scale = (None
if test_config.low_latency else test_tensors.rank_token_scales)

out = mk.forward(hidden_states=test_tensors.rank_tokens,
w1=w1,
Expand Down Expand Up @@ -258,7 +341,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
allow_deep_gemm=False)


def _deep_ep_moe(
def _test_deepep_deepgemm_moe(
pgi: ProcessGroupInfo,
dp_size: int,
config: TestConfig,
Expand Down Expand Up @@ -302,7 +385,7 @@ def _deep_ep_moe(
w1_scale_ep = w1_scale[e_start:e_end]
w2_scale_ep = w2_scale[e_start:e_end]

deepep_moe = deep_ep_moe_impl(
deepep_moe = deepep_deepgemm_moe_impl(
pg,
pgi,
dp_size,
Expand All @@ -311,7 +394,6 @@ def _deep_ep_moe(
w2_ep,
w1_scale_ep,
w2_scale_ep,
config.num_experts,
)

torch.testing.assert_close(
Expand All @@ -335,15 +417,21 @@ def _deep_ep_moe(
(222, 1024, 2048),
]

TOPKS = [2, 6]
NUM_EXPERTS = [32]


@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
world_dp_size: tuple[int, int]):
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
topk: int, world_dp_size: tuple[int, int]):
"""
Tests for High-Throughput DeepEP + DeepGemm integration.
"""

m, n, k = mnk
current_platform.seed_everything(7)
Expand All @@ -354,6 +442,58 @@ def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]

world_size, dp_size = world_dp_size
config = TestConfig(topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts,
block_size=block_size,
low_latency=False,
use_fp8_dispatch=None)

w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
num_experts, n, k, block_size)

parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
w2, w1_scale, w2_scale)


MNKs = [
(1, 128, 2560),
(2, 128, 2560),
(3, 1024, 2560),
(32, 128, 2560),
(45, 512, 2560),
(64, 1024, 2560),
(222, 1024, 2560),
]
# Fix tests for USE_FP8_DISPATCH=True
USE_FP8_DISPATCH = [False]


@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@pytest.mark.parametrize("block_size", [[128, 128]])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
int], num_experts: int, topk: int,
use_fp8_dispatch: bool, block_size: list[int],
world_dp_size: tuple[int, int]):
"""
Tests for Low-Latency DeepEP + DeepGemm integration.
"""

m, n, k = mnk
current_platform.seed_everything(7)

if topk > num_experts:
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

world_size, dp_size = world_dp_size
config = TestConfig(
topk=topk,
Expand All @@ -362,10 +502,12 @@ def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int,
n=n,
num_experts=num_experts,
block_size=block_size,
low_latency=True,
use_fp8_dispatch=use_fp8_dispatch,
)

w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
num_experts, n, k, block_size)

parallel_launch(world_size, _deep_ep_moe, dp_size, config, w1, w2,
w1_scale, w2_scale)
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
w2, w1_scale, w2_scale)
Loading