Skip to content

Commit 56f7d45

Browse files
committed
fix up
1 parent 12358d3 commit 56f7d45

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

csrc/moe_align_block_size_kernels.cu

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "dispatch_utils.h"
99

1010
const static size_t NUM_MAX_EXPERTS = 64;
11+
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
1112

1213
namespace vllm {
1314
template <typename scalar_t>
@@ -22,39 +23,61 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
2223
const size_t start_idx = threadIdx.x * tokens_per_thread;
2324
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
2425
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
25-
for(int i = 0; i < num_experts; ++i){
26+
for (int i = 0; i < num_experts; ++i) {
2627
tokens_cnts[threadIdx.x + 1][i] = 0;
2728
}
2829

30+
/**
31+
* In the first step we compute token_cnts[thread_index + 1][expert_index],
32+
* which counts how many tokens in the token shard of thread_index are assigned
33+
* to expert expert_index.
34+
*/
2935
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
3036
++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
3137
}
3238

3339
__syncthreads();
3440

41+
// For each expert we accumulate the token counts from the different threads.
3542
tokens_cnts[0][threadIdx.x] = 0;
36-
for(int i = 1; i <= blockDim.x; ++i){
43+
for (int i = 1; i <= blockDim.x; ++i) {
3744
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
3845
}
3946

4047
__syncthreads();
4148

42-
if(threadIdx.x == 0){
49+
// We accumulate the token counts of all experts in thread 0.
50+
if (threadIdx.x == 0) {
4351
cumsum[0] = 0;
44-
for(int i = 1; i <= num_experts; ++i){
45-
cumsum[i] = cumsum[i-1] + (tokens_cnts[blockDim.x][i - 1] + block_size - 1) / block_size * block_size;
52+
for (int i = 1; i <= num_experts; ++i) {
53+
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
4654
}
4755
*total_tokens_post_pad = cumsum[num_experts];
4856
}
4957

5058
__syncthreads();
5159

52-
for(int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size){
60+
/**
61+
* For each expert, each thread processes the tokens of the corresponding blocks
62+
* and stores the corresponding expert_id for each block.
63+
*/
64+
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
5365
expert_ids[i / block_size] = threadIdx.x;
5466
}
5567

68+
/**
69+
* Each thread processes a token shard, calculating the index of each token after
70+
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
71+
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
72+
* where * represents a padding value(preset in python).
73+
*/
5674
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
5775
int32_t expert_id = topk_ids[i];
76+
/** The cumsum[expert_id] stores the starting index of the tokens that the
77+
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
78+
* stores the indices of the tokens processed by the expert with expert_id within
79+
* the current thread's token shard.
80+
*/
5881
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
5982
sorted_token_ids[rank_post_pad] = i;
6083
++tokens_cnts[threadIdx.x][expert_id];
@@ -82,4 +105,4 @@ void moe_align_block_size(
82105
block_size,
83106
topk_ids.numel());
84107
});
85-
}
108+
}

0 commit comments

Comments
 (0)