Skip to content

[Hardware][NVIDIA] FP4 MoE kernel optimization #19110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_cutlass_fp4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def bench_run(

score = torch.randn((m, num_experts), device=device, dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)

quant_blocksize = 16
w1_blockscale = torch.empty(
Expand Down
6 changes: 5 additions & 1 deletion csrc/moe/moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,8 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
int64_t BLOCK_SIZE_K, int64_t bit);
#endif

bool moe_permute_unpermute_supported();
bool moe_permute_unpermute_supported();

void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor);
56 changes: 56 additions & 0 deletions csrc/moe/moe_permute_unpermute_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,62 @@ void moe_unpermute(
});
}

template <typename T>
__global__ void shuffleInputRowsKernel(const T* input,
const int32_t* dst2src_map, T* output,
int64_t num_src_rows,
int64_t num_dst_rows, int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];

if (blockIdx.x < num_dst_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;

// Duplicate and permute rows
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
auto* dest_row_ptr =
reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);

int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = num_cols / ELEM_PER_THREAD;

for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}

void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor) {
TORCH_CHECK(input_tensor.scalar_type() == output_tensor.scalar_type(),
"Input and output tensors must have the same data type");

auto stream = at::cuda::getCurrentCUDAStream().stream();
int64_t const blocks = output_tensor.size(0);
int64_t const threads = 256;
int64_t const num_dest_rows = output_tensor.size(0);
int64_t const num_src_rows = input_tensor.size(0);
int64_t const num_cols = input_tensor.size(1);

TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
"num_cols must be divisible by 128 / "
"sizeof(input_tensor.scalar_type()) / 8");

MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}

#else

void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
Expand Down
18 changes: 12 additions & 6 deletions csrc/moe/permute_unpermute_kernels/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
__VA_ARGS__(); \
break; \
}
#define MOE_DISPATCH_FLOAT_CASE(...) \
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define MOE_DISPATCH_FLOAT_CASE(...) \
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)

#define MOE_DISPATCH(TYPE, ...) \
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
Expand All @@ -39,6 +40,11 @@ template <>
struct ScalarType2CudaType<at::ScalarType::BFloat16> {
using type = __nv_bfloat16;
};
// uint8 for packed fp4
template <>
struct ScalarType2CudaType<at::ScalarType::Byte> {
using type = uint8_t;
};

// #if __CUDA_ARCH__ >= 890
// fp8
Expand Down
6 changes: 6 additions & 0 deletions csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def("moe_permute_unpermute_supported() -> bool");
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);

// Row shuffle for MoE
m.def(
"shuffle_rows(Tensor input_tensor, Tensor dst2src_map, Tensor! "
"output_tensor) -> ()");
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);

#endif
}

Expand Down
3 changes: 2 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k);
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);

void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
Expand Down
36 changes: 31 additions & 5 deletions csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ __global__ void compute_expert_offsets(
}
}

__global__ void compute_expert_blockscale_offsets(
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
int32_t* blockscale_offsets, int32_t* atomic_buffer,
const int num_experts) {
int32_t tot_offset = 0;
int32_t tot_offset_round = 0;
expert_offsets[0] = 0;
blockscale_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3];
expert_offsets[i + 1] = tot_offset;
tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128;
blockscale_offsets[i + 1] = tot_offset_round;
}
}

__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation,
Expand Down Expand Up @@ -77,7 +94,8 @@ void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k) {
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
Expand All @@ -89,10 +107,18 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
if (blockscale_offsets.has_value()) {
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
} else {
compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
}
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()),
Expand Down
9 changes: 6 additions & 3 deletions csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k);
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
#endif

void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
Expand Down Expand Up @@ -224,15 +225,17 @@ void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k) {
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k);
output_permutation, num_experts, n, k,
blockscale_offsets);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
Expand Down
2 changes: 1 addition & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()",
" int n, int k, Tensor? blockscale_offsets) -> ()",
{stride_tag});
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);

Expand Down
5 changes: 4 additions & 1 deletion tests/kernels/moe/test_nvfp4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
w2[expert], w2_gs[expert])

score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)

a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
Expand Down
45 changes: 33 additions & 12 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,11 +806,16 @@ def cutlass_scaled_sparse_mm(
return out


def get_cutlass_moe_mm_data(
topk_ids: torch.Tensor, expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor,
input_permutation: torch.Tensor, output_permutation: torch.Tensor,
num_experts: int, n: int, k: int):
def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
input_permutation: torch.Tensor,
output_permutation: torch.Tensor,
num_experts: int,
n: int,
k: int,
blockscale_offsets: Optional[torch.Tensor] = None):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE.
Expand All @@ -828,12 +833,31 @@ def get_cutlass_moe_mm_data(
before executing the MMs.
- output_permutation: Permutation that must be used to shuffle the output
after executing the MMs.
- blockscale_offsets: Optional argument passed for fp4 moe. Indices that
mark at which block scale index each expert begins
its computation. The number of block scale rows
computed with expert E is blockscale_offsets[E + 1] -
blockscale_offsets[E]
"""
return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
problem_sizes1, problem_sizes2,
input_permutation,
output_permutation,
num_experts, n, k)
num_experts, n, k,
blockscale_offsets)


def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
"""
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
This is used in MoE to permute the input tensor before performing grouped matrix multiplications.
"""
num_tokens_permuted = dst2src_map.shape[0]
output_tensor = torch.empty((num_tokens_permuted, input_tensor.shape[1]),
device=input_tensor.device,
dtype=input_tensor.dtype)
torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor)
return output_tensor


def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
Expand Down Expand Up @@ -1085,14 +1109,12 @@ def scaled_fp4_experts_quant(
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
topk: int,
expert_map: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
packed MoE Inputs.
Args:
input: The input tensor to be quantized to FP4
expert_map: The expert map tensor
input_tensor: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor
Expand All @@ -1104,14 +1126,13 @@ def scaled_fp4_experts_quant(
assert input_tensor.ndim == 2, (
f'input.ndim needs to be == 2, but got {input_tensor.ndim}.')

input_tensor = input_tensor[
expert_map] if expert_map is not None else input_tensor
m_numtopk, k = input_tensor.shape
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
# from running out of memory. This value can also be increased to support
# larger models.
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
m_numtopk, k = input_tensor.shape

assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f"{MAX_TOKENS_PER_EXPERT})"
Expand Down
15 changes: 8 additions & 7 deletions vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
num_topk = topk_ids.shape[1]

expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,2n,k))
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,n,k))
Expand All @@ -344,20 +345,18 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, e, n, k)
problem_sizes2, a_map, c_map, e, n, k,
blockscale_offsets)

tokens_per_expert = problem_sizes1[:, 0]
rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128
blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device)
blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0)
a = ops.shuffle_rows(a, a_map)

rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
a,
a1_gscale,
expert_offsets,
blockscale_offsets,
num_topk,
expert_map=a_map)
)

c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
w1_blockscale, w1_alphas, problem_sizes1,
Expand All @@ -378,6 +377,8 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
w2_alphas, problem_sizes2, expert_offsets[:-1],
blockscale_offsets[:-1], out_dtype, device)
del int_fp4, int_blockscale
out = (c2[c_map].view(m, num_topk, k) *

c2 = ops.shuffle_rows(c2, c_map)
out = (c2.view(m, num_topk, k) *
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
return out.to(dtype=out_dtype)