Skip to content

[Bugfix] Fix Maverick correctness by filling zero to cache space in cutlass_moe #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 117 additions & 19 deletions tests/kernels/moe/test_cutlass_moe.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from math import prod
from typing import Optional

import pytest
import torch

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8, run_cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
fused_topk)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.platforms import current_platform

NUM_EXPERTS = [40, 64]
Expand Down Expand Up @@ -236,6 +240,7 @@ def test_cutlass_moe_8_bit_no_graph(
per_act_token: bool,
per_out_ch: bool,
monkeypatch,
ep_size: Optional[int] = None,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
Expand All @@ -254,7 +259,13 @@ def test_cutlass_moe_8_bit_no_graph(
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)

cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
if ep_size is not None:
assert e % ep_size == 0, "Cannot distribute experts evenly"
number_local_experts = e // ep_size
else:
number_local_experts = None
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
number_local_experts)

torch.testing.assert_close(triton_output,
cutlass_output,
Expand Down Expand Up @@ -337,9 +348,62 @@ def test_cutlass_moe_8_bit_EP(
per_out_channel: bool,
ep_size: int,
monkeypatch,
):
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
per_out_channel, monkeypatch, ep_size)


LARGE_MNK_FACTORS = [
(1, 8192, 5120, 31),
(32768, 1024, 1024, 16),
(65536, 512, 1024, 16),
]


@pytest.mark.parametrize("m,n,k,topk", LARGE_MNK_FACTORS)
@pytest.mark.parametrize("e", [128])
@pytest.mark.parametrize("per_act_token", [False])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
def test_cutlass_moe_8_bit_EP_large(
m: int,
n: int,
k: int,
e: int,
topk: int,
per_act_token: bool,
per_out_channel: bool,
ep_size: int,
monkeypatch,
):
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
per_out_channel, monkeypatch, ep_size)


@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
@pytest.mark.parametrize("e", [128])
@pytest.mark.parametrize("per_act_token", [False])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
def test_run_cutlass_moe_fp8(
m: int,
n: int,
k: int,
e: int,
topk: int,
per_act_token: bool,
per_out_channel: bool,
ep_size: int,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel)
Expand All @@ -349,19 +413,53 @@ def test_cutlass_moe_8_bit_EP(
score,
topk,
renormalize=False)

# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)

assert e % ep_size == 0, "Cannot distribute experts evenly"
cutlass_output = run_8_bit(mt,
topk_weights,
topk_ids,
num_local_experts=e // ep_size)

torch.testing.assert_close(triton_output,
cutlass_output,
atol=5e-2,
rtol=1e-2)
# we want to make sure there is at least one token that's generated in
# this expert shard and at least one token that's NOT generated in this
# expert shard
topk_ids[0][0] = -1
topk_ids[0][1] = 1

workspace13_shape = (m * topk, max(2 * n, k))
workspace2_shape = (m * topk, n)
output_shape = (m * topk, k)

workspace13 = torch.empty(prod(workspace13_shape),
device="cuda",
dtype=mt.a.dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device="cuda",
dtype=mt.a.dtype)

num_local_experts = e // ep_size
start, end = 0, num_local_experts
expert_map = [-1] * e
expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")

activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn,
per_act_token)
global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
per_act_token, per_out_channel, False)

workspace13.random_()
output_random_workspace = torch.empty(output_shape,
device="cuda",
dtype=mt.a.dtype)
func(output_random_workspace)

workspace13.fill_(0)
output_zero_workspace = torch.zeros(output_shape,
device="cuda",
dtype=mt.a.dtype)
func(output_zero_workspace)

torch.testing.assert_close(output_random_workspace,
output_zero_workspace,
atol=5e-3,
rtol=1e-3)
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ def run_cutlass_moe_fp8(
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))

if expert_map is not None:
c1.fill_(0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great! any way to capture this in test_cutlass_moe?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, added a couple unit tests


ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
per_act_token, per_out_ch)
Expand Down Expand Up @@ -269,7 +272,7 @@ def apply(
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
activation_callable = lambda i, o: self.activation(activation, i, o)
activation_callable = lambda o, i: self.activation(activation, o, i)
run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
activation_callable, global_num_experts,
expert_map, w1_scale, w2_scale, a1q_scale,
Expand Down