Skip to content

Commit 5d60def

Browse files
zwd003esmeetu
andauthored
DeepseekMoE support with Fused MoE kernel (#2453)
Co-authored-by: roy <jasonailu87@gmail.com>
1 parent ea8489f commit 5d60def

File tree

9 files changed

+924
-0
lines changed

9 files changed

+924
-0
lines changed

csrc/dispatch_utils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,14 @@
2424
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
2525
AT_DISPATCH_SWITCH( \
2626
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
27+
28+
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
29+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
30+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
31+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
32+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
33+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
34+
35+
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
36+
AT_DISPATCH_SWITCH( \
37+
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#include <torch/extension.h>
2+
#include <ATen/cuda/CUDAContext.h>
3+
4+
#include <ATen/ATen.h>
5+
#include <THC/THCAtomics.cuh>
6+
7+
#include "cuda_compat.h"
8+
#include "dispatch_utils.h"
9+
10+
const static size_t NUM_MAX_EXPERTS = 64;
11+
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
12+
13+
namespace vllm {
14+
template <typename scalar_t>
15+
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
16+
int32_t *sorted_token_ids,
17+
int32_t *expert_ids,
18+
int32_t *total_tokens_post_pad,
19+
int32_t num_experts,
20+
int32_t block_size,
21+
size_t numel) {
22+
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
23+
const size_t start_idx = threadIdx.x * tokens_per_thread;
24+
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
25+
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
26+
for (int i = 0; i < num_experts; ++i) {
27+
tokens_cnts[threadIdx.x + 1][i] = 0;
28+
}
29+
30+
/**
31+
* In the first step we compute token_cnts[thread_index + 1][expert_index],
32+
* which counts how many tokens in the token shard of thread_index are assigned
33+
* to expert expert_index.
34+
*/
35+
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
36+
++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
37+
}
38+
39+
__syncthreads();
40+
41+
// For each expert we accumulate the token counts from the different threads.
42+
tokens_cnts[0][threadIdx.x] = 0;
43+
for (int i = 1; i <= blockDim.x; ++i) {
44+
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
45+
}
46+
47+
__syncthreads();
48+
49+
// We accumulate the token counts of all experts in thread 0.
50+
if (threadIdx.x == 0) {
51+
cumsum[0] = 0;
52+
for (int i = 1; i <= num_experts; ++i) {
53+
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
54+
}
55+
*total_tokens_post_pad = cumsum[num_experts];
56+
}
57+
58+
__syncthreads();
59+
60+
/**
61+
* For each expert, each thread processes the tokens of the corresponding blocks
62+
* and stores the corresponding expert_id for each block.
63+
*/
64+
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
65+
expert_ids[i / block_size] = threadIdx.x;
66+
}
67+
68+
/**
69+
* Each thread processes a token shard, calculating the index of each token after
70+
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
71+
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
72+
* where * represents a padding value(preset in python).
73+
*/
74+
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
75+
int32_t expert_id = topk_ids[i];
76+
/** The cumsum[expert_id] stores the starting index of the tokens that the
77+
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
78+
* stores the indices of the tokens processed by the expert with expert_id within
79+
* the current thread's token shard.
80+
*/
81+
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
82+
sorted_token_ids[rank_post_pad] = i;
83+
++tokens_cnts[threadIdx.x][expert_id];
84+
}
85+
}
86+
}
87+
88+
void moe_align_block_size(
89+
torch::Tensor topk_ids,
90+
int num_experts,
91+
int block_size,
92+
torch::Tensor sorted_token_ids,
93+
torch::Tensor experts_ids,
94+
torch::Tensor num_tokens_post_pad) {
95+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
96+
assert(num_experts <= NUM_MAX_EXPERTS);
97+
VLLM_DISPATCH_INTEGRAL_TYPES(
98+
topk_ids.scalar_type(), "moe_alig_block_size_kernel", [&] {
99+
vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
100+
topk_ids.data_ptr<scalar_t>(),
101+
sorted_token_ids.data_ptr<int32_t>(),
102+
experts_ids.data_ptr<int32_t>(),
103+
num_tokens_post_pad.data_ptr<int32_t>(),
104+
num_experts,
105+
block_size,
106+
topk_ids.numel());
107+
});
108+
}

csrc/ops.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,12 @@ std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
121121
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
122122
const std::vector<std::vector<int64_t>> &offsets);
123123
#endif
124+
125+
void moe_align_block_size(
126+
torch::Tensor topk_ids,
127+
int num_experts,
128+
int block_size,
129+
torch::Tensor sorted_token_ids,
130+
torch::Tensor experts_ids,
131+
torch::Tensor num_tokens_post_pad
132+
);

csrc/pybind.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5656
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
5757
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
5858
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
59+
ops.def(
60+
"moe_align_block_size",
61+
&moe_align_block_size,
62+
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
5963

6064
// Cache ops
6165
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def get_torch_arch_list() -> Set[str]:
309309
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
310310
"csrc/quantization/gptq/q_gemm.cu",
311311
"csrc/cuda_utils_kernels.cu",
312+
"csrc/moe_align_block_size_kernels.cu",
312313
"csrc/pybind.cpp",
313314
]
314315

tests/kernels/test_fused_moe.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
import torch
3+
4+
from vllm.model_executor.layers.fused_moe import fused_moe
5+
from vllm.model_executor.layers.activation import SiluAndMul
6+
7+
8+
def torch_moe(a, w1, w2, topk_weight, topk_ids):
9+
B, D = a.shape
10+
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)
11+
out = torch.zeros(B * topk_ids.shape[1],
12+
w2.shape[1],
13+
dtype=a.dtype,
14+
device=a.device)
15+
topk_ids = topk_ids.view(-1)
16+
topk_weight = topk_weight.view(-1)
17+
for i in range(w1.shape[0]):
18+
mask = topk_ids == i
19+
if mask.sum():
20+
out[mask] = SiluAndMul()(
21+
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
22+
return (out.view(B, -1, w2.shape[1]) *
23+
topk_weight.view(B, -1, 1)).sum(dim=1)
24+
25+
26+
@pytest.mark.parametrize("m", [512, 222, 33, 1])
27+
@pytest.mark.parametrize("n", [2048, 256, 1024])
28+
@pytest.mark.parametrize("k", [128, 511, 1024])
29+
@pytest.mark.parametrize("e", [8, 64])
30+
@pytest.mark.parametrize("topk", [2, 6])
31+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
32+
def test_fused_moe(
33+
m: int,
34+
n: int,
35+
k: int,
36+
e: int,
37+
topk: int,
38+
dtype: torch.dtype,
39+
):
40+
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
41+
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
42+
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
43+
44+
score = torch.randn((m, e), device='cuda', dtype=dtype)
45+
score = torch.softmax(score, dim=-1)
46+
topk_weight, topk_ids = torch.topk(score, topk)
47+
48+
triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
49+
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
50+
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)

0 commit comments

Comments
 (0)