Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
endif()

#
# Set CUDA include flags for CXX compiler.
#
if(VLLM_GPU_LANG STREQUAL "CUDA")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include")
if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include/cccl")
endif()
endif()

#
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache.
Expand Down
17 changes: 17 additions & 0 deletions csrc/cub_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#ifndef USE_ROCM
#include <cub/cub.cuh>
#if CUB_VERSION >= 200800
#include <cuda/std/functional>
using CubAddOp = cuda::std::plus<>;
using CubMaxOp = cuda::maximum<>;
#else // if CUB_VERSION < 200800
using CubAddOp = cub::Sum;
using CubMaxOp = cub::Max;
#endif // CUB_VERSION
#else
#include <hipcub/hipcub.hpp>
using CubAddOp = cub::Sum;
using CubMaxOp = cub::Max;
#endif // USE_ROCM
13 changes: 4 additions & 9 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"

#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>

#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif

namespace vllm {

// TODO(woosuk): Further optimize this kernel.
Expand All @@ -30,7 +25,7 @@ __global__ void rms_norm_kernel(

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
Expand Down Expand Up @@ -85,7 +80,7 @@ fused_add_rms_norm_kernel(

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
Expand Down Expand Up @@ -126,7 +121,7 @@ fused_add_rms_norm_kernel(

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
Expand Down
13 changes: 4 additions & 9 deletions csrc/layernorm_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,11 @@
#include "type_convert.cuh"
#include "quantization/fp8/common.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"

#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>

#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif

namespace vllm {

// TODO(woosuk): Further optimize this kernel.
Expand All @@ -39,7 +34,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
Expand Down Expand Up @@ -100,7 +95,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
Expand Down Expand Up @@ -149,7 +144,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
Expand Down
16 changes: 3 additions & 13 deletions csrc/moe/topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"

#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#include <cuda/std/functional>
using AddOp = cuda::std::plus<float>;
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
using AddOp = cub::Sum;
#endif
#include "../cub_helpers.h"

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
Expand Down Expand Up @@ -79,7 +69,7 @@ __launch_bounds__(TPB) __global__
threadData = max(static_cast<float>(input[idx]), threadData);
}

const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp());
if (threadIdx.x == 0)
{
float_max = maxElem;
Expand All @@ -94,7 +84,7 @@ __launch_bounds__(TPB) __global__
threadData += exp((static_cast<float>(input[idx]) - float_max));
}

const auto Z = BlockReduce(tmpStorage).Reduce(threadData, AddOp());
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp());

if (threadIdx.x == 0)
{
Expand Down
11 changes: 2 additions & 9 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,10 @@

#include <cmath>

#include "../../cub_helpers.h"
#include "../../dispatch_utils.h"
#include "../vectorization_utils.cuh"

#ifndef USE_ROCM
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#else
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
#endif

static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
static constexpr auto i8_min =
Expand Down Expand Up @@ -173,7 +166,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
});
using BlockReduce = cub::BlockReduce<float, 256>;
__shared__ typename BlockReduce::TempStorage tmp;
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
float block_max = BlockReduce(tmp).Reduce(thread_max, CubMaxOp{}, blockDim.x);
__shared__ float absmax;
if (tid == 0) {
absmax = block_max;
Expand Down
9 changes: 2 additions & 7 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
#include "common.cuh"
#include "dispatch_utils.h"
#include "../../cub_helpers.h"
#include "../vectorization_utils.cuh"
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h>

#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif

namespace vllm {

template <typename scalar_t, typename fp8_type>
Expand Down Expand Up @@ -116,7 +111,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
using BlockReduce = cub::BlockReduce<float, 256>;
__shared__ typename BlockReduce::TempStorage tmp;
const float block_max =
BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x);
BlockReduce(tmp).Reduce(absmax_val, CubMaxOp{}, blockDim.x);

__shared__ float token_scale;
if (tid == 0) {
Expand Down
14 changes: 5 additions & 9 deletions csrc/quantization/fused_kernels/layernorm_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
#include "quantization/utils.cuh"
#include "quant_conversions.cuh"

#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
#include "../../cub_helpers.h"

namespace vllm {

Expand All @@ -36,7 +32,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x);

__shared__ float s_rms;
if (threadIdx.x == 0) {
Expand Down Expand Up @@ -73,7 +69,7 @@ __device__ void compute_dynamic_per_token_scales(
__shared__ typename BlockReduce::TempStorage reduceStore;
block_absmax_val_maybe =
BlockReduce(reduceStore)
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);

__shared__ float s_token_scale;
if (threadIdx.x == 0) {
Expand Down Expand Up @@ -169,7 +165,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x);

__shared__ float s_rms;
if (threadIdx.x == 0) {
Expand Down Expand Up @@ -240,7 +236,7 @@ __device__ void compute_dynamic_per_token_scales(
__shared__ typename BlockReduce::TempStorage reduceStore;
block_absmax_val_maybe =
BlockReduce(reduceStore)
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);

__shared__ float s_token_scale;
if (threadIdx.x == 0) {
Expand Down
Loading