Skip to content

Commit c26bdf3

Browse files
CalebDuminpeter
authored andcommitted
permute/unpermute kernel for moe optimization (vllm-project#14568)
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent 1f0a10c commit c26bdf3

19 files changed

+1474
-28
lines changed

CMakeLists.txt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ project(vllm_extensions LANGUAGES CXX)
1515

1616
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
1717
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
18-
1918
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
2019
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
2120

@@ -682,6 +681,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
682681
endif()
683682
endif()
684683

684+
if(VLLM_GPU_LANG STREQUAL "CUDA")
685+
set(MOE_PERMUTE_SRC
686+
"csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu"
687+
"csrc/moe/moe_permute_unpermute_op.cu")
688+
689+
set_gencode_flags_for_srcs(
690+
SRCS "${MARLIN_PERMUTE_SRC}"
691+
CUDA_ARCHS "${MOE_PERMUTE_ARCHS}")
692+
693+
list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}")
694+
endif()
685695
message(STATUS "Enabling moe extension.")
686696
define_gpu_extension_target(
687697
_moe_C
@@ -690,6 +700,8 @@ define_gpu_extension_target(
690700
SOURCES ${VLLM_MOE_EXT_SRC}
691701
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
692702
ARCHITECTURES ${VLLM_GPU_ARCHES}
703+
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
704+
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
693705
USE_SABI 3
694706
WITH_SOABI)
695707

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
9090

9191
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
9292

93-
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
93+
topk_weights, topk_ids, token_expert_indices = fused_topk(
94+
a, score, topk, renormalize=False)
9495

9596
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
9697
topk_weights: torch.Tensor, topk_ids: torch.Tensor,

benchmarks/kernels/benchmark_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def run():
115115
from vllm.model_executor.layers.fused_moe import override_config
116116
with override_config(config):
117117
if use_deep_gemm:
118-
topk_weights, topk_ids = fused_topk(x, input_gating, topk,
119-
False)
118+
topk_weights, topk_ids, token_expert_indices = fused_topk(
119+
x, input_gating, topk, False)
120120
return fused_experts(
121121
x,
122122
w1,

0 commit comments

Comments
 (0)