From 29c5a5cf78ba6cb458e465435fed6123cb8c153f Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Sun, 9 Jun 2024 16:23:30 -0400 Subject: [PATCH] [Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047) --- CMakeLists.txt | 22 +- Dockerfile.rocm | 6 +- cmake/cpu_extension.cmake | 12 +- cmake/utils.cmake | 11 +- csrc/activation_kernels.cu | 2 +- csrc/attention/attention_kernels.cu | 34 ++- csrc/cache.h | 14 +- csrc/cache_kernels.cu | 13 +- csrc/cpu/attention.cpp | 26 +- csrc/cpu/cache.cpp | 13 +- csrc/cpu/cpu_types.hpp | 2 +- csrc/cpu/layernorm.cpp | 4 +- csrc/cpu/pos_encoding.cpp | 2 +- csrc/cpu/pybind.cpp | 43 --- csrc/cpu/torch_bindings.cpp | 106 +++++++ csrc/cuda_utils.h | 6 +- csrc/cuda_utils_kernels.cu | 6 +- csrc/custom_all_reduce.cu | 22 +- csrc/dispatch_utils.h | 2 +- csrc/layernorm_kernels.cu | 6 +- csrc/moe/moe_ops.cpp | 8 - csrc/moe/moe_ops.h | 2 +- csrc/moe/topk_softmax_kernels.cu | 2 +- csrc/moe/torch_bindings.cpp | 12 + csrc/moe_align_block_size_kernels.cu | 6 +- csrc/ops.h | 68 +++-- csrc/pos_encoding_kernels.cu | 12 +- csrc/punica/punica_ops.cu | 6 +- csrc/punica/punica_ops.h | 6 +- csrc/punica/punica_pybind.cpp | 13 - csrc/punica/torch_bindings.cpp | 18 ++ csrc/pybind.cpp | 114 ------- csrc/quantization/aqlm/gemm_kernels.cu | 2 +- csrc/quantization/awq/gemm_kernels.cu | 8 +- .../compressed_tensors/int8_quant_kernels.cu | 2 +- .../cutlass_w8a8/scaled_mm_dq_c2x.cu | 2 +- .../cutlass_w8a8/scaled_mm_dq_c3x.cu | 2 +- .../cutlass_w8a8/scaled_mm_dq_entry.cu | 2 +- csrc/quantization/fp8/common.cu | 2 +- csrc/quantization/gptq/q_gemm.cu | 6 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 2 +- csrc/quantization/gptq_marlin/gptq_marlin.cuh | 2 +- .../marlin/dense/marlin_cuda_kernel.cu | 2 +- .../marlin/sparse/marlin_24_cuda_kernel.cu | 2 +- .../squeezellm/quant_cuda_kernel.cu | 1 - csrc/registration.h | 22 ++ csrc/torch_bindings.cpp | 283 ++++++++++++++++++ setup.py | 2 +- tests/kernels/test_int8_quant.py | 7 +- vllm/_custom_ops.py | 217 ++++++++++---- vllm/attention/backends/flash_attn.py | 10 +- .../device_communicators/custom_all_reduce.py | 34 ++- vllm/lora/punica.py | 45 ++- .../layers/fused_moe/fused_moe.py | 3 +- vllm/utils.py | 7 +- 55 files changed, 833 insertions(+), 451 deletions(-) delete mode 100644 csrc/cpu/pybind.cpp create mode 100644 csrc/cpu/torch_bindings.cpp delete mode 100644 csrc/moe/moe_ops.cpp create mode 100644 csrc/moe/torch_bindings.cpp delete mode 100644 csrc/punica/punica_pybind.cpp create mode 100644 csrc/punica/torch_bindings.cpp delete mode 100644 csrc/pybind.cpp create mode 100644 csrc/registration.h create mode 100644 csrc/torch_bindings.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index a197063f33601..ad6736c47f459 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,19 +66,6 @@ endif() # find_package(Torch REQUIRED) -# -# Normally `torch.utils.cpp_extension.CUDAExtension` would add -# `libtorch_python.so` for linking against an extension. Torch's cmake -# configuration does not include this library (presumably since the cmake -# config is used for standalone C++ binaries that link against torch). -# The `libtorch_python.so` library defines some of the glue code between -# torch/python via pybind and is required by VLLM extensions for this -# reason. So, add it by manually with `find_library` using torch's -# installed library path. -# -find_library(torch_python_LIBRARY torch_python PATHS - "${TORCH_INSTALL_PREFIX}/lib") - # # Forward the non-CUDA device extensions to external CMake scripts. # @@ -171,7 +158,7 @@ set(VLLM_EXT_SRC "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" - "csrc/pybind.cpp") + "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") include(FetchContent) @@ -218,6 +205,7 @@ define_gpu_extension_target( COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + USE_SABI 3 WITH_SOABI) # @@ -225,7 +213,7 @@ define_gpu_extension_target( # set(VLLM_MOE_EXT_SRC - "csrc/moe/moe_ops.cpp" + "csrc/moe/torch_bindings.cpp" "csrc/moe/topk_softmax_kernels.cu") define_gpu_extension_target( @@ -235,6 +223,7 @@ define_gpu_extension_target( SOURCES ${VLLM_MOE_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 WITH_SOABI) # @@ -249,7 +238,7 @@ set(VLLM_PUNICA_EXT_SRC "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" "csrc/punica/punica_ops.cu" - "csrc/punica/punica_pybind.cpp") + "csrc/punica/torch_bindings.cpp") # # Copy GPU compilation flags+update for punica @@ -286,6 +275,7 @@ if (VLLM_PUNICA_GPU_ARCHES) SOURCES ${VLLM_PUNICA_EXT_SRC} COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} + USE_SABI 3 WITH_SOABI) else() message(WARNING "Unable to create _punica_C target because none of the " diff --git a/Dockerfile.rocm b/Dockerfile.rocm index e30a2aaf30209..954958df88fc0 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -106,9 +106,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ && python3 setup.py install \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \ && cd .. diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 0cf37769a6960..61d4843838ba0 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -12,7 +12,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc") # # Check the compile flags # -list(APPEND CXX_COMPILE_FLAGS +list(APPEND CXX_COMPILE_FLAGS "-fopenmp" "-DVLLM_CPU_EXTENSION") @@ -44,8 +44,8 @@ if (AVX512_FOUND) find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) - if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND - CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") else() message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") @@ -73,7 +73,7 @@ set(VLLM_EXT_SRC "csrc/cpu/cache.cpp" "csrc/cpu/layernorm.cpp" "csrc/cpu/pos_encoding.cpp" - "csrc/cpu/pybind.cpp") + "csrc/cpu/torch_bindings.cpp") define_gpu_extension_target( _C @@ -81,10 +81,10 @@ define_gpu_extension_target( LANGUAGE CXX SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${CXX_COMPILE_FLAGS} - WITH_SOABI + USE_SABI 3 + WITH_SOABI ) add_custom_target(default) message(STATUS "Enabling C extension.") add_dependencies(default _C) - diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 00c81e4d00ad8..f3c1286dd8498 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -5,7 +5,7 @@ macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) file(REAL_PATH ${EXECUTABLE} EXECUTABLE) set(Python_EXECUTABLE ${EXECUTABLE}) - find_package(Python COMPONENTS Interpreter Development.Module) + find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule) if (NOT Python_FOUND) message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") endif() @@ -294,6 +294,7 @@ endmacro() # INCLUDE_DIRECTORIES - Extra include directories. # LIBRARIES - Extra link libraries. # WITH_SOABI - Generate library with python SOABI suffix name. +# USE_SABI - Use python stable api # # Note: optimization level/debug info is set via cmake build type. # @@ -301,7 +302,7 @@ function (define_gpu_extension_target GPU_MOD_NAME) cmake_parse_arguments(PARSE_ARGV 1 GPU "WITH_SOABI" - "DESTINATION;LANGUAGE" + "DESTINATION;LANGUAGE;USE_SABI" "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") # Add hipify preprocessing step when building with HIP/ROCm. @@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME) set(GPU_WITH_SOABI) endif() - Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI}) + if (GPU_USE_SABI) + Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}") + else() + Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}") + endif() if (GPU_LANGUAGE STREQUAL "HIP") # Make this target dependent on the hipify preprocessor step. diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 867f63f12de4b..86ac2e75e78ee 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,5 +1,5 @@ #include -#include +#include #include #include diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 8f89f89786c3b..91083481705cb 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -17,7 +17,7 @@ * limitations under the License. */ -#include +#include #include #include #include @@ -808,16 +808,17 @@ void paged_attention_v1( torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] - int block_size, int max_seq_len, + int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, @@ -972,16 +973,17 @@ void paged_attention_v2( torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] - int block_size, int max_seq_len, + int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) @@ -990,4 +992,4 @@ void paged_attention_v2( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP \ No newline at end of file +#undef DIVIDE_ROUND_UP diff --git a/csrc/cache.h b/csrc/cache.h index 435ae3e57f555..86caa9345361d 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -8,14 +8,18 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping); -void copy_blocks(std::vector& key_caches, - std::vector& value_caches, +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, const torch::Tensor& block_mapping); void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, const float kv_scale); + const std::string& kv_cache_dtype, + const double kv_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, @@ -25,4 +29,4 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const float scale, const std::string& kv_cache_dtype); + const double scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index d924ac39b89ca..72041076ae009 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -95,8 +95,11 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, } // namespace vllm -void copy_blocks(std::vector& key_caches, - std::vector& value_caches, +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); @@ -255,7 +258,7 @@ void reshape_and_cache( torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const float kv_scale) { + const std::string& kv_cache_dtype, const double kv_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -334,7 +337,7 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, // Only for testing. void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const float kv_scale, const std::string& kv_cache_dtype) { + const double kv_scale, const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index ed8cfbd421f0f..8367093325314 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -420,12 +420,13 @@ void paged_attention_v1_impl_launcher( void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); @@ -738,12 +739,13 @@ void paged_attention_v2_impl_launcher( void paged_attention_v2( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 2890ba6e2bb32..2b5c3bd6ee70b 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -5,8 +5,8 @@ namespace { template -void copy_blocks_cpu_impl(std::vector& key_caches, - std::vector& value_caches, +void copy_blocks_cpu_impl(std::vector const& key_caches, + std::vector const& value_caches, const torch::Tensor& mapping_pairs, const int element_num_per_block, const int layer_num) { @@ -82,8 +82,11 @@ void reshape_and_cache_cpu_impl( } }; // namespace -void copy_blocks(std::vector& key_caches, - std::vector& value_caches, +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, const torch::Tensor& block_mapping) { unsigned num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); @@ -104,7 +107,7 @@ void copy_blocks(std::vector& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, float kv_scale) { + const std::string& kv_cache_dtype, double kv_scale) { TORCH_CHECK(kv_scale == 1.0f); int num_tokens = key.size(0); diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index c1d3ec058b991..034c406a532d5 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -3,7 +3,7 @@ #define CPU_TYPES_HPP #include -#include +#include namespace vec_op { diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp index 65d3ddcec5709..a76ad08928a2c 100644 --- a/csrc/cpu/layernorm.cpp +++ b/csrc/cpu/layernorm.cpp @@ -88,7 +88,7 @@ void fused_add_rms_norm_impl(scalar_t* __restrict__ input, } // namespace void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, - float epsilon) { + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -102,7 +102,7 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, } void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, - torch::Tensor& weight, float epsilon) { + torch::Tensor& weight, double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index e8aead17ae5a7..96bce7dda0132 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -168,7 +168,7 @@ void rotary_embedding_gptj_impl( }; // namespace void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int head_size, + torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox) { int num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp deleted file mode 100644 index e5b2ce4f30113..0000000000000 --- a/csrc/cpu/pybind.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include "cache.h" -#include "ops.h" -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // vLLM custom ops - pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); - - // Attention ops - ops.def("paged_attention_v1", &paged_attention_v1, - "Compute the attention between an input query and the cached " - "keys/values using PagedAttention."); - ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); - - // Activation ops - ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); - ops.def("gelu_and_mul", &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); - ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); - - // Layernorm - ops.def("rms_norm", &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); - - ops.def("fused_add_rms_norm", &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); - - // Rotary embedding - ops.def("rotary_embedding", &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); - - // Cache ops - pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def("swap_blocks", &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def("copy_blocks", ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def("reshape_and_cache", &reshape_and_cache, - "Reshape the key and value tensors and cache them"); -} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp new file mode 100644 index 0000000000000..a2bf0d49adba5 --- /dev/null +++ b/csrc/cpu/torch_bindings.cpp @@ -0,0 +1,106 @@ +#include "cache.h" +#include "ops.h" +#include "registration.h" + +#include + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + // vLLM custom ops + + // Attention ops + // Compute the attention between an input query and the cached keys/values + // using PagedAttention. + ops.def( + "paged_attention_v1(" + " Tensor! out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); + + // PagedAttention V2. + ops.def( + "paged_attention_v2(" + " Tensor! out, Tensor exp_sums, Tensor max_logits," + " Tensor tmp_out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); + + // Activation ops + + // Activation function used in SwiGLU. + ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul); + + // Activation function used in GeGLU with `none` approximation. + ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul); + + // Activation function used in GeGLU with `tanh` approximation. + ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul); + + // GELU implementation used in GPT-2. + ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_new", torch::kCPU, &gelu_new); + + // Approximate GELU implementation. + ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_fast", torch::kCPU, &gelu_fast); + + // Layernorm + // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.def( + "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " + "()"); + ops.impl("rms_norm", torch::kCPU, &rms_norm); + + // In-place fused Add and RMS Normalization. + ops.def( + "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " + "float epsilon) -> ()"); + ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm); + + // Rotary embedding + // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. + ops.def( + "rotary_embedding(Tensor positions, Tensor! query," + " Tensor! key, int head_size," + " Tensor cos_sin_cache, bool is_neox) -> ()"); + ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); +} + +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { + // Cache ops + // Swap in (out) the cache blocks from src to dst. + cache_ops.def( + "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); + cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks); + + // Copy the cache blocks from src to dst. + cache_ops.def( + "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " + "block_mapping) -> ()"); + cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks); + + // Reshape the key and value tensors and cache them. + cache_ops.def( + "reshape_and_cache(Tensor key, Tensor value," + " Tensor! key_cache, Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " float kv_scale) -> ()"); + cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 2ba49b339e148..73944f4c14890 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,7 +1,5 @@ #pragma once -#include +int64_t get_device_attribute(int64_t attribute, int64_t device_id); -int get_device_attribute(int attribute, int device_id); - -int get_max_shared_memory_per_block_device_attribute(int device_id); +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index 7d8e2e19720fa..d6f9eb646fad5 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -2,7 +2,7 @@ #include #include #endif -int get_device_attribute(int attribute, int device_id) { +int64_t get_device_attribute(int64_t attribute, int64_t device_id) { int device, value; if (device_id < 0) { cudaGetDevice(&device); @@ -14,8 +14,8 @@ int get_device_attribute(int attribute, int device_id) { return value; } -int get_max_shared_memory_per_block_device_attribute(int device_id) { - int attribute; +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) { + int64_t attribute; // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 0b1d95848525a..82a3563979f16 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -1,17 +1,17 @@ #include #include #include -#include +#include #include "custom_all_reduce.cuh" -// fake pointer type -using fptr_t = uint64_t; +// fake pointer type, must match fptr_t type in ops.h +using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, - const std::vector& offsets, int rank, + const std::vector& offsets, int64_t rank, bool full_nvlink) { int world_size = offsets.size(); if (world_size > 8) @@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) { t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, +bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, bool full_nvlink) { auto inp_size = inp.numel() * inp.element_size(); // custom allreduce requires input byte size to be multiples of 16 @@ -125,7 +125,7 @@ void dispose(fptr_t _fa) { delete fa; } -int meta_size() { return sizeof(vllm::Signal); } +int64_t meta_size() { return sizeof(vllm::Signal); } void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, @@ -134,10 +134,16 @@ void register_buffer(fptr_t _fa, torch::Tensor& t, fa->register_buffer(handles, offsets, t.data_ptr()); } -std::pair, std::vector> get_graph_buffer_ipc_meta( +std::tuple> get_graph_buffer_ipc_meta( fptr_t _fa) { auto fa = reinterpret_cast(_fa); - return fa->get_graph_buffer_ipc_meta(); + auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto handles = + torch::empty({static_cast(handle_bytes.size())}, options); + std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); + return {handles, std::move(offsets)}; } void register_graph_buffers(fptr_t _fa, const std::vector& handles, diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 3ecea03242f06..a634e1c3d4886 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -4,7 +4,7 @@ */ #pragma once -#include +#include #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 70a2b3b0a07b1..ca1c04bd880d9 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -291,7 +291,7 @@ fused_add_rms_norm_kernel( void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] - float epsilon) { + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -319,7 +319,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] - float epsilon) { + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp deleted file mode 100644 index 4122f7630d7c7..0000000000000 --- a/csrc/moe/moe_ops.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include "moe_ops.h" - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("topk_softmax", &topk_softmax, - "Apply topk softmax to the gating outputs."); -} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 93e7844ac1993..a251730aa765a 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -1,6 +1,6 @@ #pragma once -#include +#include void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 6ba4fcdb3a3f2..de9747b602524 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -16,7 +16,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include #include #include #include "../cuda_compat.h" diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp new file mode 100644 index 0000000000000..243752b9a9e8c --- /dev/null +++ b/csrc/moe/torch_bindings.cpp @@ -0,0 +1,12 @@ +#include "registration.h" +#include "moe_ops.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + // Apply topk softmax to the gating outputs. + m.def( + "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " + "token_expert_indices, Tensor gating_output) -> ()"); + m.impl("topk_softmax", torch::kCUDA, &topk_softmax); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index edc441d121029..1f8d75da83bb8 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -108,8 +108,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, } } // namespace vllm -void moe_align_block_size(torch::Tensor topk_ids, int num_experts, - int block_size, torch::Tensor sorted_token_ids, +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/csrc/ops.h b/csrc/ops.h index 06b60e748886f..0c270a78c331f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -1,40 +1,42 @@ #pragma once -#include +#include void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step); + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); void paged_attention_v2( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step); + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, - float epsilon); + double epsilon); void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, - torch::Tensor& weight, float epsilon); + torch::Tensor& weight, double epsilon); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int head_size, + torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int head_size, + torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox, - int rot_dim, + int64_t rot_dim, torch::Tensor& cos_sin_cache_offsets); void silu_and_mul(torch::Tensor& out, torch::Tensor& input); @@ -60,12 +62,12 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes, torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, - int split_k_iters); + int64_t split_k_iters); torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, - torch::Tensor _zeros, int split_k_iters, int thx, - int thy); + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy); torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, @@ -88,9 +90,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); -int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales); +void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales); #endif @@ -106,9 +108,9 @@ void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int bit); + bool use_exllama, int64_t bit); -void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit); +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); @@ -116,28 +118,28 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); -void moe_align_block_size(torch::Tensor topk_ids, int num_experts, - int block_size, torch::Tensor sorted_token_ids, +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM -using fptr_t = uint64_t; +using fptr_t = int64_t; fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, - const std::vector& offsets, int rank, + const std::vector& offsets, int64_t rank, bool full_nvlink); -bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, +bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); void dispose(fptr_t _fa); -int meta_size(); +int64_t meta_size(); void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, const std::vector& offsets); -std::pair, std::vector> get_graph_buffer_ipc_meta( +std::tuple> get_graph_buffer_ipc_meta( fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector& handles, const std::vector>& offsets); diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 69d6dae1c26bc..97184a8735593 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -127,7 +127,7 @@ void rotary_embedding( // [num_tokens, num_heads * head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or // [num_tokens, num_kv_heads * head_size] - int head_size, + int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { int64_t num_tokens = query.numel() / query.size(-1); @@ -138,7 +138,7 @@ void rotary_embedding( int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { @@ -168,9 +168,9 @@ void batched_rotary_embedding( // [num_tokens, num_heads * head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or // [num_tokens, num_kv_heads * head_size] - int head_size, + int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, int rot_dim, + bool is_neox, int64_t rot_dim, torch::Tensor& cos_sin_cache_offsets // [num_tokens] ) { int64_t num_tokens = cos_sin_cache_offsets.size(0); @@ -180,7 +180,7 @@ void batched_rotary_embedding( int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { diff --git a/csrc/punica/punica_ops.cu b/csrc/punica/punica_ops.cu index 61de3b37937cc..dd29820144b34 100644 --- a/csrc/punica/punica_ops.cu +++ b/csrc/punica/punica_ops.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -88,7 +88,7 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, } void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, float scale) { + torch::Tensor indicies, int64_t layer_idx, double scale) { CHECK_INPUT(y); CHECK_INPUT(x); CHECK_INPUT(w); @@ -320,7 +320,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, torch::Tensor indicies, int64_t layer_idx, - float scale, int64_t h_in, int64_t h_out, + double scale, int64_t h_in, int64_t h_out, int64_t y_offset) { CHECK_INPUT(y); CHECK_INPUT(x); diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h index 937e2d1d25d4a..5d625d0564f75 100644 --- a/csrc/punica/punica_ops.h +++ b/csrc/punica/punica_ops.h @@ -1,11 +1,11 @@ #pragma once -#include +#include void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, float scale); + torch::Tensor indicies, int64_t layer_idx, double scale); void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, torch::Tensor indicies, int64_t layer_idx, - float scale, int64_t h_in, int64_t h_out, + double scale, int64_t h_in, int64_t h_out, int64_t y_offset); diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp deleted file mode 100644 index 9490ad59cdd5f..0000000000000 --- a/csrc/punica/punica_pybind.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include - -#include "punica_ops.h" - -//====== pybind ====== - -#define DEFINE_pybind(name) m.def(#name, &name, #name); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); - m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, - "dispatch_bgmv_low_level"); -} diff --git a/csrc/punica/torch_bindings.cpp b/csrc/punica/torch_bindings.cpp new file mode 100644 index 0000000000000..894e229b6d9db --- /dev/null +++ b/csrc/punica/torch_bindings.cpp @@ -0,0 +1,18 @@ +#include "registration.h" +#include "punica_ops.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + m.def( + "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int " + "layer_idx, float scale) -> ()"); + m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); + + m.def( + "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w," + "Tensor indicies, int layer_idx," + "float scale, int h_in, int h_out," + "int y_offset) -> ()"); + m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp deleted file mode 100644 index 547823aa1b04e..0000000000000 --- a/csrc/pybind.cpp +++ /dev/null @@ -1,114 +0,0 @@ -#include "cache.h" -#include "cuda_utils.h" -#include "ops.h" -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // vLLM custom ops - pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); - - // Attention ops - ops.def("paged_attention_v1", &paged_attention_v1, - "Compute the attention between an input query and the cached " - "keys/values using PagedAttention."); - ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); - - // Activation ops - ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); - ops.def("gelu_and_mul", &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); - ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); - - // Layernorm - ops.def("rms_norm", &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); - - ops.def("fused_add_rms_norm", &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); - - // Rotary embedding - ops.def("rotary_embedding", &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); - - ops.def("batched_rotary_embedding", &batched_rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key " - "(supports multiple loras)"); - -// Quantization ops -#ifndef USE_ROCM - ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); - ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); - ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - ops.def("marlin_gemm", &marlin_gemm, - "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, - "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, - "gptq_marlin Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_repack", &gptq_marlin_repack, - "gptq_marlin repack from GPTQ"); - ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); - ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, - "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or " - "per-row/column quantization."); -#endif - - ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); - ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); - ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, - "Compute FP8 quantized tensor for given scaling factor"); - ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, - "Compute FP8 quantized tensor and scaling factor"); - ops.def("moe_align_block_size", &moe_align_block_size, - "Aligning the number of tokens to be processed by each expert such " - "that it is divisible by the block size."); - - ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, - "Compute int8 quantized tensor for given scaling factor"); - - ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant, - "Compute int8 quantized tensor and scaling factor"); - - // Cache ops - pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def("swap_blocks", &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def("copy_blocks", ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def("reshape_and_cache", &reshape_and_cache, - "Reshape the key and value tensors and cache them"); - cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash, - "Reshape the key and value tensors and cache them"); - cache_ops.def("convert_fp8", &convert_fp8, - "Convert the key and value cache to fp8 data type"); - - // Cuda utils - pybind11::module cuda_utils = - m.def_submodule("cuda_utils", "vLLM cuda utils"); - cuda_utils.def("get_device_attribute", &get_device_attribute, - "Gets the specified device attribute."); - - cuda_utils.def("get_max_shared_memory_per_block_device_attribute", - &get_max_shared_memory_per_block_device_attribute, - "Gets the maximum shared memory per block device attribute."); - -#ifndef USE_ROCM - // Custom all-reduce kernels - pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce"); - custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar"); - custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar"); - custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg"); - custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg"); - custom_ar.def("dispose", &dispose, "dispose"); - custom_ar.def("meta_size", &meta_size, "meta_size"); - custom_ar.def("register_buffer", ®ister_buffer, "register_buffer"); - custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, - "get_graph_buffer_ipc_meta"); - custom_ar.def("register_graph_buffers", ®ister_graph_buffers, - "register_graph_buffers"); -#endif -} diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 255844eec56d4..8fb9856800867 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include #include diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index bb8e5bbb23d7f..6d6da5f3d8746 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -7,7 +7,7 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ -#include +#include #include #include "dequantize.cuh" @@ -435,8 +435,8 @@ __global__ void __launch_bounds__(64) torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, - torch::Tensor _zeros, int split_k_iters, int thx, - int thy) { + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy) { int in_c = _kernel.size(0); int qout_c = _kernel.size(1); int out_c = qout_c * 8; @@ -491,7 +491,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, - int split_k_iters) { + int64_t split_k_iters) { int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 280b0327111da..aa9511daa2772 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,5 +1,5 @@ #include -#include +#include #include #include "../../dispatch_utils.h" diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu index 088fee4783faa..23a8b4070b70e 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -1,5 +1,5 @@ #include -#include +#include #include diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 8fc4ba662ecdd..a99802153643a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -4,7 +4,7 @@ #if defined CUDA_VERSION && CUDA_VERSION >= 12000 -#include +#include #include diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu index eb532f2ac7a9b..423e64a4932e2 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -1,7 +1,7 @@ #include #include -#include +#include void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 55be3305a9b8c..8c5b693bf6ed7 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -1,5 +1,5 @@ #include -#include +#include #include #include diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 480c4986c3821..785f1a09c1900 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -6,7 +6,7 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa #include #include -#include +#include #include #include #include @@ -1823,7 +1823,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int bit) { + bool use_exllama, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); @@ -1845,7 +1845,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, return c; } -void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit) { +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); vllm::gptq::shuffle_exllama_weight( (uint32_t*)q_weight.data_ptr(), diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index c573b9041065b..0beb9de14c687 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1867,4 +1867,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, return c; } -#endif \ No newline at end of file +#endif diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh index ba5368ea8835f..42af44951efda 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cuh +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index 03d66cecedf1f..d124c0149912d 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -15,7 +15,7 @@ * limitations under the License. */ -#include +#include #include #include diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 686dd7851e6af..b5effc3055441 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -16,7 +16,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include #include #include diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 1b339fa4b392b..40baac6108695 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -1,5 +1,4 @@ #include -#include #include #include #include diff --git a/csrc/registration.h b/csrc/registration.h new file mode 100644 index 0000000000000..e5396e9a8b137 --- /dev/null +++ b/csrc/registration.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +// REGISTER_EXTENSION allows the shared library to be loaded and initialized +// via python's import statement. +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ + STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp new file mode 100644 index 0000000000000..df2603544c85a --- /dev/null +++ b/csrc/torch_bindings.cpp @@ -0,0 +1,283 @@ +#include "cache.h" +#include "cuda_utils.h" +#include "ops.h" +#include "registration.h" + +#include + +// Note on op signatures: +// The X_meta signatures are for the meta functions corresponding to op X. +// They must be kept in sync with the signature for X. Generally, only +// functions that return Tensors require a meta function. +// +// See the following links for detailed docs on op registration and function +// schemas. +// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + // vLLM custom ops + + // Attention ops + // Compute the attention between an input query and the cached + // keys/values using PagedAttention. + ops.def( + "paged_attention_v1(" + " Tensor! out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); + + // PagedAttention V2. + ops.def( + "paged_attention_v2(" + " Tensor! out, Tensor exp_sums, Tensor max_logits," + " Tensor tmp_out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); + + // Activation ops + // Activation function used in SwiGLU. + ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + // Activation function used in GeGLU with `none` approximation. + ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + + // Activation function used in GeGLU with `tanh` approximation. + ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + // GELU implementation used in GPT-2. + ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_new", torch::kCUDA, &gelu_new); + + // Approximate GELU implementation. + ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); + + // Layernorm + // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.def( + "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " + "()"); + ops.impl("rms_norm", torch::kCUDA, &rms_norm); + + // In-place fused Add and RMS Normalization. + ops.def( + "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " + "float epsilon) -> ()"); + ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); + + // Rotary embedding + // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. + ops.def( + "rotary_embedding(Tensor positions, Tensor! query," + " Tensor! key, int head_size," + " Tensor cos_sin_cache, bool is_neox) -> ()"); + ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); + + // Apply GPT-NeoX or GPT-J style rotary embedding to query and key + // (supports multiple loras). + ops.def( + "batched_rotary_embedding(Tensor positions, Tensor! query," + " Tensor! key, int head_size," + " Tensor cos_sin_cache, bool is_neox," + " int rot_dim," + " Tensor cos_sin_cache_offsets) -> ()"); + ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); + + // Quantization ops +#ifndef USE_ROCM + // Quantized GEMM for AQLM. + ops.def("aqlm_gemm", &aqlm_gemm); + ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); + + // Decompression method for AQLM. + ops.def("aqlm_dequant", &aqlm_dequant); + ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); + + // Quantized GEMM for AWQ. + ops.def("awq_gemm", &awq_gemm); + ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); + + // Dequantization for AWQ. + ops.def("awq_dequantize", &awq_dequantize); + ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); + + // Marlin (Dense) Optimized Quantized GEMM for GPTQ. + ops.def("marlin_gemm", &marlin_gemm); + ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm); + + // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. + ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm); + ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); + + // gptq_marlin Optimized Quantized GEMM for GPTQ. + ops.def("gptq_marlin_gemm", &gptq_marlin_gemm); + ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); + + // gptq_marlin repack from GPTQ. + ops.def("gptq_marlin_repack", &gptq_marlin_repack); + ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); + + // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm_dq(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales) -> ()"); + ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq); +#endif + + // Quantized GEMM for GPTQ. + ops.def("gptq_gemm", &gptq_gemm); + ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); + + // Post processing for GPTQ. + ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); + ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); + + // Quantized GEMM for SqueezeLLM. + ops.def( + "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor " + "lookup_table) -> ()"); + ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm); + + // Compute FP8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); + ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); + + // Compute FP8 quantized tensor and scaling factor. + ops.def( + "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> " + "()"); + ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); + + // Aligning the number of tokens to be processed by each expert such + // that it is divisible by the block size. + ops.def( + "moe_align_block_size(Tensor topk_ids, int num_experts," + " int block_size, Tensor! sorted_token_ids," + " Tensor! experts_ids," + " Tensor! num_tokens_post_pad) -> ()"); + ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " + "()"); + ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); + + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " + "()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, + &dynamic_scaled_int8_quant); +} + +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { + // Cache ops + // Swap in (out) the cache blocks from src to dst. + cache_ops.def( + "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); + cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); + + // Copy the cache blocks from src to dst. + cache_ops.def( + "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " + "block_mapping) -> ()"); + cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks); + + // Reshape the key and value tensors and cache them. + cache_ops.def( + "reshape_and_cache(Tensor key, Tensor value," + " Tensor! key_cache, Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " float kv_scale) -> ()"); + cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); + + // Reshape the key and value tensors and cache them. + cache_ops.def( + "reshape_and_cache_flash(Tensor key, Tensor value," + " Tensor! key_cache," + " Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype) -> ()"); + cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, + &reshape_and_cache_flash); + + // Convert the key and value cache to fp8 data type. + cache_ops.def( + "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str " + "kv_cache_dtype) -> ()"); + cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); +} + +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { + // Cuda utils + + // Gets the specified device attribute. + cuda_utils.def("get_device_attribute", &get_device_attribute); + cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute); + + // Gets the maximum shared memory per block device attribute. + cuda_utils.def("get_max_shared_memory_per_block_device_attribute", + &get_max_shared_memory_per_block_device_attribute); + cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", + torch::kCUDA, + &get_max_shared_memory_per_block_device_attribute); +} + +#ifndef USE_ROCM +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { + // Custom all-reduce kernels + custom_ar.def("init_custom_ar", &init_custom_ar); + custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); + + custom_ar.def("should_custom_ar", &should_custom_ar); + custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar); + + custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); + custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); + + custom_ar.def( + "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> " + "()"); + custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg); + + custom_ar.def("dispose", &dispose); + custom_ar.impl("dispose", torch::kCPU, &dispose); + + custom_ar.def("meta_size", &meta_size); + custom_ar.impl("meta_size", torch::kCPU, &meta_size); + + custom_ar.def("register_buffer", ®ister_buffer); + custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer); + + custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); + custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU, + &get_graph_buffer_ipc_meta); + + custom_ar.def("register_graph_buffers", ®ister_graph_buffers); + custom_ar.impl("register_graph_buffers", torch::kCPU, + ®ister_graph_buffers); +} +#endif + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/setup.py b/setup.py index f7d465b60c153..339b0ad6de2d1 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ def remove_prefix(text, prefix): class CMakeExtension(Extension): def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None: - super().__init__(name, sources=[], **kwa) + super().__init__(name, sources=[], py_limited_api=True, **kwa) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index aab7af9d2cbf6..0daf7439468aa 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -1,7 +1,8 @@ import pytest import torch -from vllm._C import ops +# ruff: noqa: F401 +import vllm._C DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -33,7 +34,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda") scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda") - ops.dynamic_scaled_int8_quant(ops_out, x, scales_out) + torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out) assert torch.allclose(scales_out, scales) assert torch.allclose(torch_out, ops_out, @@ -60,6 +61,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, out2 = torch.empty_like(x, dtype=torch.int8) scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") - ops.static_scaled_int8_quant(out2, x, scale_argument) + torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument) assert torch.allclose(out1, out2, atol=1) # big atol to account for rounding errors diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7e12f1ba14cde..440b0e8afa99a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,35 +1,47 @@ -from typing import Optional, Tuple, Type +import contextlib +from typing import List, Optional, Tuple, Type import torch try: - from vllm._C import cache_ops as vllm_cache_ops - from vllm._C import ops as vllm_ops + import vllm._C except ImportError as e: from vllm.logger import init_logger logger = init_logger(__name__) logger.warning("Failed to import from vllm._C with %r", e) +with contextlib.suppress(ImportError): + import vllm._moe_C + +with contextlib.suppress(ImportError): + # ruff: noqa: F401 + import vllm._punica_C + + +def is_custom_op_supported(op_name: str) -> bool: + op, overloads = torch._C._jit_get_operation(op_name) + return op is not None + # activation ops def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.silu_and_mul(out, x) + torch.ops._C.silu_and_mul(out, x) def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.gelu_and_mul(out, x) + torch.ops._C.gelu_and_mul(out, x) def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.gelu_tanh_and_mul(out, x) + torch.ops._C.gelu_tanh_and_mul(out, x) def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.gelu_fast(out, x) + torch.ops._C.gelu_fast(out, x) def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.gelu_new(out, x) + torch.ops._C.gelu_new(out, x) # page attention ops @@ -53,7 +65,7 @@ def paged_attention_v1( blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: - vllm_ops.paged_attention_v1( + torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, @@ -83,7 +95,7 @@ def paged_attention_v2( blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: - vllm_ops.paged_attention_v2( + torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, @@ -100,8 +112,8 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, - is_neox) + torch.ops._C.rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox) def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, @@ -109,20 +121,20 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: - vllm_ops.batched_rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox, rot_dim, - cos_sin_cache_offsets) + torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) # layer norm ops def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: - vllm_ops.rms_norm(out, input, weight, epsilon) + torch.ops._C.rms_norm(out, input, weight, epsilon) def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: - vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) + torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) # quantization ops @@ -130,13 +142,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: - return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, - thy) + return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, + thx, thy) def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: - return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # gptq @@ -144,27 +156,27 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, use_exllama: bool, bit: int) -> torch.Tensor: - return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, - b_g_idx, use_exllama, bit) + return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, use_exllama, bit) def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: - vllm_ops.gptq_shuffle(q_weight, q_perm, bit) + torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) # squeezellm def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, lookup_table: torch.Tensor) -> None: - vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) + torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table) # marlin def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, - size_n, size_k) + return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, + size_n, size_k) # marlin_24 @@ -172,9 +184,9 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return vllm_ops.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, - workspace, num_bits, size_m, size_n, - size_k) + return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, + workspace, num_bits, size_m, + size_n, size_k) # cutlass @@ -188,7 +200,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, n = b.shape[1] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - vllm_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) + torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) return out @@ -198,21 +210,22 @@ def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, codebook_partition_sizes: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - return vllm_ops.aqlm_gemm(input, codes, codebooks, scales, - codebook_partition_sizes, bias) + return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, + codebook_partition_sizes, bias) def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: torch.Tensor) -> torch.Tensor: - return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes) + return torch.ops._C.aqlm_dequant(codes, codebooks, + codebook_partition_sizes) # gptq_marlin def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: - return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, - num_bits) + return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, + num_bits) def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, @@ -220,9 +233,9 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int, is_k_full: bool) -> torch.Tensor: - return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, - workspace, num_bits, size_m, size_n, - size_k, is_k_full) + return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, + workspace, num_bits, size_m, size_n, + size_k, is_k_full) # fp8 @@ -259,9 +272,9 @@ def scaled_fp8_quant( output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: - vllm_ops.static_scaled_fp8_quant(output, input, scale) + torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale @@ -284,14 +297,14 @@ def scaled_int8_quant( output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. - vllm_ops.static_scaled_int8_quant(output, input, scale) + torch.ops._C.static_scaled_int8_quant(output, input, scale) return output, scale # dynamic-per-token quantization. input_scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) - vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales) return output, input_scales @@ -300,9 +313,16 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor) -> None: - vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, - sorted_token_ids, experts_ids, - num_tokens_post_pad) + torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, + token_expert_indicies: torch.Tensor, + gating_output: float) -> None: + torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, + token_expert_indicies, gating_output) def reshape_and_cache( @@ -314,8 +334,9 @@ def reshape_and_cache( kv_cache_dtype: str, kv_scale: float, ) -> None: - vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, kv_scale) + torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, kv_scale) def reshape_and_cache_flash( @@ -326,25 +347,115 @@ def reshape_and_cache_flash( slot_mapping: torch.Tensor, kv_cache_dtype: str, ) -> None: - vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype) + torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype) def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, block_mapping: torch.Tensor) -> None: - vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: - vllm_cache_ops.swap_blocks(src, dst, block_mapping) + torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) def convert_fp8(output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8") -> None: - vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype) + torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) + + +def get_device_attribute(attribute: int, device: int) -> int: + return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) + + +def get_max_shared_memory_per_block_device_attribute(device: int) -> int: + # ruff: noqa: E501 + return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( + device) + + +# custom ar +def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, + handles: List[str], offsets: List[int], rank: int, + full_nvlink: bool) -> int: + return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles, + offsets, rank, full_nvlink) + + +def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int, + full_nvlink: bool) -> bool: + return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size, + full_nvlink) + + +def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) + +def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, + out: torch.Tensor) -> None: + torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out) -#TODO: cuda_utils, custom_ar + +def dispose(fa: int) -> None: + torch.ops._C_custom_ar.dispose(fa) + + +def meta_size() -> int: + return torch.ops._C_custom_ar.meta_size() + + +def register_buffer(fa: int, t: torch.Tensor, handles: List[str], + offsets: List[int]) -> None: + return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets) + + +def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]: + return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) + + +def register_graph_buffers(fa: int, handles: List[str], + offsets: List[List[int]]) -> None: + torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + + +# punica +def dispatch_bgmv( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.Tensor, + layer_idx: int, + scale: float, +) -> None: + torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, + scale) + + +def dispatch_bgmv_low_level( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.Tensor, + layer_idx: int, + scale: float, + h_in: int, + h_out: int, + y_offset: int, +) -> None: + torch.ops._punica_C.dispatch_bgmv_low_level( + y, + x, + w_t_all, + indicies, + layer_idx, + scale, + h_in, + h_out, + y_offset, + ) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 070c074e511bc..8c64c2bfdeb8f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -5,7 +5,7 @@ import torch from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache -from vllm._C import cache_ops +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) @@ -47,11 +47,11 @@ def swap_blocks( ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( @@ -60,7 +60,7 @@ def copy_blocks( ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass @@ -285,7 +285,7 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - cache_ops.reshape_and_cache_flash( + ops.reshape_and_cache_flash( key, value, key_cache, diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a3902aecb3793..4a0e19bc0c159 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -6,6 +6,7 @@ from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import ( @@ -15,7 +16,11 @@ try: import pynvml - from vllm._C import custom_ar + # Simulate ImportError if custom_ar ops are not supported. + if not ops.is_custom_op_supported("_C_custom_ar::meta_size"): + raise ImportError("custom_ar", __file__) + + custom_ar = True @contextmanager def _nvml(): @@ -27,7 +32,7 @@ def _nvml(): except ImportError: # For AMD GPUs - custom_ar = None + custom_ar = False pynvml = None @contextmanager @@ -97,7 +102,7 @@ def __init__(self, self._IS_CAPTURING = False self.disabled = True - if custom_ar is None: + if not custom_ar: # disable because of missing custom allreduce library # e.g. in a non-cuda environment return @@ -175,7 +180,7 @@ def __init__(self, # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate # allreduce results. - self.meta = torch.zeros(custom_ar.meta_size() + max_size, + self.meta = torch.zeros(ops.meta_size() + max_size, dtype=torch.uint8, device=self.device) # This is a pre-registered IPC buffer. In eager mode, input tensors @@ -196,9 +201,8 @@ def __init__(self, self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = full_nvlink - self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, - handles, offsets, rank, - self.full_nvlink) + self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles, + offsets, rank, self.full_nvlink) self.register_buffer(self.buffer) @contextmanager @@ -252,31 +256,31 @@ def _gather_ipc_meta(self, shard_data): def register_buffer(self, inp: torch.Tensor): handles, offsets = self._get_ipc_meta(inp) - custom_ar.register_buffer(self._ptr, inp, handles, offsets) + ops.register_buffer(self._ptr, inp, handles, offsets) def register_graph_buffers(self): - handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr) + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) logger.info("Registering %d cuda graph addresses", len(offset)) - custom_ar.register_graph_buffers(self._ptr, handles, offsets) + ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): - return custom_ar.should_custom_ar(inp, self.max_size, self.world_size, - self.full_nvlink) + return ops.should_custom_ar(inp, self.max_size, self.world_size, + self.full_nvlink) # all reduce, assuming inp tensor is IPC registered with register_buffer, # or, in the context of cuda graphs, register_graph_buffers def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): if out is None: out = torch.empty_like(inp) - custom_ar.all_reduce_reg(self._ptr, inp, out) + ops.all_reduce_reg(self._ptr, inp, out) return out # all reduce, assuming inp tensor is NOT IPC registered def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): if out is None: out = torch.empty_like(inp) - custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) + ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -304,7 +308,7 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: def close(self): if not self.disabled and self._ptr: - custom_ar.dispose(self._ptr) + ops.dispose(self._ptr) self._ptr = 0 def __del__(self): diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index c87bed54726fc..7ecaa450f1758 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -4,16 +4,21 @@ import torch +from vllm import _custom_ops as ops + + +def _check_punica_support(): + if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"): + return -def _raise_import_error(e): if torch.cuda.get_device_capability() < (8, 0): raise ImportError( - "punica LoRA kernels require compute capability >= 8.0") from e + "punica LoRA kernels require compute capability >= 8.0") else: raise ImportError( "punica LoRA kernels could not be imported. If you built vLLM " "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " - "was set.") from e + "was set.") def bgmv( @@ -41,12 +46,9 @@ def bgmv( layer_idx: Layer index of the weight matrices. scale: Scaling factor. """ - try: - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) + _check_punica_support() - punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, @@ -75,11 +77,9 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, y_offset: Offset to apply to the starting column of y. y_slice_size: Size of the y column slice. """ - try: - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) - punica_kernels.dispatch_bgmv_low_level( + _check_punica_support() + + ops.dispatch_bgmv_low_level( y, x, w_t_all, @@ -122,10 +122,7 @@ def add_lora(y: torch.Tensor, scale: Scaling factor. buffer: Optional. Shape: `[B, R]`. Temporary buffer. """ - try: - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) + _check_punica_support() r = wb_t_all.size(-1) if buffer is None: @@ -135,9 +132,8 @@ def add_lora(y: torch.Tensor, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) - punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, - scale) + ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) + ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale) def add_lora_slice(y: torch.Tensor, @@ -176,10 +172,7 @@ def add_lora_slice(y: torch.Tensor, y_offset: Offset to apply to the starting column of y. y_slice_size: Size of the y column slice. """ - try: - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) + _check_punica_support() r = wb_t_all.size(-1) if buffer is None: @@ -189,7 +182,7 @@ def add_lora_slice(y: torch.Tensor, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - punica_kernels.dispatch_bgmv_low_level( + ops.dispatch_bgmv_low_level( buffer, x, wa_t_all, @@ -200,7 +193,7 @@ def add_lora_slice(y: torch.Tensor, buffer.size(1), 0, ) - punica_kernels.dispatch_bgmv_low_level( + ops.dispatch_bgmv_low_level( y, buffer, wb_t_all, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1c6947137a1c9..4d0160ff296a0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,7 +8,6 @@ import triton import triton.language as tl -import vllm._moe_C as moe_kernels from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -355,7 +354,7 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) - moe_kernels.topk_softmax( + ops.topk_softmax( topk_weights, topk_ids, token_expert_indicies, diff --git a/vllm/utils.py b/vllm/utils.py index 2bd24d086f690..54d446b23350a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -22,6 +22,7 @@ import torch import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.logger import enable_trace_function_call, init_logger T = TypeVar("T") @@ -148,12 +149,8 @@ def is_neuron() -> bool: @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" - # NOTE: This import statement should be executed lazily since - # the Neuron-X backend does not have the `cuda_utils` module. - from vllm._C import cuda_utils - max_shared_mem = ( - cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu)) + ops.get_max_shared_memory_per_block_device_attribute(gpu)) # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py # will fail assert max_shared_mem > 0, "max_shared_mem can not be zero"