Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions python/sglang/srt/layers/moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
if _is_cuda:
import sgl_kernel
from sgl_kernel import (
apply_shuffle_mul_sum,
cutlass_fp4_group_mm,
fp8_blockwise_scaled_grouped_mm,
prepare_moe_input,
Expand Down Expand Up @@ -151,8 +152,8 @@ def cutlass_fused_experts_fp8(
k,
)

rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
rep_a1_scales = a1_scale[a_map]
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))

c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
Expand Down Expand Up @@ -206,9 +207,9 @@ def cutlass_fused_experts_fp8(
expert_offsets[:-1],
workspace,
)
return (
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
).sum(dim=1)

result = torch.empty((m, k), device=device, dtype=out_dtype)
return apply_shuffle_mul_sum(c2, result, c_map, topk_weights)


FLOAT4_E2M1_MAX = 6.0
Expand Down
3 changes: 2 additions & 1 deletion sgl-kernel/csrc/common_extension.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {

m.def("shuffle_rows(Tensor input, Tensor dst2src_map, Tensor output) -> ()");
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);

m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
/*
* From csrc/speculative
*/
Expand Down
12 changes: 6 additions & 6 deletions sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
// bool use_small_config = a[0].size(0) <= 128;
struct MmaConfig1 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _32, _128>;
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using MmaTileShape = Shape<_256, _32, _128>;
using ClusterShape = Shape<_2, _1, _1>; // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
using ScaleConfig =
cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
Expand Down Expand Up @@ -214,7 +214,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
torch::Tensor scales_a_t = scales_a.t();
torch::Tensor scales_b_t = scales_b.transpose(1, 2);

