8
8
#include " dispatch_utils.h"
9
9
10
10
const static size_t NUM_MAX_EXPERTS = 64 ;
11
+ #define CEILDIV (x,y ) (((x) + (y) - 1 ) / (y))
11
12
12
13
namespace vllm {
13
14
template <typename scalar_t >
@@ -22,39 +23,61 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
22
23
const size_t start_idx = threadIdx .x * tokens_per_thread;
23
24
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1 ][NUM_MAX_EXPERTS];
24
25
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1 ];
25
- for (int i = 0 ; i < num_experts; ++i){
26
+ for (int i = 0 ; i < num_experts; ++i) {
26
27
tokens_cnts[threadIdx .x + 1 ][i] = 0 ;
27
28
}
28
29
30
+ /* *
31
+ * In the first step we compute token_cnts[thread_index + 1][expert_index],
32
+ * which counts how many tokens in the token shard of thread_index are assigned
33
+ * to expert expert_index.
34
+ */
29
35
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
30
36
++tokens_cnts[threadIdx .x + 1 ][topk_ids[i]];
31
37
}
32
38
33
39
__syncthreads ();
34
40
41
+ // For each expert we accumulate the token counts from the different threads.
35
42
tokens_cnts[0 ][threadIdx .x ] = 0 ;
36
- for (int i = 1 ; i <= blockDim .x ; ++i){
43
+ for (int i = 1 ; i <= blockDim .x ; ++i) {
37
44
tokens_cnts[i][threadIdx .x ] += tokens_cnts[i-1 ][threadIdx .x ];
38
45
}
39
46
40
47
__syncthreads ();
41
48
42
- if (threadIdx .x == 0 ){
49
+ // We accumulate the token counts of all experts in thread 0.
50
+ if (threadIdx .x == 0 ) {
43
51
cumsum[0 ] = 0 ;
44
- for (int i = 1 ; i <= num_experts; ++i){
45
- cumsum[i] = cumsum[i-1 ] + (tokens_cnts[blockDim .x ][i - 1 ] + block_size - 1 ) / block_size * block_size;
52
+ for (int i = 1 ; i <= num_experts; ++i) {
53
+ cumsum[i] = cumsum[i-1 ] + CEILDIV (tokens_cnts[blockDim .x ][i - 1 ], block_size) * block_size;
46
54
}
47
55
*total_tokens_post_pad = cumsum[num_experts];
48
56
}
49
57
50
58
__syncthreads ();
51
59
52
- for (int i = cumsum[threadIdx .x ];i < cumsum[threadIdx .x + 1 ];i += block_size){
60
+ /* *
61
+ * For each expert, each thread processes the tokens of the corresponding blocks
62
+ * and stores the corresponding expert_id for each block.
63
+ */
64
+ for (int i = cumsum[threadIdx .x ];i < cumsum[threadIdx .x + 1 ];i += block_size) {
53
65
expert_ids[i / block_size] = threadIdx .x ;
54
66
}
55
67
68
+ /* *
69
+ * Each thread processes a token shard, calculating the index of each token after
70
+ * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
71
+ * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
72
+ * where * represents a padding value(preset in python).
73
+ */
56
74
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
57
75
int32_t expert_id = topk_ids[i];
76
+ /* * The cumsum[expert_id] stores the starting index of the tokens that the
77
+ * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
78
+ * stores the indices of the tokens processed by the expert with expert_id within
79
+ * the current thread's token shard.
80
+ */
58
81
int32_t rank_post_pad = tokens_cnts[threadIdx .x ][expert_id] + cumsum[expert_id];
59
82
sorted_token_ids[rank_post_pad] = i;
60
83
++tokens_cnts[threadIdx .x ][expert_id];
@@ -82,4 +105,4 @@ void moe_align_block_size(
82
105
block_size,
83
106
topk_ids.numel ());
84
107
});
85
- }
108
+ }
0 commit comments