Skip to content

Commit

Permalink
[Bugfix] Fix Marlin MoE act order when is_k_full == False (vllm-proje…
Browse files Browse the repository at this point in the history
…ct#8741)

Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
  • Loading branch information
2 people authored and garg-amit committed Oct 28, 2024
1 parent 433b0b3 commit 376df3e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 18 deletions.
3 changes: 3 additions & 0 deletions csrc/core/exception.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#pragma once

#define VLLM_IMPLIES(p, q) (!(p) || (q))
12 changes: 6 additions & 6 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <iostream>

#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
Expand Down Expand Up @@ -189,7 +190,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int load_groups =
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
return load_groups * tb_n * 4;

} else {
int tb_scales = tb_groups * tb_n * 2;
Expand Down Expand Up @@ -433,11 +434,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
int4* C_ptr = (int4*)C;
const float* topk_weights_ptr = (const float*)topk_weights;
const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr =
(const int4*)s +
(((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) *
prob_n / 8) *
expert_idx;
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
int* locks = (int*)workspace;
Expand Down Expand Up @@ -521,6 +518,9 @@ torch::Tensor marlin_gemm_moe(
" is not size_n = ", size_n);
num_groups = b_scales.size(1);

TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order),
"if is_k_full is false, has_act_order must be true");

if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
Expand Down
32 changes: 23 additions & 9 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def compute_max_diff(output, output_ref):
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_fused_marlin_moe(
m: int,
n: int,
Expand All @@ -154,6 +155,7 @@ def test_fused_marlin_moe(
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
seed_everything(7)

Expand All @@ -166,6 +168,9 @@ def test_fused_marlin_moe(
return
if group_size in (k, n):
return
else:
if not is_k_full:
return

quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
Expand Down Expand Up @@ -246,6 +251,7 @@ def test_fused_marlin_moe(
w1_scale=scales1,
w2_scale=scales2,
num_bits=num_bits,
is_k_full=is_k_full,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2
Expand Down Expand Up @@ -290,6 +296,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_single_marlin_moe_multiply(
m: int,
n: int,
Expand All @@ -299,6 +306,7 @@ def test_single_marlin_moe_multiply(
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
if topk > e:
return
Expand All @@ -309,6 +317,9 @@ def test_single_marlin_moe_multiply(
return
if group_size == k:
return
else:
if not is_k_full:
return

quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
Expand Down Expand Up @@ -339,15 +350,18 @@ def test_single_marlin_moe_multiply(
sort_indices = stack_and_dev(sort_indices_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False,
num_bits=num_bits)
marlin_output = single_marlin_moe(
a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False,
num_bits=num_bits,
is_k_full=is_k_full,
)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

assert compute_max_diff(marlin_output, torch_output) < 1e-2
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def single_marlin_moe(
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes the multiplication of hidden_states with expert
Expand Down Expand Up @@ -86,7 +87,7 @@ def single_marlin_moe(

intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
g_idx, perm, workspace, scalar_type, M, N, K, is_k_full, E, topk,
block_size_m, True, False)

return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
Expand All @@ -107,6 +108,7 @@ def fused_marlin_moe(
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand Down Expand Up @@ -199,7 +201,7 @@ def fused_marlin_moe(
M,
2 * N,
K,
True,
is_k_full,
E,
topk,
block_size_m,
Expand All @@ -223,7 +225,7 @@ def fused_marlin_moe(
M,
K,
N,
True,
is_k_full,
E,
topk,
block_size_m,
Expand Down

0 comments on commit 376df3e

Please sign in to comment.