Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Continuation] Merge EmbeddedLLM/vllm-rocm into vLLM main #1836

Merged
merged 63 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
43af310
port dtype_float16.cuh and cache_kernels.cu
pcmoritz Oct 10, 2023
cc81866
port dtype_bfloat16.cuh
pcmoritz Oct 10, 2023
475b5e2
port attention_utils.cuh
pcmoritz Oct 10, 2023
ddc496c
port more kernels
pcmoritz Oct 10, 2023
5eaa7a1
fix typo
pcmoritz Oct 10, 2023
f7273c6
add cuda_compat.h
pcmoritz Oct 10, 2023
99c3be7
Merge branch 'main' into port-to-rocm
pcmoritz Oct 16, 2023
f8093dc
sync branches
pcmoritz Oct 16, 2023
41df689
update
pcmoritz Oct 16, 2023
93be9c5
update
pcmoritz Oct 16, 2023
d96fa3c
fixes
pcmoritz Oct 16, 2023
421365b
cleanup
pcmoritz Oct 16, 2023
06b800e
update
pcmoritz Oct 16, 2023
2312beb
update
pcmoritz Oct 16, 2023
2958b39
update
pcmoritz Oct 16, 2023
3f89734
fmt
pcmoritz Oct 16, 2023
5397a57
cleanup
pcmoritz Oct 16, 2023
90e02d2
refactor
pcmoritz Oct 16, 2023
a420202
update
pcmoritz Oct 16, 2023
b072182
Merge branch 'main' into port-to-rocm
pcmoritz Oct 17, 2023
2d1e435
detecting rocm and adding flag for compiling
iAmir97 Oct 17, 2023
e231b79
using asm volatile instead of hip api
iAmir97 Oct 17, 2023
31bb335
using asm volatile for type casting of f16
iAmir97 Oct 17, 2023
b027d06
Hipifying csrc file to accomodate rocm builds
kliuae Nov 27, 2023
9a1781c
Checked CUDA ROCm Compatibility (#15)
tjtanaa Nov 29, 2023
0f67117
merged with latest upstream
kliuae Nov 29, 2023
7dbf2d4
format code
kliuae Nov 29, 2023
52ffcf0
downgrade torch requirement in toml to torch 2.0.1 to accommodate ROC…
kliuae Nov 29, 2023
27f0513
Merged changes from vllm main
kliuae Dec 1, 2023
5cce649
Merged with changes in vllm main
kliuae Dec 1, 2023
16d3ccc
Updated Dockerfile, rocm installation guide and setuppy
kliuae Dec 1, 2023
d764f9d
Updated amd installation guide and dockerfile
kliuae Dec 2, 2023
e798632
Added num_gpus for ray init in ROCm
kliuae Dec 2, 2023
0e8129f
Synced torch version with vllm main in pyproject.toml
kliuae Dec 2, 2023
2b3821b
Format code
kliuae Dec 2, 2023
0c8795a
Merge branch 'main' into vllm-cuda-rocm-dev
kliuae Dec 4, 2023
5793f30
Updated dockerfile.rocm and requirements-rocm.txt
kliuae Dec 4, 2023
b172cdd
Disable mistral for ROCm
kliuae Dec 4, 2023
9cd5b18
Format code
kliuae Dec 4, 2023
b86f88a
Revert to cuda kernels
kliuae Dec 5, 2023
9727ab4
Merge remote-tracking branch 'pcmoritz/port-to-rocm'
kliuae Dec 5, 2023
c4aa2af
Port latest kernels to ROCm
kliuae Dec 5, 2023
f8c304e
Update readme
kliuae Dec 5, 2023
e608c30
Cleaned up kernel code
kliuae Dec 5, 2023
951e225
Added wrapper for setting devFuncAttributeMaxDynamicSharedMemorySize
kliuae Dec 6, 2023
25f9a97
Added wrapper for setting devFuncAttributeMaxDynamicSharedMemorySize
kliuae Dec 6, 2023
e984ada
Updated ROCm warp size
kliuae Dec 6, 2023
cc1195f
Format code
kliuae Dec 6, 2023
f92980e
Check hip from wrapper
kliuae Dec 6, 2023
66b4aa1
Format code
kliuae Dec 6, 2023
4a0ecb8
Enable support for mistral models
kliuae Dec 6, 2023
acf51a8
Fixed hip device attribute
kliuae Dec 6, 2023
4a52977
Format code
kliuae Dec 6, 2023
23a987a
Restored awq file
kliuae Dec 7, 2023
8787a4e
Format code
kliuae Dec 7, 2023
5911131
Merge latest vllm main
kliuae Dec 7, 2023
9fa8075
Updated rocm dockerfile
kliuae Dec 7, 2023
81e052d
Update amd installation guide
kliuae Dec 7, 2023
fb8ac26
Update vLLM Documentations (#18)
tjtanaa Dec 7, 2023
98f5487
Updated setup.py, vllm/utils.py and amd-installation doc
kliuae Dec 8, 2023
d90187a
Updated setup.py
kliuae Dec 8, 2023
c840531
Format code
kliuae Dec 8, 2023
9dba1d8
Merge branch 'main' into vllm-cuda-rocm-mod
kliuae Dec 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
port attention_utils.cuh
  • Loading branch information
pcmoritz committed Oct 10, 2023
commit 475b5e2875f9f870b88206bf087ff6adc99517a9
30 changes: 29 additions & 1 deletion csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
#ifndef USE_ROCM
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
#else
sum += __shfl_xor(uint32_t(-1), sum, mask);
#endif
}

// Warp leaders store the data to shared memory.
Expand All @@ -58,11 +62,19 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
#ifndef USE_ROCM
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
#else
sum += __shfl_xor(uint32_t(-1), sum, mask);
#endif
}

// Broadcast to other threads.
#ifndef USE_ROCM
return __shfl_sync(uint32_t(-1), sum, 0);
#else
return __shfl(uint32_t(-1), sum, 0);
#endif
}

// Grid: (num_heads, num_seqs).
Expand Down Expand Up @@ -196,7 +208,11 @@ __global__ void single_query_cached_kv_attention_kernel(
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
#ifndef USE_ROCM
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
#else
qk_max = fmaxf(qk_max, __shfl_xor(uint32_t(-1), qk_max, mask));
#endif
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
Expand All @@ -208,10 +224,18 @@ __global__ void single_query_cached_kv_attention_kernel(
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
#ifndef USE_ROCM
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
#else
qk_max = fmaxf(qk_max, __shfl_xor(uint32_t(-1), qk_max, mask));
#endif
}
// Broadcast the max qk value to all threads.
#ifndef USE_ROCM
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
#else
qk_max = __shfl(uint32_t(-1), qk_max, 0);
#endif

// Get the sum of the exp values.
float exp_sum = 0.f;
Expand Down Expand Up @@ -284,7 +308,11 @@ __global__ void single_query_cached_kv_attention_kernel(
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
#ifndef USE_ROCM
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
#else
acc += __shfl_xor(uint32_t(-1), acc, mask);
#endif
}
accs[i] = acc;
}
Expand Down Expand Up @@ -342,7 +370,7 @@ __global__ void single_query_cached_kv_attention_kernel(

#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
cudaFuncSetAttribute( \
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
(void*)vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
Expand Down
4 changes: 4 additions & 0 deletions csrc/attention/attention_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
#ifndef USE_ROCM
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
#else
qk += __shfl_xor(uint32_t(-1), qk, mask);
#endif
}
return qk;
}
Expand Down