44#include " utils/vec_copy.h"
55#include " common/micros.h"
66
7- using colossalAI::cuda::utils::copy_vector;
87using colossalAI::cuda::utils::get_vec_size;
8+ using colossalAI::cuda::utils::copy;
9+ using colossalAI::funcs::CastFunctor;
910
1011
11- template <typename scalar_t , bool Aligned, int VecSize>
12+ template <typename T, typename CacheT , bool Aligned, int VecSize>
1213__global__ void context_kv_cache_memcpy_kernel (
13- const scalar_t * __restrict__ key,
14- const scalar_t * __restrict__ value,
15- scalar_t * __restrict__ key_cache,
16- scalar_t * __restrict__ value_cache,
14+ const T * __restrict__ key,
15+ const T * __restrict__ value,
16+ CacheT * __restrict__ key_cache,
17+ CacheT * __restrict__ value_cache,
1718 const int * __restrict__ sequence_lengths,
1819 const int * __restrict__ cu_seqlens,
1920 const int * __restrict__ block_tables,
@@ -54,8 +55,8 @@ __global__ void context_kv_cache_memcpy_kernel(
5455 + head_id * block_size * head_dim
5556 + block_offset * head_dim + head_offset;
5657
57- copy_vector< scalar_t , VecSize>(key_cache + target_id, key + key_src_id );
58- copy_vector< scalar_t , VecSize>(value_cache + target_id, value + value_src_id );
58+ copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_id );
59+ copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_id );
5960 }
6061
6162 // tail process
@@ -69,22 +70,22 @@ __global__ void context_kv_cache_memcpy_kernel(
6970 + head_id * block_size * head_dim
7071 + block_offset * head_dim + head_offset;
7172
72- key_cache[target_id] = key[key_src_id];
73- value_cache[target_id] = value[value_src_id];
73+ key_cache[target_id] = CastFunctor<T, CacheT>()( key[key_src_id]) ;
74+ value_cache[target_id] = CastFunctor<T, CacheT>()( value[value_src_id]) ;
7475 }
7576 }
7677
7778}
7879
79- template <typename scalar_t >
80+ template <typename T, typename CacheT >
8081void apply_context_kv_cache_memcpy (
81- at ::Tensor& key, // [num_tokens, head_num, head_dim]
82- at ::Tensor& value, // [num_tokens, head_num, head_dim]
83- at ::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
84- at ::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
85- at ::Tensor& sequence_lengths, // [batch_size]
86- at ::Tensor& cu_seqlens, // [batch_size + 1]
87- at ::Tensor& block_tables, // [batch_size, max_seq_len]
82+ torch ::Tensor& key, // [num_tokens, head_num, head_dim]
83+ torch ::Tensor& value, // [num_tokens, head_num, head_dim]
84+ torch ::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
85+ torch ::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
86+ torch ::Tensor& sequence_lengths, // [batch_size]
87+ torch ::Tensor& cu_seqlens, // [batch_size + 1]
88+ torch ::Tensor& block_tables, // [batch_size, max_seq_len]
8889 int max_seq_len_in_batch)
8990{
9091 int num_tokens = key.size (0 );
@@ -97,7 +98,7 @@ void apply_context_kv_cache_memcpy(
9798 int64_t value_stride = value.stride (0 );
9899 int block_table_stride = block_tables.stride (0 );
99100
100- int vec_size = get_vec_size<scalar_t >(key);
101+ int vec_size = get_vec_size<T >(key);
101102
102103 bool aligned = true ;
103104 if (head_dim % vec_size != 0 ) {
@@ -112,11 +113,11 @@ void apply_context_kv_cache_memcpy(
112113
113114#define CONTEXT_KV_CACHE_MEMCOPY_KERNEL_LAUNCH (__aligned, __vec_size ) \
114115 do { \
115- context_kv_cache_memcpy_kernel<scalar_t , __aligned, __vec_size><<<grid, block, 0 , stream>>> ( \
116- key.data_ptr < scalar_t >(), \
117- value.data_ptr < scalar_t >(), \
118- key_cache.data_ptr < scalar_t >(), \
119- value_cache.data_ptr < scalar_t >(), \
116+ context_kv_cache_memcpy_kernel<T, CacheT , __aligned, __vec_size><<<grid, block, 0 , stream>>> ( \
117+ reinterpret_cast <T*>( key.data_ptr ()), \
118+ reinterpret_cast <T*>( value.data_ptr ()), \
119+ reinterpret_cast <CacheT*>( key_cache.data_ptr ()), \
120+ reinterpret_cast <CacheT*>( value_cache.data_ptr ()), \
120121 sequence_lengths.data_ptr <int >(), \
121122 cu_seqlens.data_ptr <int >(), \
122123 block_tables.data_ptr <int >(), \
@@ -161,26 +162,63 @@ void apply_context_kv_cache_memcpy(
161162}
162163
163164void context_kv_cache_memcpy (
164- at ::Tensor& key, // [num_tokens, head_num, head_dim]
165- at ::Tensor& value, // [num_tokens, head_num, head_dim]
166- at ::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
167- at ::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
168- at ::Tensor& sequence_lengths, // [batch_size]
169- at ::Tensor& cu_seqlens, // [batch_size + 1]
170- at ::Tensor& block_tables, // [batch_size, max_seq_len]
165+ torch ::Tensor& key, // [num_tokens, head_num, head_dim]
166+ torch ::Tensor& value, // [num_tokens, head_num, head_dim]
167+ torch ::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
168+ torch ::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
169+ torch ::Tensor& sequence_lengths, // [batch_size]
170+ torch ::Tensor& cu_seqlens, // [batch_size + 1]
171+ torch ::Tensor& block_tables, // [batch_size, max_seq_len]
171172 int max_seq_len_in_batch)
172173{
173- DISPATCH_FLOAT_HALF_AND_BFLOAT (
174- key.scalar_type (),
175- " context_kv_cache_memcpy" ,
176- apply_context_kv_cache_memcpy<scalar_t >(
177- key,
178- value,
179- key_cache,
180- value_cache,
181- sequence_lengths,
182- cu_seqlens,
183- block_tables,
184- max_seq_len_in_batch
185- );)
174+
175+ TORCH_CHECK (key.scalar_type () == at::ScalarType::Float || key.scalar_type () == at::ScalarType::Half || key.scalar_type () == at::ScalarType::BFloat16,
176+ " Dtype of key should be float, half or bfloat16!" );
177+ TORCH_CHECK (key_cache.scalar_type () == at::ScalarType::Byte || key_cache.scalar_type () == key.scalar_type (),
178+ " Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!" );
179+
180+
181+ #define _ (T, CacheT ) \
182+ apply_context_kv_cache_memcpy<T, CacheT>( \
183+ key, \
184+ value, \
185+ key_cache, \
186+ value_cache, \
187+ sequence_lengths, \
188+ cu_seqlens, \
189+ block_tables, \
190+ max_seq_len_in_batch \
191+ )
192+
193+ if (key_cache.scalar_type () == at::ScalarType::Byte)
194+ {
195+ switch (key.scalar_type ())
196+ {
197+ case at::ScalarType::Float:
198+ _ (float , uint8_t );
199+ break ;
200+ case at::ScalarType::Half:
201+ _ (half, uint8_t );
202+ break ;
203+ case at::ScalarType::BFloat16:
204+ _ (__nv_bfloat16, uint8_t );
205+ break ;
206+ }
207+ }
208+ else
209+ {
210+ switch (key.scalar_type ())
211+ {
212+ case at::ScalarType::Float:
213+ _ (float , float );
214+ break ;
215+ case at::ScalarType::Half:
216+ _ (half, half);
217+ break ;
218+ case at::ScalarType::BFloat16:
219+ _ (__nv_bfloat16, __nv_bfloat16);
220+ break ;
221+ }
222+ }
223+ #undef _t
186224}
0 commit comments