Skip to content

Commit

Permalink
Merge EmbeddedLLM/vllm-rocm into vLLM main (vllm-project#1836)
Browse files Browse the repository at this point in the history
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Co-authored-by: Amir Balwel <amoooori04@gmail.com>
Co-authored-by: root <kuanfu.liu@akirakan.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: kuanfu <kuanfu.liu@embeddedllm.com>
Co-authored-by: miloice <17350011+kliuae@users.noreply.github.com>
  • Loading branch information
7 people authored Dec 8, 2023
1 parent c8e7eb1 commit 6ccc0bf
Show file tree
Hide file tree
Showing 29 changed files with 873 additions and 118 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,7 @@ _build/
# vim swap files
*.swo
*.swp

# hip files generated by PyTorch
*.hip
*_hip*
62 changes: 62 additions & 0 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y

# Install some basic utilities
RUN apt-get update && apt-get install -y \
curl \
ca-certificates \
sudo \
git \
bzip2 \
libx11-6 \
build-essential \
wget \
unzip \
nvidia-cuda-toolkit \
tmux \
&& rm -rf /var/lib/apt/lists/*

### Mount Point ###
# When launching the container, mount the code directory to /app
ARG APP_MOUNT=/app
VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas

ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:

# Install ROCm flash-attention
RUN mkdir libs \
&& cd libs \
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
&& cd flash-attention \
&& git checkout 3d2b6f5 \
&& git submodule update --init \
&& export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \
&& patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
&& python3 setup.py install \
&& cd ..

COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip
RUN pip install xformers==0.0.22.post7 --no-deps

RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
&& bash patch_xformers-0.0.22.post7.rocm.sh \
&& python3 setup.py install \
&& cd ..

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir ray[all]

CMD ["/bin/bash"]
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Easy, fast, and cheap LLM serving for everyone
---

*Latest News* 🔥
- [2023/12] Added ROCm support to vLLM.
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
Expand All @@ -43,6 +44,7 @@ vLLM is flexible and easy to use with:
- Tensor parallelism support for distributed inference
- Streaming outputs
- OpenAI-compatible API server
- Support NVIDIA CUDA and AMD ROCm.

vLLM seamlessly supports many Hugging Face models, including the following architectures:

Expand Down
7 changes: 4 additions & 3 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "cuda_compat.h"
#include "dispatch_utils.h"

namespace vllm {
Expand All @@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel(
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
}
}
Expand Down Expand Up @@ -57,7 +58,7 @@ __global__ void activation_kernel(
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
Expand Down
34 changes: 21 additions & 13 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

Expand All @@ -23,7 +27,11 @@

#include <algorithm>

#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
Expand All @@ -40,7 +48,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}

// Warp leaders store the data to shared memory.
Expand All @@ -59,11 +67,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}

// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
return VLLM_SHFL_SYNC(sum, 0);
}

// TODO(woosuk): Merge the last two dimensions of the grid.
Expand Down Expand Up @@ -223,7 +231,7 @@ __device__ void paged_attention_kernel(
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
Expand All @@ -235,10 +243,10 @@ __device__ void paged_attention_kernel(
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
qk_max = VLLM_SHFL_SYNC(qk_max, 0);

// Get the sum of the exp values.
float exp_sum = 0.f;
Expand Down Expand Up @@ -326,7 +334,7 @@ __device__ void paged_attention_kernel(
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
}
Expand Down Expand Up @@ -492,7 +500,7 @@ __global__ void paged_attention_v2_reduce_kernel(
// Reduce within the warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
if (lane == 0) {
red_smem[warp_idx] = max_logit;
Expand All @@ -502,10 +510,10 @@ __global__ void paged_attention_v2_reduce_kernel(
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
// Broadcast the max value to all threads.
max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
max_logit = VLLM_SHFL_SYNC(max_logit, 0);

// Load rescaled exp sums to shared memory.
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
Expand Down Expand Up @@ -539,9 +547,9 @@ __global__ void paged_attention_v2_reduce_kernel(
} // namespace vllm

#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
cudaFuncSetAttribute( \
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \
shared_mem_size); \
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
Expand Down
3 changes: 2 additions & 1 deletion csrc/attention/attention_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
#pragma once

#include "../cuda_compat.h"
#include "attention_dtypes.h"

#include <float.h>
Expand All @@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
}
return qk;
}
Expand Down
19 changes: 16 additions & 3 deletions csrc/attention/dtype_bfloat16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@
#include "attention_generic.cuh"
#include "dtype_float32.cuh"

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>

typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
#endif

#include <stdint.h>

namespace vllm {
Expand Down Expand Up @@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return a + b;
#ifndef USE_ROCM
return a + b;
#else
return __hadd(a, b);
#endif
#endif
}

Expand Down
Loading

0 comments on commit 6ccc0bf

Please sign in to comment.