@@ -24,12 +24,24 @@ limitations under the License.
24
24
25
25
#define WARP_SIZE 32
26
26
27
+ template <typename scalar_t >
28
+ __global__ void moe_token_sort_kernel (scalar_t * __restrict__ topk_ids, int32_t * sorted_token_ids,
29
+ int32_t * cumsum_buffer, size_t numel) {
30
+ const size_t tid = blockIdx .x * blockDim .x + threadIdx .x ;
31
+ const size_t stride = blockDim .x * gridDim .x ;
32
+
33
+ for (size_t i = tid; i < numel; i += stride) {
34
+ int32_t expert_id = topk_ids[i];
35
+ int32_t rank_post_pad = atomicAdd (&cumsum_buffer[expert_id], 1 );
36
+ sorted_token_ids[rank_post_pad] = i;
37
+ }
38
+ }
39
+
27
40
template <typename scalar_t >
28
41
__global__ void moe_align_block_size_kernel (scalar_t * __restrict__ topk_ids, int32_t * sorted_token_ids,
29
42
int32_t * expert_ids, int32_t * total_tokens_post_pad, int32_t num_experts,
30
43
int32_t block_size, size_t numel, int32_t * cumsum) {
31
44
__shared__ int32_t shared_counts[WARP_SIZE][8 ];
32
- __shared__ int32_t local_offsets[256 ];
33
45
34
46
const int warp_id = threadIdx .x / WARP_SIZE;
35
47
const int experts_per_warp = 8 ;
@@ -72,20 +84,6 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int
72
84
for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ]; i += block_size) {
73
85
expert_ids[i / block_size] = threadIdx .x ;
74
86
}
75
- local_offsets[threadIdx .x ] = cumsum[threadIdx .x ];
76
- }
77
-
78
- __syncthreads ();
79
-
80
- // Note: For the moe_align_kernel, the primary bottleneck lies in the atomic add and non-coalesced memory writes here.
81
- // If these operations can be performed using multiple blocks, similar to the Triton version, the performance of this
82
- // kernel can achieve state-of-the-art performance across all token cases. However, once multiple blocks are used,
83
- // illegal memory access occurs. Even replacing these lines of code with the stage 4 kernel from the Triton version
84
- // results in the same issue, and a correct solution has not yet been found.
85
- for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
86
- int32_t expert_id = topk_ids[i];
87
- int32_t rank_post_pad = atomicAdd (&local_offsets[expert_id], 1 );
88
- sorted_token_ids[rank_post_pad] = i;
89
87
}
90
88
}
91
89
@@ -100,5 +98,15 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
100
98
align_kernel<<<1 , 1024 , 0 , stream>>> (topk_ids.data_ptr <scalar_t >(), sorted_token_ids.data_ptr <int32_t >(),
101
99
experts_ids.data_ptr <int32_t >(), num_tokens_post_pad.data_ptr <int32_t >(),
102
100
num_experts, block_size, topk_ids.numel (), cumsum_buffer.data_ptr <int32_t >());
101
+
102
+ const int block_threads = 256 ;
103
+ const int num_blocks = (topk_ids.numel () + block_threads - 1 ) / block_threads;
104
+ const int max_blocks = 65535 ;
105
+ const int actual_blocks = std::min (num_blocks, max_blocks);
106
+
107
+ auto sort_kernel = moe_token_sort_kernel<scalar_t >;
108
+ sort_kernel<<<actual_blocks, block_threads, 0 , stream>>> (topk_ids.data_ptr <scalar_t >(),
109
+ sorted_token_ids.data_ptr <int32_t >(),
110
+ cumsum_buffer.data_ptr <int32_t >(), topk_ids.numel ());
103
111
});
104
112
}
0 commit comments