Skip to content

[Kernel] vLLM Windows CUDA support #14891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 63 additions & 23 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
# Please keep this in sync with FetchContent_Declare line below.
set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use")
set(CUTLASS_ENABLE_CUBLAS ON CACHE BOOL "cuBLAS enabled for Cutlass")
set(CUBLAS_ENABLED ON CACHE BOOL "cuBLAS enabled")
if (WIN32)
list(APPEND VLLM_GPU_FLAGS "-O2")
list(APPEND VLLM_GPU_FLAGS "-Xptxas=-O2")
list(APPEND VLLM_GPU_FLAGS "-Xcompiler=/O2")
set(CMAKE_CUDA_FLAGS_DEBUG "")
endif()

# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
Expand All @@ -262,18 +270,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "The VLLM_CUTLASS_SRC_DIR is set, using ${VLLM_CUTLASS_SRC_DIR} for compilation")
FetchContent_Declare(cutlass SOURCE_DIR ${VLLM_CUTLASS_SRC_DIR})
else()
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# Please keep this in sync with CUTLASS_REVISION line above.
GIT_TAG v3.8.0
GIT_PROGRESS TRUE

# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
)
if (WIN32)
FetchContent_Declare(
cutlass
# For Windows, use fixed v3.8.0 fix-sm100gemm custom branch until Nvidia the fixes into official repo. More info https://github.com/NVIDIA/cutlass/pull/2167
GIT_REPOSITORY https://github.com/SystemPanic/cutlass.git
# Please keep this in sync with CUTLASS_REVISION line above.
GIT_TAG fix-sm100gemm
GIT_PROGRESS TRUE

# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
)
else()
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# Please keep this in sync with CUTLASS_REVISION line above.
GIT_TAG v3.8.0
GIT_PROGRESS TRUE

# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
)
endif()
endif()
FetchContent_MakeAvailable(cutlass)

Expand Down Expand Up @@ -534,17 +558,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()

message(STATUS "Enabling C extension.")
define_gpu_extension_target(
_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)
if (DEFINED _CUBLAS_LIBRARY)
define_gpu_extension_target(
_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
LIBRARIES ${_CUBLAS_LIBRARY}
USE_SABI 3
WITH_SOABI)
else()
define_gpu_extension_target(
_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)
endif()


# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
Expand Down
9 changes: 9 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,13 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,

// Create data structures for the kernel.
// Create an array of pointers to the key and value caches.
#ifdef _WIN32
int64_t* key_cache_ptrs = new int64_t[num_layers];
int64_t* value_cache_ptrs = new int64_t[num_layers];
#else
int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers];
#endif
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
Expand Down Expand Up @@ -167,6 +172,10 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(), numel_per_block);
}));
#ifdef _WIN32
delete[] key_cache_ptrs;
delete[] value_cache_ptrs;
#endif
}

// copy blocks kernel for MLA (assumes a joint KV-cache)
Expand Down
4 changes: 4 additions & 0 deletions csrc/core/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@

inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
#ifdef _WIN32
return 1 << (CHAR_BIT * sizeof(num) - __lzcnt(num - 1));
#else
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
#endif
}
4 changes: 4 additions & 0 deletions csrc/cpu/cpu_types_vsx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,11 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
}