if (a.size(0) <= 512 && a.size(1) >= 2048) {
if (a.size(0) <= 2048 && a.size(1) >= 2048) {
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
expert_offsets,
a_ptrs,
Expand Down Expand Up @@ -247,7 +247,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
expert_offsets,
workspace);
output = output_t.t();
} else if (a.size(0) > 512 && a.size(1) >= 2048) {
} else if (a.size(0) > 2048 && a.size(1) >= 2048) {
run_get_group_gemm_starts<MmaConfig2::LayoutSFA, MmaConfig2::LayoutSFB, MmaConfig2::ScaleConfig>(
expert_offsets,
a_ptrs,
Expand Down
114 changes: 114 additions & 0 deletions sgl-kernel/csrc/moe/prepare_moe_input.cu
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,117 @@ void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2sr
shuffle_rows_caller(input_tensor, dst2src_map, output_tensor);
return;
}

template <typename scalar_t>
__global__ void apply_shuffle_mul_sum_kernel(
const scalar_t* __restrict__ input_tensor, // [m * topk, row_stride]
scalar_t* __restrict__ output_tensor, // [m, row_stride]
const int32_t* __restrict__ permutation, // [m * topk]
int m,
int topk,
int row_stride,
const scalar_t* __restrict__ factors) // [m * topk] or nullptr
{
int i = blockIdx.x; // [0, m * topk)
int d = threadIdx.x; // [0, row_stride)

if (i >= m || d >= row_stride) return;

scalar_t sum_val = 0.0;

for (int j = 0; j < topk; ++j) {
int index_2d = i * topk + j;
int src_row = permutation[index_2d];
if (src_row >= m) continue;

scalar_t val = input_tensor[src_row * row_stride + d];

scalar_t factor = 1.0;
if (factors != nullptr) {
factor = factors[index_2d];
}

sum_val += factor * val;
}

output_tensor[i * row_stride + d] = sum_val;
}

void get_apply_shuffle_mul_sum_caller(
const torch::Tensor& input_tensor, // [m * topk, row_stride], bf16/f16
torch::Tensor& output_tensor, // [m, row_stride], bf16/f16
const torch::Tensor& permutation, // [m * topk], int32
const std::optional<torch::Tensor>& factors_opt) // optional [m * topk], bf16/f16
{
TORCH_CHECK(input_tensor.dim() == 2, "input_tensor must be 2D [m * topk, row_stride]");
TORCH_CHECK(output_tensor.dim() == 2, "output_tensor must be 2D [m, row_stride]");
TORCH_CHECK(permutation.dim() == 1, "permutation must be 1D [m * topk]");

int m = output_tensor.size(0);
int topk = int(permutation.size(0) / m);
int row_stride = output_tensor.size(1);

TORCH_CHECK(permutation.size(0) == m * topk, "permutation size must match m * topk");

dim3 block(std::min(256, row_stride));
dim3 grid(m); // blockIdx.x = j, blockIdx.y = i
auto stream = at::cuda::getCurrentCUDAStream(input_tensor.device().index());

const int32_t* perm_ptr = permutation.data_ptr<int32_t>();

void* factors_ptr = nullptr;
if (factors_opt.has_value()) {
TORCH_CHECK(factors_opt->dtype() == output_tensor.dtype(), "Factors must match output dtype");
TORCH_CHECK(factors_opt->numel() == m * topk, "Factors must have shape [m * topk]");
factors_ptr = factors_opt->data_ptr();
}

if (output_tensor.scalar_type() == at::ScalarType::Half) {
const at::Half* factor_data = static_cast<const at::Half*>(factors_ptr);
apply_shuffle_mul_sum_kernel<at::Half><<<grid, block, 0, stream>>>(
input_tensor.data_ptr<at::Half>(),
output_tensor.data_ptr<at::Half>(),
perm_ptr,
m,
topk,
row_stride,
static_cast<const at::Half*>(factors_ptr));
} else if (output_tensor.scalar_type() == at::ScalarType::BFloat16) {
const c10::BFloat16* factor_data = static_cast<const c10::BFloat16*>(factors_ptr);
apply_shuffle_mul_sum_kernel<c10::BFloat16><<<grid, block, 0, stream>>>(
input_tensor.data_ptr<c10::BFloat16>(),
output_tensor.data_ptr<c10::BFloat16>(),
perm_ptr,
m,
topk,
row_stride,
static_cast<const c10::BFloat16*>(factors_ptr));
} else {
TORCH_CHECK(false, "Unsupported output dtype for cast+mul kernel: ", output_tensor.scalar_type());
}
}

/**
* @brief Applies a permutation-based shuffle, element-wise multiplication, and reduction over the second dimension.
*
* This function performs the equivalent of the following PyTorch expression:
*
* (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
*
* Specifically:
* - `input` is shuffled using the `permutation` tensor.
* - The shuffled tensor is reshaped and multiplied element-wise with `factors` (e.g., top-k weights).
* - The result is summed along dimension 1 (the top-k dimension), and stored in `output`.
*
* @param input Input tensor of shape (m * topk, k), representing c2.
* @param output Output tensor of shape (m, k), where the final reduced results are stored.
* @param permutation Index tensor (e.g., c_map) that maps positions in `input` to shuffled layout.
* @param factors Optional scaling factors (e.g., top-k weights), shape (m * topk) or (m, topk).
*/
void apply_shuffle_mul_sum(
const torch::Tensor& input,
torch::Tensor& output,
const torch::Tensor& permutation,
const std::optional<torch::Tensor>& factors) {
get_apply_shuffle_mul_sum_caller(input, output, permutation, factors);
}
6 changes: 6 additions & 0 deletions sgl-kernel/include/sgl_kernel_ops.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ void ep_moe_post_reorder(

void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor);

void apply_shuffle_mul_sum(
const torch::Tensor& input,
torch::Tensor& output,
const torch::Tensor& permutation,
const std::optional<torch::Tensor>& factors);

void cutlass_fp4_group_mm(
torch::Tensor& output,
const torch::Tensor& a,
Expand Down
1 change: 1 addition & 0 deletions sgl-kernel/python/sgl_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.moe import (
apply_shuffle_mul_sum,
cutlass_fp4_group_mm,
ep_moe_post_reorder,
ep_moe_pre_reorder,
Expand Down
11 changes: 11 additions & 0 deletions sgl-kernel/python/sgl_kernel/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,17 @@ def prepare_moe_input(
)


def apply_shuffle_mul_sum(
input,
output,
permutation,
factors,
):
torch.ops.sgl_kernel.apply_shuffle_mul_sum.default(
input, output, permutation, factors
)


def cutlass_fp4_group_mm(
a_fp4,
b_fp4,
Expand Down
Loading