Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][ROCm][AMD] enable fused topk_softmax kernel for moe layer #4927

Merged
merged 9 commits into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.")
add_dependencies(default _C)

message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C)

# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
# there are supported target arches.
Expand All @@ -317,8 +320,3 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
add_dependencies(default _punica_C)
endif()
endif()

if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C)
endif()
1 change: 1 addition & 0 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
&& python3 setup.py install \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cd ..


Expand Down
4 changes: 4 additions & 0 deletions csrc/cuda_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
#ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
#else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
__shfl_xor(var, lane_mask, width)
#endif

#ifndef USE_ROCM
Expand Down
27 changes: 17 additions & 10 deletions csrc/moe/topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"

#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))

namespace vllm {
namespace moe {

static constexpr int WARP_SIZE = 32;

/// Aligned array type
template <
typename T,
Expand Down Expand Up @@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
}

// From this point, thread max in all the threads have the max within the row.
Expand All @@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
}

// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
Expand Down Expand Up @@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);

// We want lower indices to "win" in every thread so we break ties this way
if (other_max > max_val || (other_max == max_val && other_expert < expert))
Expand Down Expand Up @@ -383,7 +390,7 @@ struct TopkConstants
{
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
Expand All @@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;

static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def _read_requirements(filename: str) -> List[str]:

ext_modules = []

if _is_cuda():
if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))

if not _is_neuron():
Expand Down
46 changes: 19 additions & 27 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import triton
import triton.language as tl

import vllm._moe_C as moe_kernels
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.utils import is_hip

logger = init_logger(__name__)

Expand Down Expand Up @@ -319,34 +319,26 @@ def fused_topk(

M, _ = hidden_states.shape

if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels

topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk_weights = torch.empty(M,
topk,
dtype=torch.int32,
dtype=torch.float32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
Expand Down
Loading