Skip to content

Commit

Permalink
[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custo…
Browse files Browse the repository at this point in the history
  • Loading branch information
bnellnm authored and joerunde committed Jun 13, 2024
1 parent d94203e commit 869bef9
Show file tree
Hide file tree
Showing 55 changed files with 833 additions and 451 deletions.
22 changes: 6 additions & 16 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,6 @@ endif()
#
find_package(Torch REQUIRED)

#
# Normally `torch.utils.cpp_extension.CUDAExtension` would add
# `libtorch_python.so` for linking against an extension. Torch's cmake
# configuration does not include this library (presumably since the cmake
# config is used for standalone C++ binaries that link against torch).
# The `libtorch_python.so` library defines some of the glue code between
# torch/python via pybind and is required by VLLM extensions for this
# reason. So, add it by manually with `find_library` using torch's
# installed library path.
#
find_library(torch_python_LIBRARY torch_python PATHS
"${TORCH_INSTALL_PREFIX}/lib")

#
# Forward the non-CUDA device extensions to external CMake scripts.
#
Expand Down Expand Up @@ -171,7 +158,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")
"csrc/torch_bindings.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
Expand Down Expand Up @@ -218,14 +205,15 @@ define_gpu_extension_target(
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)

#
# _moe_C extension
#

set(VLLM_MOE_EXT_SRC
"csrc/moe/moe_ops.cpp"
"csrc/moe/torch_bindings.cpp"
"csrc/moe/topk_softmax_kernels.cu")

define_gpu_extension_target(
Expand All @@ -235,6 +223,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_MOE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)

#
Expand All @@ -249,7 +238,7 @@ set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/punica_ops.cu"
"csrc/punica/punica_pybind.cpp")
"csrc/punica/torch_bindings.cpp")

#
# Copy GPU compilation flags+update for punica
Expand Down Expand Up @@ -286,6 +275,7 @@ if (VLLM_PUNICA_GPU_ARCHES)
SOURCES ${VLLM_PUNICA_EXT_SRC}
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)
else()
message(WARNING "Unable to create _punica_C target because none of the "
Expand Down
6 changes: 3 additions & 3 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \
&& cd ..


Expand Down
12 changes: 6 additions & 6 deletions cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc")
#
# Check the compile flags
#
list(APPEND CXX_COMPILE_FLAGS
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-DVLLM_CPU_EXTENSION")

Expand Down Expand Up @@ -44,8 +44,8 @@ if (AVX512_FOUND)

find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
else()
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
Expand Down Expand Up @@ -73,18 +73,18 @@ set(VLLM_EXT_SRC
"csrc/cpu/cache.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/pybind.cpp")
"csrc/cpu/torch_bindings.cpp")

define_gpu_extension_target(
_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
WITH_SOABI
USE_SABI 3
WITH_SOABI
)

add_custom_target(default)
message(STATUS "Enabling C extension.")
add_dependencies(default _C)

11 changes: 8 additions & 3 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
set(Python_EXECUTABLE ${EXECUTABLE})
find_package(Python COMPONENTS Interpreter Development.Module)
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
if (NOT Python_FOUND)
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
endif()
Expand Down Expand Up @@ -294,14 +294,15 @@ endmacro()
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
# LIBRARIES <libraries> - Extra link libraries.
# WITH_SOABI - Generate library with python SOABI suffix name.
# USE_SABI <version> - Use python stable api <version>
#
# Note: optimization level/debug info is set via cmake build type.
#
function (define_gpu_extension_target GPU_MOD_NAME)
cmake_parse_arguments(PARSE_ARGV 1
GPU
"WITH_SOABI"
"DESTINATION;LANGUAGE"
"DESTINATION;LANGUAGE;USE_SABI"
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")

# Add hipify preprocessing step when building with HIP/ROCm.
Expand All @@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME)
set(GPU_WITH_SOABI)
endif()

Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI})
if (GPU_USE_SABI)
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
else()
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
endif()

if (GPU_LANGUAGE STREQUAL "HIP")
# Make this target dependent on the hipify preprocessor step.
Expand Down
2 changes: 1 addition & 1 deletion csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#include <cmath>
Expand Down
34 changes: 18 additions & 16 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* limitations under the License.
*/

#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
Expand Down Expand Up @@ -808,16 +808,17 @@ void paged_attention_v1(
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, // [num_heads]
float scale,
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int block_size, int max_seq_len,
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);

DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
Expand Down Expand Up @@ -972,16 +973,17 @@ void paged_attention_v2(
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, // [num_heads]
float scale,
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int block_size, int max_seq_len,
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
Expand All @@ -990,4 +992,4 @@ void paged_attention_v2(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
14 changes: 9 additions & 5 deletions csrc/cache.h
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
#pragma once

#include <torch/extension.h>
#include <torch/all.h>

#include <map>
#include <vector>

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping);

void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping);

void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const float kv_scale);
const std::string& kv_cache_dtype,
const double kv_scale);

void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
Expand All @@ -25,4 +29,4 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,

// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const float scale, const std::string& kv_cache_dtype);
const double scale, const std::string& kv_cache_dtype);
13 changes: 8 additions & 5 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

Expand Down Expand Up @@ -95,8 +95,11 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,

} // namespace vllm

void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
Expand Down Expand Up @@ -255,7 +258,7 @@ void reshape_and_cache(
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const float kv_scale) {
const std::string& kv_cache_dtype, const double kv_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
Expand Down Expand Up @@ -334,7 +337,7 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,

// Only for testing.
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const float kv_scale, const std::string& kv_cache_dtype) {
const double kv_scale, const std::string& kv_cache_dtype) {
torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
Expand Down
26 changes: 14 additions & 12 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,12 +420,13 @@ void paged_attention_v1_impl_launcher(

void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
Expand Down Expand Up @@ -738,12 +739,13 @@ void paged_attention_v2_impl_launcher(
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
Expand Down
13 changes: 8 additions & 5 deletions csrc/cpu/cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

namespace {
template <typename scalar_t>
void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& mapping_pairs,
const int element_num_per_block,
const int layer_num) {
Expand Down Expand Up @@ -82,8 +82,11 @@ void reshape_and_cache_cpu_impl(
}
}; // namespace

void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) {
unsigned num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
Expand All @@ -104,7 +107,7 @@ void copy_blocks(std::vector<torch::Tensor>& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, float kv_scale) {
const std::string& kv_cache_dtype, double kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);

int num_tokens = key.size(0);
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/cpu_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#define CPU_TYPES_HPP

#include <immintrin.h>
#include <torch/extension.h>
#include <torch/all.h>

namespace vec_op {

Expand Down
Loading

0 comments on commit 869bef9

Please sign in to comment.