Skip to content

Commit 376df3e

Browse files
ElizaWszolatlrmchlsmth
authored andcommitted
[Bugfix] Fix Marlin MoE act order when is_k_full == False (vllm-project#8741)
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent 433b0b3 commit 376df3e

File tree

4 files changed

+37
-18
lines changed

4 files changed

+37
-18
lines changed

csrc/core/exception.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#pragma once
2+
3+
#define VLLM_IMPLIES(p, q) (!(p) || (q))

csrc/moe/marlin_moe_ops.cu

+6-6
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include <iostream>
2727

28+
#include "core/exception.hpp"
2829
#include "core/scalar_type.hpp"
2930
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
3031
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
@@ -189,7 +190,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
189190
int load_groups =
190191
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
191192
load_groups = max(load_groups, 32); // We load at least 32 scale groups
192-
return load_groups * tb_n * 2;
193+
return load_groups * tb_n * 4;
193194

194195
} else {
195196
int tb_scales = tb_groups * tb_n * 2;
@@ -433,11 +434,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
433434
int4* C_ptr = (int4*)C;
434435
const float* topk_weights_ptr = (const float*)topk_weights;
435436
const int* sorted_ids_ptr = (const int*)sorted_ids;
436-
const int4* s_ptr =
437-
(const int4*)s +
438-
(((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) *
439-
prob_n / 8) *
440-
expert_idx;
437+
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
441438
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
442439
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
443440
int* locks = (int*)workspace;
@@ -521,6 +518,9 @@ torch::Tensor marlin_gemm_moe(
521518
" is not size_n = ", size_n);
522519
num_groups = b_scales.size(1);
523520

521+
TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order),
522+
"if is_k_full is false, has_act_order must be true");
523+
524524
if (has_act_order) {
525525
if (is_k_full) {
526526
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");

tests/kernels/test_moe.py

+23-9
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def compute_max_diff(output, output_ref):
145145
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
146146
@pytest.mark.parametrize("act_order", [True, False])
147147
@pytest.mark.parametrize("num_bits", [4, 8])
148+
@pytest.mark.parametrize("is_k_full", [True, False])
148149
def test_fused_marlin_moe(
149150
m: int,
150151
n: int,
@@ -154,6 +155,7 @@ def test_fused_marlin_moe(
154155
group_size: int,
155156
act_order: bool,
156157
num_bits: int,
158+
is_k_full: bool,
157159
):
158160
seed_everything(7)
159161

@@ -166,6 +168,9 @@ def test_fused_marlin_moe(
166168
return
167169
if group_size in (k, n):
168170
return
171+
else:
172+
if not is_k_full:
173+
return
169174

170175
quant_type = (scalar_types.uint4b8
171176
if num_bits == 4 else scalar_types.uint8b128)
@@ -246,6 +251,7 @@ def test_fused_marlin_moe(
246251
w1_scale=scales1,
247252
w2_scale=scales2,
248253
num_bits=num_bits,
254+
is_k_full=is_k_full,
249255
)
250256

251257
assert compute_max_diff(marlin_output, triton_output) < 4e-2
@@ -290,6 +296,7 @@ def test_fused_marlin_moe(
290296
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
291297
@pytest.mark.parametrize("act_order", [True, False])
292298
@pytest.mark.parametrize("num_bits", [4, 8])
299+
@pytest.mark.parametrize("is_k_full", [True, False])
293300
def test_single_marlin_moe_multiply(
294301
m: int,
295302
n: int,
@@ -299,6 +306,7 @@ def test_single_marlin_moe_multiply(
299306
group_size: int,
300307
act_order: bool,
301308
num_bits: int,
309+
is_k_full: bool,
302310
):
303311
if topk > e:
304312
return
@@ -309,6 +317,9 @@ def test_single_marlin_moe_multiply(
309317
return
310318
if group_size == k:
311319
return
320+
else:
321+
if not is_k_full:
322+
return
312323

313324
quant_type = (scalar_types.uint4b8
314325
if num_bits == 4 else scalar_types.uint8b128)
@@ -339,15 +350,18 @@ def test_single_marlin_moe_multiply(
339350
sort_indices = stack_and_dev(sort_indices_l)
340351

341352
score = torch.randn((m, e), device="cuda", dtype=dtype)
342-
marlin_output = single_marlin_moe(a,
343-
qweight,
344-
scales,
345-
score,
346-
g_idx,
347-
sort_indices,
348-
topk,
349-
renormalize=False,
350-
num_bits=num_bits)
353+
marlin_output = single_marlin_moe(
354+
a,
355+
qweight,
356+
scales,
357+
score,
358+
g_idx,
359+
sort_indices,
360+
topk,
361+
renormalize=False,
362+
num_bits=num_bits,
363+
is_k_full=is_k_full,
364+
)
351365
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
352366

353367
assert compute_max_diff(marlin_output, torch_output) < 1e-2

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def single_marlin_moe(
2121
renormalize: bool,
2222
override_config: Optional[Dict[str, Any]] = None,
2323
num_bits: int = 8,
24+
is_k_full: bool = True,
2425
) -> torch.Tensor:
2526
"""
2627
This function computes the multiplication of hidden_states with expert
@@ -86,7 +87,7 @@ def single_marlin_moe(
8687

8788
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
8889
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
89-
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
90+
g_idx, perm, workspace, scalar_type, M, N, K, is_k_full, E, topk,
9091
block_size_m, True, False)
9192

9293
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
@@ -107,6 +108,7 @@ def fused_marlin_moe(
107108
w1_scale: Optional[torch.Tensor] = None,
108109
w2_scale: Optional[torch.Tensor] = None,
109110
num_bits: int = 8,
111+
is_k_full: bool = True,
110112
) -> torch.Tensor:
111113
"""
112114
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -199,7 +201,7 @@ def fused_marlin_moe(
199201
M,
200202
2 * N,
201203
K,
202-
True,
204+
is_k_full,
203205
E,
204206
topk,
205207
block_size_m,
@@ -223,7 +225,7 @@ def fused_marlin_moe(
223225
M,
224226
K,
225227
N,
226-
True,
228+
is_k_full,
227229
E,
228230
topk,
229231
block_size_m,

0 commit comments

Comments
 (0)