|
| 1 | +cmake_minimum_required(VERSION 3.21) |
| 2 | + |
| 3 | +project(vllm_extensions LANGUAGES CXX) |
| 4 | + |
| 5 | +# |
| 6 | +# Find where user site-packages are installed and add it to cmake's search path. |
| 7 | +# |
| 8 | + |
| 9 | +if(NOT DEFINED PYTHON_EXECUTABLE) |
| 10 | + set(PYTHON_EXECUTABLE python3) |
| 11 | +endif() |
| 12 | + |
| 13 | +execute_process( |
| 14 | + COMMAND |
| 15 | + "${PYTHON_EXECUTABLE}" "-c" |
| 16 | + "import site; print(site.getusersitepackages())" |
| 17 | + OUTPUT_VARIABLE SITE_PATH |
| 18 | + ERROR_VARIABLE SITE_PATH_ERR |
| 19 | + OUTPUT_STRIP_TRAILING_WHITESPACE) |
| 20 | + |
| 21 | +if(SITE_PATH STREQUAL "") |
| 22 | + message(FATAL_ERROR "Failed to locate site-packages path," |
| 23 | + " full error message:\n${SITE_PATH_ERR}") |
| 24 | +endif() |
| 25 | + |
| 26 | +list(APPEND CMAKE_PREFIX_PATH ${SITE_PATH}) |
| 27 | + |
| 28 | +# |
| 29 | +# Find packages needed to compile |
| 30 | +# |
| 31 | +find_package(Python 3.8 REQUIRED COMPONENTS Interpreter Development.Module) |
| 32 | +find_package(Torch 2.1.2 EXACT REQUIRED) |
| 33 | +append_torchlib_if_found(torch_python) |
| 34 | +find_package(MPI REQUIRED) |
| 35 | + |
| 36 | +execute_process( |
| 37 | + COMMAND |
| 38 | + "${PYTHON_EXECUTABLE}" "-c" |
| 39 | + "import torch.utils.cpp_extension as torch_cpp_ext; print(' '.join(torch_cpp_ext.COMMON_NVCC_FLAGS))" |
| 40 | + OUTPUT_VARIABLE TORCH_NVCC_FLAGS |
| 41 | + ERROR_VARIABLE TORCH_NVCC_FLAGS_ERR |
| 42 | + OUTPUT_STRIP_TRAILING_WHITESPACE) |
| 43 | + |
| 44 | +if(TORCH_NVCC_FLAGS STREQUAL "") |
| 45 | + message(FATAL_ERROR "Unable to determine torch nvcc compiler flags," |
| 46 | + " full error message:\n${TORCH_NVCC_FLAGS_ERR}") |
| 47 | +endif() |
| 48 | + |
| 49 | +string(STRIP ${TORCH_NVCC_FLAGS} TORCH_NVCC_FLAGS) |
| 50 | +list(APPEND NVCC_FLAGS ${TORCH_NVCC_FLAGS}) |
| 51 | + |
| 52 | +set(PUNICA_NVCC_FLAGS "${NVCC_FLAGS}") |
| 53 | +foreach(OPT |
| 54 | + "-D__CUDA_NO_HALF_OPERATORS__" |
| 55 | + "-D__CUDA_NO_HALF_CONVERSIONS__" |
| 56 | + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" |
| 57 | + "-D__CUDA_NO_HALF2_OPERATORS__" |
| 58 | + ) |
| 59 | + string(REPLACE ${OPT} "" PUNICA_NVCC_FLAGS ${PUNICA_NVCC_FLAGS}) |
| 60 | +endforeach() |
| 61 | +string(STRIP ${PUNICA_NVCC_FLAGS} PUNICA_NVCC_FLAGS) |
| 62 | + |
| 63 | +if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) |
| 64 | + list(APPEND NVCC_FLAGS "-DENABLE_FP8_E5M2") |
| 65 | +endif() |
| 66 | + |
| 67 | +# |
| 68 | +# Check for existence of CUDA/HIP language support |
| 69 | +# |
| 70 | +# https://cliutils.gitlab.io/modern-cmake/chapters/packages/CUDA.html |
| 71 | +include(CheckLanguage) |
| 72 | +check_language(HIP) |
| 73 | +check_language(CUDA) |
| 74 | + |
| 75 | +if(NOT CMAKE_HIP_COMPILER STREQUAL "NOTFOUND") |
| 76 | + enable_language(HIP) |
| 77 | + list(APPEND NVCC_FLAGS "-DUSE_ROCM -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__") |
| 78 | + |
| 79 | + # TODO: intersect with this list? |
| 80 | + if(NOT DEFINED CMAKE_HIP_ARCHITECTURES) |
| 81 | + set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942") |
| 82 | + endif() |
| 83 | + |
| 84 | + foreach(HIP_ARCH ${CMAKE_HIP_ARCHITECTURES}) |
| 85 | + list(APPEND NVCC_FLAGS "--offload-arch=${HIP_ARCH}") |
| 86 | + endforeach() |
| 87 | +elseif(NOT CMAKE_CUDA_COMPILER STREQUAL "NOTFOUND") |
| 88 | + enable_language(CUDA) |
| 89 | + set(IS_CUDA true) |
| 90 | + |
| 91 | + # TODO: parse TORCH_CUDA_ARCH_LIST -> CMAKE_CUDA_ARCHITECTURES? |
| 92 | + |
| 93 | + # https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html#prop_tgt:CUDA_ARCHITECTURES |
| 94 | + # set_target_properties(tgt PROPERTIES CUDA_ARCHITECTURES "35;50;72") |
| 95 | + # TODO: PTX stuff |
| 96 | + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) |
| 97 | + # This indicates support for both real architectures (i.e, no ptx). |
| 98 | + set(CMAKE_CUDA_ARCHITECTURES "70;75;80;86;89;90") |
| 99 | + endif() |
| 100 | +else() |
| 101 | + message(FATAL_ERROR "Can't find CUDA or HIP installation.") |
| 102 | +endif() |
| 103 | + |
| 104 | +if(NVCC_THREADS) |
| 105 | + list(APPEND NVCC_FLAGS "--threads=${NVCC_THREADS}") |
| 106 | +endif() |
| 107 | + |
| 108 | +# |
| 109 | +# Define target source files |
| 110 | +# |
| 111 | + |
| 112 | +set(VLLM_EXT_SRC |
| 113 | + "csrc/cache_kernels.cu" |
| 114 | + "csrc/attention/attention_kernels.cu" |
| 115 | + "csrc/pos_encoding_kernels.cu" |
| 116 | + "csrc/activation_kernels.cu" |
| 117 | + "csrc/layernorm_kernels.cu" |
| 118 | + "csrc/quantization/squeezellm/quant_cuda_kernel.cu" |
| 119 | + "csrc/quantization/gptq/q_gemm.cu" |
| 120 | + "csrc/cuda_utils_kernels.cu" |
| 121 | + "csrc/moe_align_block_size_kernels.cu" |
| 122 | + "csrc/pybind.cpp") |
| 123 | + |
| 124 | +if(IS_CUDA) |
| 125 | + list(APPEND VLLM_EXT_SRC |
| 126 | + "csrc/quantization/awq/gemm_kernels.cu" |
| 127 | + "csrc/custom_all_reduce.cu") |
| 128 | +endif() |
| 129 | + |
| 130 | +File(GLOB VLLM_MOE_EXT_SRC "csrc/moe/*.cu" "csrc/moe/*.cpp") |
| 131 | +File(GLOB VLLM_PUNICA_EXT_SRC "csrc/punica/bgmv/*.cu" "csrc/punica/*.cpp") |
| 132 | + |
| 133 | +# |
| 134 | +# Define targets |
| 135 | +# |
| 136 | +set(CMAKE_CXX_STANDARD 17) |
| 137 | + |
| 138 | +function(define_module_target MOD_NAME MOD_SRC MOD_NVCC_FLAGS) |
| 139 | + Python_add_library(${MOD_NAME} MODULE ${MOD_SRC} WITH_SOABI) |
| 140 | + # Note: optimization level/debug info is set by build type |
| 141 | + if (IS_CUDA) |
| 142 | + set(CUDA_LANG "CUDA") |
| 143 | + else() |
| 144 | + set(CUDA_LANG "HIP") |
| 145 | + endif() |
| 146 | + target_compile_options(${MOD_NAME} PRIVATE |
| 147 | + $<$<COMPILE_LANGUAGE:${CUDA_LANG}>:${MOD_NVCC_FLAGS}>) |
| 148 | + target_compile_definitions(${MOD_NAME} PRIVATE "-DTORCH_EXTENSION_NAME=${MOD_NAME}") |
| 149 | + target_include_directories(${MOD_NAME} PRIVATE csrc PRIVATE ${TORCH_INCLUDE_DIRS} ${MPI_CXX_INCLUDE_DIRS}) |
| 150 | + target_link_libraries(${MOD_NAME} PRIVATE ${TORCH_LIBRARIES}) |
| 151 | + install(TARGETS ${MOD_NAME} LIBRARY DESTINATION vllm) |
| 152 | +endfunction() |
| 153 | + |
| 154 | +define_module_target(_C "${VLLM_EXT_SRC}" "${NVCC_FLAGS}") |
| 155 | +define_module_target(_moe_C "${VLLM_MOE_EXT_SRC}" "${NVCC_FLAGS}") |
| 156 | +define_module_target(_punica_C "${VLLM_PUNICA_EXT_SRC}" "${PUNICA_NVCC_FLAGS}") |
0 commit comments