diff --git a/.gitignore b/.gitignore index 20c9baee226..1f9ba162c13 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,13 @@ target router/tokenizer.json *__pycache__* + +# ROCm auto-generated files +*.hip +server/exllamav2_kernels/exllamav2_kernels/hip/ +server/exllama_kernels/exllama_kernels/hip/ +server/exllama_kernels/exllama_kernels/hip_func/ +*_hip.cuh +server/exllama_kernels/exllama_kernels/hip_buffers.cuh +server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp + diff --git a/Dockerfile_amd b/Dockerfile_amd index dd331a5df66..d2b6f8979a0 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -75,8 +75,8 @@ RUN chmod +x ~/mambaforge.sh && \ mamba init && \ rm ~/mambaforge.sh -# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6. -RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 +# Install PyTorch 2.2 RC compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6. +RUN pip install torch --index-url https://download.pytorch.org/whl/test/rocm5.7/ FROM base AS kernel-builder @@ -104,6 +104,20 @@ WORKDIR /usr/src COPY server/custom_kernels/ . RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build +# Build exllama kernels +FROM kernel-builder as exllama-kernels-builder +WORKDIR /usr/src +COPY server/exllama_kernels/ . + +RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build + +# Build exllama v2 kernels +FROM kernel-builder as exllamav2-kernels-builder +WORKDIR /usr/src +COPY server/exllamav2_kernels/ . + +RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build + FROM base as base-copy # Text Generation Inference base env @@ -120,6 +134,12 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86 # Copy build artifacts from custom kernels builder COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from exllama kernels builder +COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + +# Copy build artifacts from exllamav2 kernels builder +COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages + # Install flash-attention dependencies RUN pip install einops --no-cache-dir diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 004790ab87a..df5102c2ceb 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -43,8 +43,8 @@ text-generation-launcher --model-id TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed. -TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future: -* Quantization (GPTQ, AWQ, etc.) +TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention, GPTQ quantization, flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future: +* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints. * Flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm) * Kernel for slinding window attention (Mistral) diff --git a/server/exllama_kernels/exllama_kernels/cuda_compat.cuh b/server/exllama_kernels/exllama_kernels/cu_compat.cuh similarity index 91% rename from server/exllama_kernels/exllama_kernels/cuda_compat.cuh rename to server/exllama_kernels/exllama_kernels/cu_compat.cuh index 8dfa25de39c..c5258813e14 100644 --- a/server/exllama_kernels/exllama_kernels/cuda_compat.cuh +++ b/server/exllama_kernels/exllama_kernels/cu_compat.cuh @@ -43,12 +43,12 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) // -#if defined(__CUDA_ARCH__) -#if __CUDA_ARCH__ < 700 +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) __device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } -#if __CUDA_ARCH__ < 600 +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } #endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu index 60dc4c9db4d..61380f4296f 100644 --- a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu @@ -2,8 +2,11 @@ #include "column_remap.cuh" #include "../util.cuh" #include "../matrix.cuh" -#include "../cuda_compat.cuh" +#include "../cu_compat.cuh" #include "../cuda_buffers.cuh" +#if defined(USE_ROCM) +#include "../hip_compat.cuh" +#endif const int THREADS_X = 32; // Block size and thread count along columns in w and out const int THREADS_Y = 1; // Block size and thread count along rows in x and out @@ -128,7 +131,7 @@ __global__ void q4_matmul_kernel if constexpr (use_half2) { - half result = __hadd(acc.x, acc.y); + half result = __hadd(__low2half(acc), __high2half(acc)); atomicAdd(out_.item_ptr(x_row, w_column), result); } else diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh b/server/exllama_kernels/exllama_kernels/hip_compat.cuh similarity index 68% rename from server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh rename to server/exllama_kernels/exllama_kernels/hip_compat.cuh index 19b1e4a6041..4f2a7ae7df0 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/compat_gemm.cuh +++ b/server/exllama_kernels/exllama_kernels/hip_compat.cuh @@ -1,12 +1,23 @@ -#ifndef _compat_gemm_cuh -#define _compat_gemm_cuh +// Adapted from turboderp exllama: https://github.com/turboderp/exllama -#if defined(USE_ROCM) +#ifndef _hip_compat_cuh +#define _hip_compat_cuh -// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required -// for symbols as hipblasHalf. -#include +// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6. +__device__ __forceinline__ __half __compat_hrcp(__half x) { + return __half_raw{ + static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; +} + +__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { + return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), + static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; +} + +#define hrcp __compat_hrcp +#define h2rcp __compat_h2rcp +// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf. __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, hipblasOperation_t transA, hipblasOperation_t transB, @@ -31,8 +42,10 @@ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t #define hipblasHgemm __compat_hipblasHgemm // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_handle hipblasHandle_t #define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_get_stream hipblasGetStream +#define rocblas_set_stream hipblasSetStream #define rocblas_hgemm __compat_hipblasHgemm -#endif -#endif +#endif \ No newline at end of file diff --git a/server/exllama_kernels/exllama_kernels/util.cuh b/server/exllama_kernels/exllama_kernels/util.cuh index 2839b10fafc..7b397573214 100644 --- a/server/exllama_kernels/exllama_kernels/util.cuh +++ b/server/exllama_kernels/exllama_kernels/util.cuh @@ -8,7 +8,11 @@ #include #include +#if defined(USE_ROCM) +#define cudaUnspecified hipErrorUnknown +#else #define cudaUnspecified cudaErrorApiFailureBase +#endif // React to failure on return code != cudaSuccess diff --git a/server/exllamav2_kernels/setup.py b/server/exllamav2_kernels/setup.py index 518db1df9a2..4a16b546f70 100644 --- a/server/exllamav2_kernels/setup.py +++ b/server/exllamav2_kernels/setup.py @@ -1,5 +1,15 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +extra_cuda_cflags = ["-lineinfo", "-O3"] + +if torch.version.hip: + extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"] + +extra_compile_args = { + "nvcc": extra_cuda_cflags, +} setup( name="exllamav2_kernels", @@ -11,6 +21,7 @@ "exllamav2_kernels/cuda/q_matrix.cu", "exllamav2_kernels/cuda/q_gemm.cu", ], + extra_compile_args=extra_compile_args, ) ], cmdclass={"build_ext": BuildExtension}, diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py index 2b897f252fc..80836a95974 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -1,12 +1,9 @@ # Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 -from logging import getLogger - import torch import torch.nn as nn -import math -logger = getLogger(__name__) +from loguru import logger try: from exllamav2_kernels import make_q_matrix, gemm_half_q_half diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index c9393d9971a..6ddfd6f4ef0 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -33,7 +33,7 @@ major = 1 HAS_EXLLAMA = False -CAN_EXLLAMA = major >= 8 +CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" # if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: # V2 = False