inline void prefetch(const void* addr) {
#ifdef _WIN32
__asm("dcbt 0, %0" : : "r"(addr) : "memory");
#else
__asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
#endif
}

}; // namespace vec_op
Expand Down
4 changes: 4 additions & 0 deletions csrc/cumem_allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ extern "C" {
#include <cuda_runtime_api.h>
#include <cuda.h>

#ifndef ssize_t
#define ssize_t ptrdiff_t
#endif

char error_msg[10240]; // 10KB buffer to store error messages
CUresult no_error = CUresult(0);
CUresult error_code = no_error; // store error code
Expand Down
29 changes: 16 additions & 13 deletions csrc/mamba/causal_conv1d/causal_conv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,20 @@

#include "static_switch.h"


// Helper function to set the maximum dynamic shared memory attribute.
// This function is defined at file scope so that the preprocessor directives
// are not embedded inside a lambda.
template <typename KernelT>
void set_max_dynamic_shared_memory(KernelT kernel, int smem_size) {
if (smem_size >= 48 * 1024) {
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
#else
C10_CUDA_CHECK(cudaFuncSetAttribute((void*)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior.\n" << std::endl;
#endif
}
}

#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")

Expand Down Expand Up @@ -499,18 +512,8 @@ void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
dim3 grid(params.batch, params.dim);

auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;

if (kSmemSize >= 48 * 1024) {
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#else
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
#endif
}

set_max_dynamic_shared_memory(kernel, kSmemSize);
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down
14 changes: 9 additions & 5 deletions csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
#include <c10/util/Half.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK

#ifdef _WIN32
#include <math.h>
#endif

#ifndef USE_ROCM
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
Expand Down Expand Up @@ -308,15 +312,15 @@ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
// processing 1 row.
constexpr int kNRows = 1;
static constexpr int kNRows = 1;
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
constexpr bool kIsVariableB = true;
constexpr bool kIsVariableC = true;
constexpr bool kHasZ = true;
static constexpr bool kIsVariableB = true;
static constexpr bool kIsVariableC = true;
static constexpr bool kHasZ = true;
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
static constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
Expand Down
4 changes: 2 additions & 2 deletions csrc/mamba/mamba_ssm/static_switch.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
20 changes: 10 additions & 10 deletions csrc/quantization/awq/gemm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,15 @@ __global__ void __launch_bounds__(64)
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
{
unsigned int addr;
__asm__ __volatile__(
asm volatile(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
"addr; }\n"
: "=r"(addr)
: "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
(((((int)threadIdx.x) & 15) * 40) +
((((int)threadIdx.x) >> 4) * 8)))));

__asm__ __volatile__(
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned*)(A_shared_warp + 0))[0]),
Expand All @@ -197,7 +197,7 @@ __global__ void __launch_bounds__(64)
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
{
unsigned int addr;
__asm__ __volatile__(
asm volatile(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
"addr; }\n"
: "=r"(addr)
Expand All @@ -206,7 +206,7 @@ __global__ void __launch_bounds__(64)
(ax1_0 * 16))])) +
(((((int)threadIdx.x) & 15) * (N + 8)) +
((((int)threadIdx.x) >> 4) * 8)))));
__asm__ __volatile__(
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
Expand All @@ -219,7 +219,7 @@ __global__ void __launch_bounds__(64)
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
{
__asm__ __volatile__(
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
Expand All @@ -236,7 +236,7 @@ __global__ void __launch_bounds__(64)
}

{
__asm__ __volatile__(
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
Expand All @@ -253,7 +253,7 @@ __global__ void __launch_bounds__(64)
}

{
__asm__ __volatile__(
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
Expand All @@ -270,7 +270,7 @@ __global__ void __launch_bounds__(64)
}

{
__asm__ __volatile__(
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
Expand All @@ -287,7 +287,7 @@ __global__ void __launch_bounds__(64)
}
#else
{
__asm__ __volatile__(
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
"%13};\n"
Expand All @@ -308,7 +308,7 @@ __global__ void __launch_bounds__(64)
}

{
__asm__ __volatile__(
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
"%13};\n"
Expand Down
4 changes: 2 additions & 2 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) {
float const min_scaling_factor =
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);
1.0f / (fp8_e4m3_adjusted_max<fp8_type>::val() * 512.f);

int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
Expand Down Expand Up @@ -67,7 +67,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
token_scale = max(token_scale / fp8_e4m3_adjusted_max<fp8_type>::val(),
min_scaling_factor);
scale[token_idx] = token_scale;
}
Expand Down
10 changes: 3 additions & 7 deletions csrc/quantization/fp8/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fnuz> {
}
};

template <typename T>
MAYBE_HOST_DEVICE static constexpr T fp8_e4m3_adjusted_max_v =
fp8_e4m3_adjusted_max<T>::val();

namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
Expand All @@ -76,8 +72,8 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
x = val / scale;
}

float r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
float r = fmax(-fp8_e4m3_adjusted_max<fp8_type>::val(),
fmin(x, fp8_e4m3_adjusted_max<fp8_type>::val()));
#ifndef USE_ROCM
return static_cast<fp8_type>(r);
#else
Expand Down Expand Up @@ -123,7 +119,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max_v<fp8_type>);
atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max<fp8_type>::val());
}
}

Expand Down
Loading
Loading