Skip to content

Commit 1fa1c24

Browse files
authored
Merge pull request apache#48 from vinx13/vllm-fp8
Add fp8 fused dequant-paged-attention in vllm
2 parents b1d7caa + 80651a2 commit 1fa1c24

File tree

9 files changed

+783
-116
lines changed

9 files changed

+783
-116
lines changed

cmake/config.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,5 @@ set(USE_UMA OFF)
442442

443443
# Set custom Alloc Alignment for device allocated memory ndarray points to
444444
set(USE_KALLOC_ALIGNMENT 64)
445+
446+
set(USE_CUDA_FP8 OFF)

cmake/modules/CUDA.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ if(USE_CUDA)
109109
# Add CUDA builtins to RelaxVM
110110
tvm_file_glob(GLOB RELAX_VM_CUDA_BUILTIN_SRC_CC src/runtime/relax_vm/cuda/*.cc)
111111
list(APPEND RUNTIME_SRCS ${RELAX_VM_CUDA_BUILTIN_SRC_CC})
112+
113+
if (USE_CUDA_FP8)
114+
message(STATUS "Build with CUDA FP8 support")
115+
add_definitions(-DUSE_CUDA_FP8=1)
116+
endif()
112117
else(USE_CUDA)
113118
list(APPEND COMPILER_SRCS src/target/opt/build_cuda_off.cc)
114119
endif(USE_CUDA)

cmake/modules/LibInfo.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ function(add_lib_info src_file)
138138
TVM_INFO_USE_MSC="${USE_MSC}"
139139
TVM_INFO_USE_CCACHE="${USE_CCACHE}"
140140
TVM_INFO_BACKTRACE_ON_SEGFAULT="${BACKTRACE_ON_SEGFAULT}"
141+
TVM_INFO_USE_CUDA_FP8="${USE_CUDA_FP8}"
141142
)
142143

143144
endfunction()

src/runtime/contrib/vllm/attention_kernels.cu

Lines changed: 123 additions & 76 deletions
Large diffs are not rendered by default.

src/runtime/contrib/vllm/cache_alloc.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,22 @@ namespace runtime {
2525
namespace vllm {
2626

2727
Array<NDArray> AllocateKVCache(int head_size, int num_layers, int num_heads, int block_size,
28-
int num_blocks) {
28+
int num_blocks, tvm::runtime::DataType kv_cache_dtype) {
29+
CHECK(kv_cache_dtype.is_float16() || kv_cache_dtype.is_float8())
30+
<< "Unsupported data type for kv_cache: " << kv_cache_dtype;
2931
Array<NDArray> cache;
30-
int element_size = 2;
32+
int element_size = kv_cache_dtype.bits() / 8;
3133
int vec_size = 16 / element_size;
32-
3334
int device_id;
3435
cudaGetDevice(&device_id);
3536

3637
DLDevice dev{DLDeviceType::kDLCUDA, device_id};
3738

3839
for (int i = 0; i < num_layers; ++i) {
39-
NDArray key_blocks =
40-
NDArray::Empty({num_blocks, num_heads, head_size / vec_size, block_size, vec_size},
41-
runtime::DataType::Float(16), dev);
42-
NDArray value_blocks = NDArray::Empty({num_blocks, num_heads, head_size, block_size},
43-
runtime::DataType::Float(16), dev);
40+
NDArray key_blocks = NDArray::Empty(
41+
{num_blocks, num_heads, head_size / vec_size, block_size, vec_size}, kv_cache_dtype, dev);
42+
NDArray value_blocks =
43+
NDArray::Empty({num_blocks, num_heads, head_size, block_size}, kv_cache_dtype, dev);
4444
cache.push_back(key_blocks);
4545
cache.push_back(value_blocks);
4646
}

src/runtime/contrib/vllm/cache_kernels.cu

Lines changed: 72 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,17 @@
2525
#include <map>
2626
#include <vector>
2727

28+
#include "quant_utils.cuh"
29+
2830
namespace vllm {
2931

30-
template <typename scalar_t>
32+
template <typename scalar_t, typename cache_t = scalar_t,
33+
KVCacheDType kv_cache_dtype = KVCacheDType::kFloat>
3134
__global__ void reshape_and_cache_kernel(
3235
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
3336
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
34-
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
35-
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
37+
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
38+
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
3639
const int* __restrict__ slot_mapping, // [num_tokens]
3740
const int key_stride, const int value_stride, const int num_heads, const int head_size,
3841
const int block_size, const int x) {
@@ -57,18 +60,35 @@ __global__ void reshape_and_cache_kernel(
5760
const int tgt_value_idx = block_idx * num_heads * head_size * block_size +
5861
head_idx * head_size * block_size + head_offset * block_size +
5962
block_offset;
60-
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
61-
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
63+
if constexpr (kv_cache_dtype == KVCacheDType::kE5M2Float) {
64+
#if USE_CUDA_FP8
65+
key_cache[tgt_key_idx] =
66+
fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(__ldg(&key[src_key_idx]));
67+
value_cache[tgt_value_idx] =
68+
fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(__ldg(&value[src_value_idx]));
69+
#endif
70+
} else if constexpr (kv_cache_dtype == KVCacheDType::kE4M3Float) {
71+
#if USE_CUDA_FP8
72+
key_cache[tgt_key_idx] =
73+
fp8_e4m3_unscaled::vec_conversion<uint8_t, scalar_t>(__ldg(&key[src_key_idx]));
74+
value_cache[tgt_value_idx] =
75+
fp8_e4m3_unscaled::vec_conversion<uint8_t, scalar_t>(__ldg(&value[src_value_idx]));
76+
#endif
77+
} else {
78+
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
79+
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
80+
}
6281
}
6382
}
6483

65-
template <typename scalar_t>
84+
template <typename scalar_t, typename cache_t, KVCacheDType kv_cache_dtype>
6685
__global__ void reconstruct_from_cache_kernel(
67-
const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
68-
const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
69-
const int* __restrict__ slot_mapping, // [num_tokens]
70-
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
71-
scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
86+
const cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size,
87+
// x]
88+
const cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
89+
const int* __restrict__ slot_mapping, // [num_tokens]
90+
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
91+
scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
7292
const int key_stride, const int value_stride, const int num_heads, const int head_size,
7393
const int block_size, const int x) {
7494
const int token_idx = blockIdx.x;
@@ -93,8 +113,24 @@ __global__ void reconstruct_from_cache_kernel(
93113
head_idx * head_size * block_size + head_offset * block_size +
94114
block_offset;
95115

96-
key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
97-
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
116+
if constexpr (kv_cache_dtype == KVCacheDType::kE5M2Float) {
117+
#if USE_CUDA_FP8
118+
key[tgt_key_idx] =
119+
fp8_e5m2_unscaled::vec_conversion<scalar_t, uint8_t>(__ldg(&key_cache[src_key_idx]));
120+
value[tgt_value_idx] =
121+
fp8_e5m2_unscaled::vec_conversion<scalar_t, uint8_t>(__ldg(&value_cache[src_value_idx]));
122+
#endif
123+
} else if constexpr (kv_cache_dtype == KVCacheDType::kE4M3Float) {
124+
#if USE_CUDA_FP8
125+
key[tgt_key_idx] =
126+
fp8_e4m3_unscaled::vec_conversion<scalar_t, uint8_t>(__ldg(&key_cache[src_key_idx]));
127+
value[tgt_value_idx] =
128+
fp8_e4m3_unscaled::vec_conversion<scalar_t, uint8_t>(__ldg(&value_cache[src_value_idx]));
129+
#endif
130+
} else {
131+
key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
132+
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
133+
}
98134
}
99135
}
100136

@@ -144,14 +180,16 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache")
144180

145181
dim3 grid(num_tokens);
146182
dim3 block(std::min(num_heads * head_size, 512));
147-
148183
using scalar_t = uint16_t;
149-
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block>>>(
150-
static_cast<const scalar_t*>(key->data), static_cast<const scalar_t*>(value->data),
151-
static_cast<scalar_t*>(key_cache->data), static_cast<scalar_t*>(value_cache->data),
152-
static_cast<const int*>(slot_mapping->data), key_stride, value_stride, num_heads,
153-
head_size, block_size, vec_size);
154-
184+
using cache_t = uint16_t;
185+
using scalar_t = uint16_t;
186+
VLLM_DISPATCH_KV_CACHE_DTYPE(key_cache->dtype, {
187+
vllm::reshape_and_cache_kernel<scalar_t, cache_t, kv_cache_dtype><<<grid, block>>>(
188+
static_cast<const scalar_t*>(key->data), static_cast<const scalar_t*>(value->data),
189+
static_cast<cache_t*>(key_cache->data), static_cast<cache_t*>(value_cache->data),
190+
static_cast<const int*>(slot_mapping->data), key_stride, value_stride, num_heads,
191+
head_size, block_size, vec_size);
192+
});
155193
return Array{key_cache, value_cache};
156194
});
157195

@@ -174,13 +212,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reconstruct_from_cache")
174212
dim3 block(std::min(num_heads * head_size, 512));
175213

176214
using scalar_t = uint16_t;
177-
vllm::reconstruct_from_cache_kernel<scalar_t>
178-
<<<grid, block>>>(static_cast<const scalar_t*>(key_cache->data),
179-
static_cast<const scalar_t*>(value_cache->data),
180-
static_cast<const int*>(slot_mapping->data),
181-
static_cast<scalar_t*>(key->data), static_cast<scalar_t*>(value->data),
182-
key_stride, value_stride, num_heads, head_size, block_size, vec_size);
183-
215+
VLLM_DISPATCH_KV_CACHE_DTYPE(key_cache->dtype, {
216+
vllm::reconstruct_from_cache_kernel<scalar_t, cache_t, kv_cache_dtype><<<grid, block>>>(
217+
static_cast<const cache_t*>(key_cache->data),
218+
static_cast<const cache_t*>(value_cache->data),
219+
static_cast<const int*>(slot_mapping->data), static_cast<scalar_t*>(key->data),
220+
static_cast<scalar_t*>(value->data), key_stride, value_stride, num_heads, head_size,
221+
block_size, vec_size);
222+
});
184223
return Array{key, value};
185224
});
186225

@@ -223,11 +262,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.copy_blocks")
223262
dim3 grid(num_layers, num_pairs);
224263
dim3 block(std::min(1024, numel_per_block));
225264

226-
using scalar_t = uint16_t;
227-
vllm::copy_blocks_kernel<scalar_t>
228-
<<<grid, block>>>(static_cast<int64_t*>(key_cache_ptrs_gpu->data),
229-
static_cast<int64_t*>(value_cache_ptrs_gpu->data),
230-
static_cast<int64_t*>(block_mapping_gpu->data), numel_per_block);
265+
VLLM_DISPATCH_KV_CACHE_DTYPE(key_cache->dtype, {
266+
vllm::copy_blocks_kernel<cache_t>
267+
<<<grid, block>>>(static_cast<int64_t*>(key_cache_ptrs_gpu->data),
268+
static_cast<int64_t*>(value_cache_ptrs_gpu->data),
269+
static_cast<int64_t*>(block_mapping_gpu->data), numel_per_block);
270+
});
231271
});
232272

233273
} // namespace runtime
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#pragma once
2+
3+
#include <stdint.h>
4+
5+
namespace vllm {
6+
// fp8 vector types for quantization of kv cache
7+
8+
template <>
9+
struct Vec<uint8_t, 1> {
10+
using Type = uint8_t;
11+
};
12+
13+
template <>
14+
struct Vec<uint8_t, 2> {
15+
using Type = uint16_t;
16+
};
17+
18+
template <>
19+
struct Vec<uint8_t, 4> {
20+
using Type = uint32_t;
21+
};
22+
23+
template <>
24+
struct Vec<uint8_t, 8> {
25+
using Type = uint2;
26+
};
27+
28+
} // namespace vllm

0 commit comments

Comments
 (0)