Skip to content

[Bugfix] Fix Maverick correctness by filling zero to cache space in cutlass_moe #20167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 117 additions & 19 deletions tests/kernels/moe/test_cutlass_moe.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from math import prod
from typing import Optional

import pytest
import torch

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8, run_cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
fused_topk)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.platforms import current_platform

NUM_EXPERTS = [40, 64]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it help with the working sets at line 38-39?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately no. We can look into this more separately.

Expand Down Expand Up @@ -236,6 +240,7 @@ def test_cutlass_moe_8_bit_no_graph(
per_act_token: bool,
per_out_ch: bool,
monkeypatch,
ep_size: Optional[int] = None,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
Expand All @@ -254,7 +259,13 @@ def test_cutlass_moe_8_bit_no_graph(
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)

cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
if ep_size is not None:
assert e % ep_size == 0, "Cannot distribute experts evenly"
number_local_experts = e // ep_size
else:
number_local_experts = None
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
number_local_experts)

torch.testing.assert_close(triton_output,
cutlass_output,
Expand Down Expand Up @@ -337,9 +348,62 @@ def test_cutlass_moe_8_bit_EP(
per_out_channel: bool,
ep_size: int,
monkeypatch,
):
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
per_out_channel, monkeypatch, ep_size)


LARGE_MNK_FACTORS = [
(1, 8192, 5120, 31),
(32768, 1024, 1024, 16),
(65536, 512, 1024, 16),
]


@pytest.mark.parametrize("m,n,k,topk", LARGE_MNK_FACTORS)
@pytest.mark.parametrize("e", [128])
@pytest.mark.parametrize("per_act_token", [False])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
def test_cutlass_moe_8_bit_EP_large(
m: int,
n: int,
k: int,
e: int,
topk: int,
per_act_token: bool,
per_out_channel: bool,
ep_size: int,
monkeypatch,
):
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
per_out_channel, monkeypatch, ep_size)


@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
@pytest.mark.parametrize("e", [128])
@pytest.mark.parametrize("per_act_token", [False])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
def test_run_cutlass_moe_fp8(
m: int,
n: int,
k: int,
e: int,
topk: int,
per_act_token: bool,
per_out_channel: bool,
ep_size: int,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not very git-friendly, but note this line is not removed during refactoring. see test_cutlass_moe_8_bit_no_graph

with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel)
Expand All @@ -349,19 +413,53 @@ def test_cutlass_moe_8_bit_EP(
score,
topk,
renormalize=False)

# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)

assert e % ep_size == 0, "Cannot distribute experts evenly"
cutlass_output = run_8_bit(mt,
topk_weights,
topk_ids,
num_local_experts=e // ep_size)

torch.testing.assert_close(triton_output,
cutlass_output,
atol=5e-2,
rtol=1e-2)
# we want to make sure there is at least one token that's generated in
# this expert shard and at least one token that's NOT generated in this
# expert shard
topk_ids[0][0] = -1
topk_ids[0][1] = 1

workspace13_shape = (m * topk, max(2 * n, k))
workspace2_shape = (m * topk, n)
output_shape = (m * topk, k)

workspace13 = torch.empty(prod(workspace13_shape),
device="cuda",
dtype=mt.a.dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device="cuda",
dtype=mt.a.dtype)

num_local_experts = e // ep_size
start, end = 0, num_local_experts
expert_map = [-1] * e
expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")

activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn,
per_act_token)
global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
per_act_token, per_out_channel, False)

workspace13.random_()
output_random_workspace = torch.empty(output_shape,
device="cuda",
dtype=mt.a.dtype)
func(output_random_workspace)

workspace13.fill_(0)
output_zero_workspace = torch.zeros(output_shape,
device="cuda",
dtype=mt.a.dtype)
func(output_zero_workspace)

torch.testing.assert_close(output_random_workspace,
output_zero_workspace,
atol=5e-3,
rtol=1e-3)
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ def run_cutlass_moe_fp8(
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))

if expert_map is not None:
c1.fill_(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds like this should impact both chunking and non chunking path?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the non-chunking (batched) path is not impacted because c1 is fully overridden. I don't have solid proof though (I need to look into that code path more). @bnellnm / @ElizaWszola comments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this locally and found that the batched case needs to be cleared also. I think it's probably best to unconditionally zero out c1

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. updated. Could you share how to run batched case tests? thx.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed when we don't use expert_map? In case it's not, can you write a condition for this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. updated!


ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
per_act_token, per_out_ch)
Expand Down Expand Up @@ -269,7 +272,7 @@ def apply(
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
activation_callable = lambda i, o: self.activation(activation, i, o)
activation_callable = lambda o, i: self.activation(activation, o, i)
run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
activation_callable, global_num_experts,
expert_map, w1_scale, w2_scale, a1q_scale,
Expand Down