-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
DeepseekMoE support with Fused MoE kernel #2453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
023575a
init support
esmeetu 3f9a978
deepseekmoe support
zwd003 8b64d5c
add fused moe kernel
zwd003 13f526d
Add more code comments.
zwd003 8e26af6
bug fix
zwd003 3924f17
fp16 support
zwd003 108a13d
bug fix
zwd003 49e2b10
remove global static cache
zwd003 867eb13
add unit test
zwd003 1b45bbc
code format
zwd003 1f43929
optimizing performance on small batch
zwd003 9ae3d34
fix up
zwd003 c017314
fix up
zwd003 c3a30d1
rewrite alig_block_size in cpp
zwd003 82ac2e4
fix up
zwd003 12358d3
fix up
zwd003 56f7d45
fix up
zwd003 0bb745f
fix up
esmeetu 46603ca
yapf
WoosukKwon b85c1eb
Merge branch 'main' into deepseekmoe-dev
WoosukKwon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#include <torch/extension.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
#include <ATen/ATen.h> | ||
#include <THC/THCAtomics.cuh> | ||
|
||
#include "cuda_compat.h" | ||
#include "dispatch_utils.h" | ||
|
||
const static size_t NUM_MAX_EXPERTS = 64; | ||
#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) | ||
|
||
namespace vllm { | ||
template <typename scalar_t> | ||
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, | ||
int32_t *sorted_token_ids, | ||
int32_t *expert_ids, | ||
int32_t *total_tokens_post_pad, | ||
int32_t num_experts, | ||
int32_t block_size, | ||
size_t numel) { | ||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); | ||
const size_t start_idx = threadIdx.x * tokens_per_thread; | ||
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Actually, this part doesn't need to be static. Shared memory size can be configured dynamically at the kernel launch time. However, I think we can fix this in a later PR. |
||
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1]; | ||
for (int i = 0; i < num_experts; ++i) { | ||
tokens_cnts[threadIdx.x + 1][i] = 0; | ||
} | ||
|
||
/** | ||
* In the first step we compute token_cnts[thread_index + 1][expert_index], | ||
* which counts how many tokens in the token shard of thread_index are assigned | ||
* to expert expert_index. | ||
*/ | ||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||
++tokens_cnts[threadIdx.x + 1][topk_ids[i]]; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
// For each expert we accumulate the token counts from the different threads. | ||
tokens_cnts[0][threadIdx.x] = 0; | ||
zwd003 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for (int i = 1; i <= blockDim.x; ++i) { | ||
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x]; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
// We accumulate the token counts of all experts in thread 0. | ||
if (threadIdx.x == 0) { | ||
cumsum[0] = 0; | ||
for (int i = 1; i <= num_experts; ++i) { | ||
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size; | ||
} | ||
*total_tokens_post_pad = cumsum[num_experts]; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
/** | ||
* For each expert, each thread processes the tokens of the corresponding blocks | ||
* and stores the corresponding expert_id for each block. | ||
*/ | ||
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { | ||
expert_ids[i / block_size] = threadIdx.x; | ||
} | ||
|
||
/** | ||
* Each thread processes a token shard, calculating the index of each token after | ||
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and | ||
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], | ||
* where * represents a padding value(preset in python). | ||
*/ | ||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||
int32_t expert_id = topk_ids[i]; | ||
/** The cumsum[expert_id] stores the starting index of the tokens that the | ||
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] | ||
* stores the indices of the tokens processed by the expert with expert_id within | ||
* the current thread's token shard. | ||
*/ | ||
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id]; | ||
sorted_token_ids[rank_post_pad] = i; | ||
++tokens_cnts[threadIdx.x][expert_id]; | ||
} | ||
} | ||
} | ||
|
||
void moe_align_block_size( | ||
torch::Tensor topk_ids, | ||
int num_experts, | ||
int block_size, | ||
torch::Tensor sorted_token_ids, | ||
torch::Tensor experts_ids, | ||
torch::Tensor num_tokens_post_pad) { | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
assert(num_experts <= NUM_MAX_EXPERTS); | ||
VLLM_DISPATCH_INTEGRAL_TYPES( | ||
topk_ids.scalar_type(), "moe_alig_block_size_kernel", [&] { | ||
vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>( | ||
topk_ids.data_ptr<scalar_t>(), | ||
sorted_token_ids.data_ptr<int32_t>(), | ||
experts_ids.data_ptr<int32_t>(), | ||
num_tokens_post_pad.data_ptr<int32_t>(), | ||
num_experts, | ||
block_size, | ||
topk_ids.numel()); | ||
}); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import pytest | ||
import torch | ||
|
||
from vllm.model_executor.layers.fused_moe import fused_moe | ||
from vllm.model_executor.layers.activation import SiluAndMul | ||
|
||
|
||
def torch_moe(a, w1, w2, topk_weight, topk_ids): | ||
B, D = a.shape | ||
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) | ||
out = torch.zeros(B * topk_ids.shape[1], | ||
w2.shape[1], | ||
dtype=a.dtype, | ||
device=a.device) | ||
topk_ids = topk_ids.view(-1) | ||
topk_weight = topk_weight.view(-1) | ||
for i in range(w1.shape[0]): | ||
mask = topk_ids == i | ||
if mask.sum(): | ||
out[mask] = SiluAndMul()( | ||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) | ||
return (out.view(B, -1, w2.shape[1]) * | ||
topk_weight.view(B, -1, 1)).sum(dim=1) | ||
|
||
|
||
@pytest.mark.parametrize("m", [512, 222, 33, 1]) | ||
@pytest.mark.parametrize("n", [2048, 256, 1024]) | ||
@pytest.mark.parametrize("k", [128, 511, 1024]) | ||
@pytest.mark.parametrize("e", [8, 64]) | ||
@pytest.mark.parametrize("topk", [2, 6]) | ||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
def test_fused_moe( | ||
m: int, | ||
n: int, | ||
k: int, | ||
e: int, | ||
topk: int, | ||
dtype: torch.dtype, | ||
): | ||
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 | ||
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 | ||
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 | ||
|
||
score = torch.randn((m, e), device='cuda', dtype=dtype) | ||
score = torch.softmax(score, dim=-1) | ||
topk_weight, topk_ids = torch.topk(score, topk) | ||
|
||
triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False) | ||
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids) | ||
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.