@@ -55,11 +55,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
55
55
__syncthreads ();
56
56
57
57
// For each expert we accumulate the token counts from the different threads.
58
- for ( int eid = threadIdx .x ; eid < num_experts; eid += blockDim . x ) {
59
- tokens_cnts[index (num_experts, 0 , eid )] = 0 ;
58
+ if ( threadIdx .x < num_experts) {
59
+ tokens_cnts[index (num_experts, 0 , threadIdx . x )] = 0 ;
60
60
for (int i = 1 ; i <= blockDim .x ; ++i) {
61
- tokens_cnts[index (num_experts, i, eid )] +=
62
- tokens_cnts[index (num_experts, i - 1 , eid )];
61
+ tokens_cnts[index (num_experts, i, threadIdx . x )] +=
62
+ tokens_cnts[index (num_experts, i - 1 , threadIdx . x )];
63
63
}
64
64
}
65
65
@@ -83,9 +83,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
83
83
* For each expert, each thread processes the tokens of the corresponding
84
84
* blocks and stores the corresponding expert_id for each block.
85
85
*/
86
- for (int eid = threadIdx .x ; eid < num_experts; eid += blockDim .x ) {
87
- for (int i = cumsum[eid]; i < cumsum[eid + 1 ]; i += block_size) {
88
- expert_ids[i / block_size] = eid;
86
+ if (threadIdx .x < num_experts) {
87
+ for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ];
88
+ i += block_size) {
89
+ expert_ids[i / block_size] = threadIdx .x ;
89
90
}
90
91
}
91
92
@@ -140,11 +141,11 @@ __global__ void moe_align_block_size_global_mem_kernel(
140
141
__syncthreads ();
141
142
142
143
// For each expert we accumulate the token counts from the different threads.
143
- for ( int eid = threadIdx .x ; eid < num_experts; eid += blockDim . x ) {
144
- tokens_cnts[index (num_experts, 0 , eid )] = 0 ;
144
+ if ( threadIdx .x < num_experts) {
145
+ tokens_cnts[index (num_experts, 0 , threadIdx . x )] = 0 ;
145
146
for (int i = 1 ; i <= blockDim .x ; ++i) {
146
- tokens_cnts[index (num_experts, i, eid )] +=
147
- tokens_cnts[index (num_experts, i - 1 , eid )];
147
+ tokens_cnts[index (num_experts, i, threadIdx . x )] +=
148
+ tokens_cnts[index (num_experts, i - 1 , threadIdx . x )];
148
149
}
149
150
}
150
151
@@ -168,9 +169,10 @@ __global__ void moe_align_block_size_global_mem_kernel(
168
169
* For each expert, each thread processes the tokens of the corresponding
169
170
* blocks and stores the corresponding expert_id for each block.
170
171
*/
171
- for (int eid = threadIdx .x ; eid < num_experts; eid += blockDim .x ) {
172
- for (int i = cumsum[eid]; i < cumsum[eid + 1 ]; i += block_size) {
173
- expert_ids[i / block_size] = eid;
172
+ if (threadIdx .x < num_experts) {
173
+ for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ];
174
+ i += block_size) {
175
+ expert_ids[i / block_size] = threadIdx .x ;
174
176
}
175
177
}
176
178
@@ -221,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
221
223
torch::Tensor experts_ids,
222
224
torch::Tensor num_tokens_post_pad) {
223
225
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
224
- VLLM_DISPATCH_INTEGRAL_TYPES (
225
- topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
226
- // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
227
- // tensors
228
- const int32_t num_thread = WARP_SIZE;
229
- const int32_t shared_mem =
230
- ((num_thread + 1 ) * num_experts + (num_experts + 1 )) *
231
- sizeof (int32_t );
232
-
233
- // set dynamic shared mem
234
- auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
235
- AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
236
- (void *)kernel, shared_mem));
237
- kernel<<<1 , num_thread, shared_mem, stream>>> (
238
- topk_ids.data_ptr <scalar_t >(), sorted_token_ids.data_ptr <int32_t >(),
239
- experts_ids.data_ptr <int32_t >(),
240
- num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
241
- topk_ids.numel ());
242
- });
226
+
227
+ // If we have very large number of experts, we can no longer use shared
228
+ // memory.
229
+ // TODO(simon): the right solution should be calculating the exact right
230
+ // amount of shared memory and use that. The num_experts >= 256 is just a
231
+ // temporary solution to unblock Deepseek V3.
232
+ if (num_experts >= 96 ) {
233
+ VLLM_DISPATCH_INTEGRAL_TYPES (
234
+ topk_ids.scalar_type (), " moe_align_block_size_global_mem_kernel" , [&] {
235
+ // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
236
+ // tensors
237
+ const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
238
+
239
+ const int32_t mem_tokens_cnts =
240
+ ((num_experts + 1 ) * num_experts) * sizeof (int32_t );
241
+ const int32_t mem_cumsum = (num_experts + 1 ) * sizeof (int32_t );
242
+ // allocate global memory
243
+ int32_t * tokens_cnts;
244
+ int32_t * cumsum;
245
+ cudaMalloc (&tokens_cnts, mem_tokens_cnts);
246
+ cudaMalloc (&cumsum, mem_cumsum);
247
+
248
+ auto kernel =
249
+ vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t >;
250
+ kernel<<<1 , num_thread, 0 , stream>>> (
251
+ topk_ids.data_ptr <scalar_t >(),
252
+ sorted_token_ids.data_ptr <int32_t >(),
253
+ experts_ids.data_ptr <int32_t >(),
254
+ num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
255
+ topk_ids.numel (), tokens_cnts, cumsum);
256
+ cudaFree (tokens_cnts);
257
+ cudaFree (cumsum);
258
+ });
259
+ } else {
260
+ VLLM_DISPATCH_INTEGRAL_TYPES (
261
+ topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
262
+ // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
263
+ // tensors
264
+ const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
265
+ const int32_t shared_mem =
266
+ ((num_thread + 1 ) * num_experts + (num_experts + 1 )) *
267
+ sizeof (int32_t );
268
+
269
+ // set dynamic shared mem
270
+ auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
271
+ AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
272
+ (void *)kernel, shared_mem));
273
+ kernel<<<1 , num_thread, shared_mem, stream>>> (
274
+ topk_ids.data_ptr <scalar_t >(),
275
+ sorted_token_ids.data_ptr <int32_t >(),
276
+ experts_ids.data_ptr <int32_t >(),
277
+ num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
278
+ topk_ids.numel ());
279
+ });
280
+ }
243
281
}
244
282
245
283
void moe_sum (torch::Tensor& input, // [num_tokens, topk, hidden_size]
0 commit comments