Skip to content

Commit cdae77b

Browse files
authored
optimize moe_align_kernel cuda (#3347)
1 parent adeee15 commit cdae77b

File tree

3 files changed

+29
-21
lines changed

3 files changed

+29
-21
lines changed

benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ def calculate_diff(batch_size, seq_len):
163163
num_tokens_post_pad_cuda = torch.empty(
164164
(1), dtype=torch.int32, device=topk_ids.device
165165
)
166-
token_cnts_buffer = torch.empty(
166+
token_cnts_buffer = torch.zeros(
167167
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
168168
)
169-
cumsum_buffer = torch.empty(
169+
cumsum_buffer = torch.zeros(
170170
num_experts + 1, dtype=torch.int32, device=topk_ids.device
171171
)
172172

@@ -260,10 +260,10 @@ def benchmark(batch_size, seq_len, provider):
260260
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
261261
)
262262
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
263-
token_cnts_buffer = torch.empty(
263+
token_cnts_buffer = torch.zeros(
264264
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
265265
)
266-
cumsum_buffer = torch.empty(
266+
cumsum_buffer = torch.zeros(
267267
num_experts + 1, dtype=torch.int32, device=topk_ids.device
268268
)
269269

python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,12 +417,12 @@ def moe_align_block_size(
417417
num_tokens_post_pad,
418418
)
419419
else:
420-
token_cnts_buffer = torch.empty(
420+
token_cnts_buffer = torch.zeros(
421421
(num_experts + 1) * num_experts,
422422
dtype=torch.int32,
423423
device=topk_ids.device,
424424
)
425-
cumsum_buffer = torch.empty(
425+
cumsum_buffer = torch.zeros(
426426
num_experts + 1, dtype=torch.int32, device=topk_ids.device
427427
)
428428

sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,24 @@ limitations under the License.
2424

2525
#define WARP_SIZE 32
2626

27+
template <typename scalar_t>
28+
__global__ void moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
29+
int32_t* cumsum_buffer, size_t numel) {
30+
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
31+
const size_t stride = blockDim.x * gridDim.x;
32+
33+
for (size_t i = tid; i < numel; i += stride) {
34+
int32_t expert_id = topk_ids[i];
35+
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
36+
sorted_token_ids[rank_post_pad] = i;
37+
}
38+
}
39+
2740
template <typename scalar_t>
2841
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
2942
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
3043
int32_t block_size, size_t numel, int32_t* cumsum) {
3144
__shared__ int32_t shared_counts[WARP_SIZE][8];
32-
__shared__ int32_t local_offsets[256];
3345

3446
const int warp_id = threadIdx.x / WARP_SIZE;
3547
const int experts_per_warp = 8;
@@ -72,20 +84,6 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int
7284
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
7385
expert_ids[i / block_size] = threadIdx.x;
7486
}
75-
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
76-
}
77-
78-
__syncthreads();
79-
80-
// Note: For the moe_align_kernel, the primary bottleneck lies in the atomic add and non-coalesced memory writes here.
81-
// If these operations can be performed using multiple blocks, similar to the Triton version, the performance of this
82-
// kernel can achieve state-of-the-art performance across all token cases. However, once multiple blocks are used,
83-
// illegal memory access occurs. Even replacing these lines of code with the stage 4 kernel from the Triton version
84-
// results in the same issue, and a correct solution has not yet been found.
85-
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
86-
int32_t expert_id = topk_ids[i];
87-
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
88-
sorted_token_ids[rank_post_pad] = i;
8987
}
9088
}
9189

@@ -100,5 +98,15 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
10098
align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
10199
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
102100
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
101+
102+
const int block_threads = 256;
103+
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
104+
const int max_blocks = 65535;
105+
const int actual_blocks = std::min(num_blocks, max_blocks);
106+
107+
auto sort_kernel = moe_token_sort_kernel<scalar_t>;
108+
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(),
109+
sorted_token_ids.data_ptr<int32_t>(),
110+
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
103111
});
104112
}

0 commit comments

Comments
 (0)