2525#include < map>
2626#include < vector>
2727
28+ #include " quant_utils.cuh"
29+
2830namespace vllm {
2931
30- template <typename scalar_t >
32+ template <typename scalar_t , typename cache_t = scalar_t ,
33+ KVCacheDType kv_cache_dtype = KVCacheDType::kFloat >
3134__global__ void reshape_and_cache_kernel (
3235 const scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
3336 const scalar_t * __restrict__ value, // [num_tokens, num_heads, head_size]
34- scalar_t * __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
35- scalar_t * __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
37+ cache_t * __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
38+ cache_t * __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
3639 const int * __restrict__ slot_mapping, // [num_tokens]
3740 const int key_stride, const int value_stride, const int num_heads, const int head_size,
3841 const int block_size, const int x) {
@@ -57,18 +60,35 @@ __global__ void reshape_and_cache_kernel(
5760 const int tgt_value_idx = block_idx * num_heads * head_size * block_size +
5861 head_idx * head_size * block_size + head_offset * block_size +
5962 block_offset;
60- key_cache[tgt_key_idx] = __ldg (&key[src_key_idx]);
61- value_cache[tgt_value_idx] = __ldg (&value[src_value_idx]);
63+ if constexpr (kv_cache_dtype == KVCacheDType::kE5M2Float ) {
64+ #if USE_CUDA_FP8
65+ key_cache[tgt_key_idx] =
66+ fp8_e5m2_unscaled::vec_conversion<uint8_t , scalar_t >(__ldg (&key[src_key_idx]));
67+ value_cache[tgt_value_idx] =
68+ fp8_e5m2_unscaled::vec_conversion<uint8_t , scalar_t >(__ldg (&value[src_value_idx]));
69+ #endif
70+ } else if constexpr (kv_cache_dtype == KVCacheDType::kE4M3Float ) {
71+ #if USE_CUDA_FP8
72+ key_cache[tgt_key_idx] =
73+ fp8_e4m3_unscaled::vec_conversion<uint8_t , scalar_t >(__ldg (&key[src_key_idx]));
74+ value_cache[tgt_value_idx] =
75+ fp8_e4m3_unscaled::vec_conversion<uint8_t , scalar_t >(__ldg (&value[src_value_idx]));
76+ #endif
77+ } else {
78+ key_cache[tgt_key_idx] = __ldg (&key[src_key_idx]);
79+ value_cache[tgt_value_idx] = __ldg (&value[src_value_idx]);
80+ }
6281 }
6382}
6483
65- template <typename scalar_t >
84+ template <typename scalar_t , typename cache_t , KVCacheDType kv_cache_dtype >
6685__global__ void reconstruct_from_cache_kernel (
67- const scalar_t * __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
68- const scalar_t * __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
69- const int * __restrict__ slot_mapping, // [num_tokens]
70- scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
71- scalar_t * __restrict__ value, // [num_tokens, num_heads, head_size]
86+ const cache_t * __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size,
87+ // x]
88+ const cache_t * __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
89+ const int * __restrict__ slot_mapping, // [num_tokens]
90+ scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
91+ scalar_t * __restrict__ value, // [num_tokens, num_heads, head_size]
7292 const int key_stride, const int value_stride, const int num_heads, const int head_size,
7393 const int block_size, const int x) {
7494 const int token_idx = blockIdx .x ;
@@ -93,8 +113,24 @@ __global__ void reconstruct_from_cache_kernel(
93113 head_idx * head_size * block_size + head_offset * block_size +
94114 block_offset;
95115
96- key[tgt_key_idx] = __ldg (&key_cache[src_key_idx]);
97- value[tgt_value_idx] = __ldg (&value_cache[src_value_idx]);
116+ if constexpr (kv_cache_dtype == KVCacheDType::kE5M2Float ) {
117+ #if USE_CUDA_FP8
118+ key[tgt_key_idx] =
119+ fp8_e5m2_unscaled::vec_conversion<scalar_t , uint8_t >(__ldg (&key_cache[src_key_idx]));
120+ value[tgt_value_idx] =
121+ fp8_e5m2_unscaled::vec_conversion<scalar_t , uint8_t >(__ldg (&value_cache[src_value_idx]));
122+ #endif
123+ } else if constexpr (kv_cache_dtype == KVCacheDType::kE4M3Float ) {
124+ #if USE_CUDA_FP8
125+ key[tgt_key_idx] =
126+ fp8_e4m3_unscaled::vec_conversion<scalar_t , uint8_t >(__ldg (&key_cache[src_key_idx]));
127+ value[tgt_value_idx] =
128+ fp8_e4m3_unscaled::vec_conversion<scalar_t , uint8_t >(__ldg (&value_cache[src_value_idx]));
129+ #endif
130+ } else {
131+ key[tgt_key_idx] = __ldg (&key_cache[src_key_idx]);
132+ value[tgt_value_idx] = __ldg (&value_cache[src_value_idx]);
133+ }
98134 }
99135}
100136
@@ -144,14 +180,16 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache")
144180
145181 dim3 grid (num_tokens);
146182 dim3 block (std::min (num_heads * head_size, 512 ));
147-
148183 using scalar_t = uint16_t ;
149- vllm::reshape_and_cache_kernel<scalar_t ><<<grid, block>>> (
150- static_cast <const scalar_t *>(key->data ), static_cast <const scalar_t *>(value->data ),
151- static_cast <scalar_t *>(key_cache->data ), static_cast <scalar_t *>(value_cache->data ),
152- static_cast <const int *>(slot_mapping->data ), key_stride, value_stride, num_heads,
153- head_size, block_size, vec_size);
154-
184+ using cache_t = uint16_t ;
185+ using scalar_t = uint16_t ;
186+ VLLM_DISPATCH_KV_CACHE_DTYPE (key_cache->dtype , {
187+ vllm::reshape_and_cache_kernel<scalar_t , cache_t , kv_cache_dtype><<<grid, block>>> (
188+ static_cast <const scalar_t *>(key->data ), static_cast <const scalar_t *>(value->data ),
189+ static_cast <cache_t *>(key_cache->data ), static_cast <cache_t *>(value_cache->data ),
190+ static_cast <const int *>(slot_mapping->data ), key_stride, value_stride, num_heads,
191+ head_size, block_size, vec_size);
192+ });
155193 return Array{key_cache, value_cache};
156194 });
157195
@@ -174,13 +212,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reconstruct_from_cache")
174212 dim3 block (std::min (num_heads * head_size, 512 ));
175213
176214 using scalar_t = uint16_t ;
177- vllm::reconstruct_from_cache_kernel<scalar_t >
178- <<<grid, block>>> (static_cast <const scalar_t *>(key_cache->data ),
179- static_cast <const scalar_t *>(value_cache->data ),
180- static_cast <const int *>(slot_mapping->data ),
181- static_cast <scalar_t *>(key->data ), static_cast <scalar_t *>(value->data ),
182- key_stride, value_stride, num_heads, head_size, block_size, vec_size);
183-
215+ VLLM_DISPATCH_KV_CACHE_DTYPE (key_cache->dtype , {
216+ vllm::reconstruct_from_cache_kernel<scalar_t , cache_t , kv_cache_dtype><<<grid, block>>> (
217+ static_cast <const cache_t *>(key_cache->data ),
218+ static_cast <const cache_t *>(value_cache->data ),
219+ static_cast <const int *>(slot_mapping->data ), static_cast <scalar_t *>(key->data ),
220+ static_cast <scalar_t *>(value->data ), key_stride, value_stride, num_heads, head_size,
221+ block_size, vec_size);
222+ });
184223 return Array{key, value};
185224 });
186225
@@ -223,11 +262,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.copy_blocks")
223262 dim3 grid (num_layers, num_pairs);
224263 dim3 block (std::min (1024 , numel_per_block));
225264
226- using scalar_t = uint16_t ;
227- vllm::copy_blocks_kernel<scalar_t >
228- <<<grid, block>>> (static_cast <int64_t *>(key_cache_ptrs_gpu->data ),
229- static_cast <int64_t *>(value_cache_ptrs_gpu->data ),
230- static_cast <int64_t *>(block_mapping_gpu->data ), numel_per_block);
265+ VLLM_DISPATCH_KV_CACHE_DTYPE (key_cache->dtype , {
266+ vllm::copy_blocks_kernel<cache_t >
267+ <<<grid, block>>> (static_cast <int64_t *>(key_cache_ptrs_gpu->data ),
268+ static_cast <int64_t *>(value_cache_ptrs_gpu->data ),
269+ static_cast <int64_t *>(block_mapping_gpu->data ), numel_per_block);
270+ });
231271 });
232272
233273} // namespace runtime
0 commit comments