Skip to content

[Kernels] Add activation chunking logic to FusedMoEModularKernel #19168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/kernels/moe/test_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
(224, 1024, 1536),
(224, 3072, 1024),
(224, 3072, 1536),
(1024 * 128, 1024, 1024),
]

vllm_config = VllmConfig(parallel_config=ParallelConfig(
Expand Down
23 changes: 22 additions & 1 deletion tests/kernels/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
Expand Down Expand Up @@ -76,6 +77,13 @@ def test_fused_moe(
else:
e_map = None

m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=None)

with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
iterative_output = iterative_moe(a,
Expand Down Expand Up @@ -103,7 +111,20 @@ def test_fused_moe(
expert_map=e_map,
renormalize=False)

topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
m_triton_output = m_fused_moe(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=e_map)

torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(m_triton_output,
torch_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(iterative_output,
torch_output,
atol=2e-2,
Expand Down
43 changes: 33 additions & 10 deletions tests/kernels/moe/test_pplx_cutlass_moe.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional

import pytest
import torch

from tests.pplx_utils import ProcessGroupInfo, parallel_launch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
Expand All @@ -14,6 +15,8 @@
FusedMoEModularKernel)
from vllm.platforms import current_platform

from .deepep_utils import ProcessGroupInfo, parallel_launch

try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
Expand Down Expand Up @@ -64,6 +67,7 @@ def pplx_cutlass_moe(
out_dtype,
per_act_token: bool,
per_out_ch: bool,
group_name: Optional[str],
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
Expand All @@ -84,7 +88,7 @@ def pplx_cutlass_moe(
else:
scale_elems = (hidden_dim + block_size - 1) // block_size

ata = AllToAll.internode(
args = dict(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
Expand All @@ -96,6 +100,12 @@ def pplx_cutlass_moe(
hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
)

if group_name is None:
ata = AllToAll.internode(**args)
else:
args["group_name"] = group_name
ata = AllToAll.intranode(**args)

w1 = w1.to(device)
w2 = w2.to(device)
w1_scale = w1_scale.to(device)
Expand All @@ -113,7 +123,10 @@ def pplx_cutlass_moe(
)

experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
out_dtype, per_act_token, per_out_ch)
out_dtype,
per_act_token,
per_out_ch,
use_batched_format=True)

fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
Expand Down Expand Up @@ -184,19 +197,25 @@ def _pplx_moe(
w2_full: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
use_internode: bool,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name

with set_current_vllm_config(vllm_config):
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
topk_ids)
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
w2_scale, topk_weights, topk_ids,
a1_scale, out_dtype, per_act_token,
per_out_ch)
per_out_ch, group_name)

torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
Expand All @@ -207,7 +226,8 @@ def _pplx_moe(

torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)

nvshmem_finalize()
if use_internode:
nvshmem_finalize()


@pytest.mark.parametrize("m", [2, 224])
Expand All @@ -218,6 +238,7 @@ def _pplx_moe(
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
Expand All @@ -232,6 +253,7 @@ def test_cutlass_moe_pplx(
per_act_token: bool,
per_out_ch: bool,
world_dp_size: tuple[int, int],
use_internode: bool,
):
current_platform.seed_everything(7)

Expand Down Expand Up @@ -284,4 +306,5 @@ def test_cutlass_moe_pplx(

parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
dtype, a, w1_d, w2_d, per_act_token, per_out_ch)
dtype, a, w1_d, w2_d, per_act_token, per_out_ch,
use_internode)
Loading