Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix Marlin MoE act order when is_k_full == False #8741

Merged
merged 13 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,9 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
has_act_order, is_k_full, max_shared_mem);
}

int group_tensor_size =
(!is_k_full && has_act_order) ? prob_k / num_groups : group_size;

tlrmchlsmth marked this conversation as resolved.
Show resolved Hide resolved
TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
Expand Down Expand Up @@ -1826,8 +1829,8 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr =
(const int4*)s +
(((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) *
prob_n / 8) *
((group_tensor_size == -1 ? 1 : prob_k / group_tensor_size) * prob_n /
8) *
expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
Expand Down
32 changes: 23 additions & 9 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def compute_max_diff(output, output_ref):
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_fused_marlin_moe(
m: int,
n: int,
Expand All @@ -151,6 +152,7 @@ def test_fused_marlin_moe(
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
seed_everything(7)

Expand All @@ -163,6 +165,9 @@ def test_fused_marlin_moe(
return
if group_size in (k, n):
return
else:
if not is_k_full:
return

quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
Expand Down Expand Up @@ -243,6 +248,7 @@ def test_fused_marlin_moe(
w1_scale=scales1,
w2_scale=scales2,
num_bits=num_bits,
is_k_full=is_k_full,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2
Expand All @@ -258,6 +264,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_single_marlin_moe_multiply(
m: int,
n: int,
Expand All @@ -267,6 +274,7 @@ def test_single_marlin_moe_multiply(
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
if topk > e:
return
Expand All @@ -277,6 +285,9 @@ def test_single_marlin_moe_multiply(
return
if group_size == k:
return
else:
if not is_k_full:
return

quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
Expand Down Expand Up @@ -307,15 +318,18 @@ def test_single_marlin_moe_multiply(
sort_indices = stack_and_dev(sort_indices_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False,
num_bits=num_bits)
marlin_output = single_marlin_moe(
a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False,
num_bits=num_bits,
is_k_full=is_k_full,
)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

assert compute_max_diff(marlin_output, torch_output) < 1e-2
8 changes: 5 additions & 3 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def single_marlin_moe(
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes the multiplication of hidden_states with expert
Expand Down Expand Up @@ -86,7 +87,7 @@ def single_marlin_moe(

intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
g_idx, perm, workspace, scalar_type, M, N, K, is_k_full, E, topk,
block_size_m, True, False)

return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
Expand All @@ -107,6 +108,7 @@ def fused_marlin_moe(
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand Down Expand Up @@ -199,7 +201,7 @@ def fused_marlin_moe(
M,
2 * N,
K,
True,
is_k_full,
E,
topk,
block_size_m,
Expand All @@ -223,7 +225,7 @@ def fused_marlin_moe(
M,
K,
N,
True,
is_k_full,
E,
topk,
block_size_m,
Expand Down
Loading