Skip to content

Revert "[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)" #19512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUTLASS MoE kernels

# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
# on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
# to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
Expand Down
63 changes: 51 additions & 12 deletions benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.model_executor.layers.fused_moe import (
cutlass_moe_fp8,
fused_experts,
fused_topk,
)
Expand Down Expand Up @@ -70,9 +70,18 @@ def bench_run(
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)

ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)

for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_q_notransp = w1_q.clone()
w2_q_notransp = w2_q.clone()
w1_q = w1_q.transpose(1, 2)
w2_q = w2_q.transpose(1, 2)

score = torch.randn((m, num_experts), device="cuda", dtype=dtype)

Expand Down Expand Up @@ -113,17 +122,25 @@ def run_cutlass_moe(
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
num_repeats: int,
):
for _ in range(num_repeats):
cutlass_moe_fp8(
a,
w1,
w2,
topk_weights,
topk_ids,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
a1_scale=a_scale,
)

Expand All @@ -136,6 +153,10 @@ def run_cutlass_from_graph(
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
Expand All @@ -144,10 +165,14 @@ def run_cutlass_from_graph(
a,
w1_q,
w2_q,
topk_weights,
topk_ids,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
a1_scale=a_scale,
)

Expand Down Expand Up @@ -193,6 +218,10 @@ def replay_graph(graph, num_repeats):
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
)
torch.cuda.synchronize()

Expand All @@ -201,8 +230,8 @@ def replay_graph(graph, num_repeats):
with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph(
a,
w1_q,
w2_q,
w1_q_notransp,
w2_q_notransp,
topk_weights,
topk_ids,
w1_scale,
Expand All @@ -221,12 +250,18 @@ def replay_graph(graph, num_repeats):
"w2": w2,
"score": score,
"topk": topk,
"w1_q_notransp": w1_q_notransp,
"w2_q_notransp": w2_q_notransp,
# Cutlass params
"a_scale": a_scale,
"w1_q": w1_q,
"w2_q": w2_q,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"ab_strides1": ab_strides1,
"c_strides1": c_strides1,
"ab_strides2": ab_strides2,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
Expand All @@ -244,8 +279,8 @@ def replay_graph(graph, num_repeats):
# Warmup
run_triton_moe(
a,
w1_q,
w2_q,
w1_q_notransp,
w2_q_notransp,
topk_weights,
topk_ids,
w1_scale,
Expand All @@ -256,7 +291,7 @@ def replay_graph(graph, num_repeats):

results.append(
benchmark.Timer(
stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down Expand Up @@ -287,12 +322,16 @@ def replay_graph(graph, num_repeats):
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
num_warmup,
)

results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down
11 changes: 1 addition & 10 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,7 @@ void cutlass_moe_mm(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
torch::Tensor const& b_strides, torch::Tensor const& c_strides);

void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
Expand All @@ -252,14 +251,6 @@ void get_cutlass_moe_mm_data(
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);

void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k);

void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand Down
29 changes: 10 additions & 19 deletions csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ void run_cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
Expand Down Expand Up @@ -114,23 +113,19 @@ void run_cutlass_moe_mm_sm90(
if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
problem_sizes, a_strides, b_strides, c_strides);
} else if (k >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
problem_sizes, a_strides, b_strides, c_strides);
} else if (m <= 16) {
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
problem_sizes, a_strides, b_strides, c_strides);
} else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
problem_sizes, a_strides, b_strides, c_strides);
}
}

Expand All @@ -139,18 +134,15 @@ void dispatch_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
problem_sizes, a_strides, b_strides, c_strides);
} else {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
problem_sizes, a_strides, b_strides, c_strides);
}
}

Expand All @@ -161,9 +153,8 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
c_strides);
}
6 changes: 4 additions & 2 deletions csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,17 @@ void cutlass_group_gemm_caller(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;

int num_experts = static_cast<int>(expert_offsets.size(0));
int k_size = a_tensors.size(1);
int n_size = out_tensors.size(1);

bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;

auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());

auto options_int =
Expand Down
42 changes: 4 additions & 38 deletions csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

constexpr uint64_t THREADS_PER_EXPERT = 512;

__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
Expand Down Expand Up @@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
}
}

__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation,
int32_t* output_permutation,
Expand Down Expand Up @@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(

int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
Expand All @@ -120,44 +120,10 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
}
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
topk_ids.size(1));
}

__global__ void compute_pplx_data(int32_t* expert_offsets,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
const int32_t* __restrict__ expert_num_tokens,
const int padded_m, const int n,
const int k) {
int expert_idx = threadIdx.x;

expert_offsets[expert_idx] = expert_idx * padded_m;
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
problem_sizes1[expert_idx * 3 + 2] = k;
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes2[expert_idx * 3 + 1] = k;
problem_sizes2[expert_idx * 3 + 2] = n;
}

void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());

compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}
Loading