-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision Flag To Rotary Embedding #5461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
isky-cd
merged 13 commits into
hpcaitech:feature/colossal-infer
from
isky-cd:context_flash_attn_branch
Mar 25, 2024
Merged
[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision Flag To Rotary Embedding #5461
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
519d318
Support FP16/BF16 Flash Attention 2
isky-cd 28c43fc
fix bugs in test_kv_cache_memcpy.py
isky-cd ecbbd38
add context_kv_cache_memcpy_kernel.cu
isky-cd 85017a5
rm typename MT
isky-cd 93cba15
add tail process
isky-cd 554a4f6
add high_precision
isky-cd f86518f
add high_precision to config.py
isky-cd 7cc62b8
rm unused code
isky-cd 42d8a5d
change the comment for the high_precision parameter
isky-cd 128b439
update test_rotary_embdding_unpad.py
isky-cd c5adf1c
Merge branch 'feature/colossal-infer' into context_flash_attn_branch
isky-cd 32d77a6
fix vector_copy_utils.h
isky-cd 1c1ea0e
add comment for self.high_precision when using float32
isky-cd File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
176 changes: 110 additions & 66 deletions
176
colossalai/inference/modeling/models/nopadding_llama.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,195 @@ | ||
| #include <ATen/cuda/CUDAContext.h> | ||
| #include <torch/extension.h> | ||
|
|
||
| #include "utils/vector_copy_utils.h" | ||
| #include "../common/micros.h" | ||
|
|
||
| template<typename scalar_t, int VecSize> | ||
| __global__ void context_kv_cache_memcpy_kernel( | ||
| const scalar_t* __restrict__ key, | ||
| const scalar_t* __restrict__ value, | ||
| scalar_t* __restrict__ key_cache, | ||
| scalar_t* __restrict__ value_cache, | ||
| const int* __restrict__ sequence_lengths, | ||
| const int* __restrict__ cu_seqlens, | ||
| const int* __restrict__ block_tables, | ||
| const int head_num, | ||
| const int head_dim, | ||
| const int block_size, | ||
| const int batch_size, | ||
| const int block_table_stride, | ||
| const int64_t key_stride, | ||
| const int64_t value_stride | ||
| ) | ||
| { | ||
| const int seq_token_id = blockIdx.x; | ||
| const int seq_id = blockIdx.y; | ||
| const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size]; | ||
|
|
||
| if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) { | ||
| return ; | ||
| } | ||
|
|
||
| const int block_offset = seq_token_id % block_size; | ||
| const int hidden_size = head_num * head_dim; | ||
| const int total_token_id = cu_seqlens[seq_id] + seq_token_id; | ||
| int head_id; | ||
| int head_offset; | ||
| int64_t key_src_id; | ||
| int64_t value_src_id; | ||
| int64_t target_id; | ||
|
|
||
| int i = threadIdx.x * VecSize; | ||
|
|
||
| for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { | ||
| head_id = i / head_dim; | ||
| head_offset = i % head_dim; | ||
| key_src_id = total_token_id * key_stride + i; | ||
| value_src_id = total_token_id * value_stride + i; | ||
| target_id = block_id * hidden_size * block_size | ||
| + head_id * block_size * head_dim | ||
| + block_offset * head_dim + head_offset; | ||
|
|
||
| copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id); | ||
| copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id); | ||
| } | ||
|
|
||
| // tail process | ||
| for (; i < hidden_size; ++i ) { | ||
| head_id = i / head_dim; | ||
| head_offset = i % head_dim; | ||
| key_src_id = total_token_id * key_stride + i; | ||
| value_src_id = total_token_id * value_stride + i; | ||
| target_id = block_id * hidden_size * block_size | ||
| + head_id * block_size * head_dim | ||
| + block_offset * head_dim + head_offset; | ||
|
|
||
| key_cache[target_id] = key[key_src_id]; | ||
| value_cache[target_id] = value[value_src_id]; | ||
| } | ||
|
|
||
| } | ||
|
|
||
| template<typename scalar_t> | ||
| void apply_context_kv_cache_memcpy( | ||
| at::Tensor& key, // [num_tokens, head_num, head_dim] | ||
| at::Tensor& value, // [num_tokens, head_num, head_dim] | ||
| at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] | ||
| at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] | ||
| at::Tensor& sequence_lengths, // [batch_size] | ||
| at::Tensor& cu_seqlens, // [batch_size + 1] | ||
| at::Tensor& block_tables, // [batch_size, max_seq_len] | ||
| int max_seq_len_in_batch) | ||
| { | ||
| int num_tokens = key.size(0); | ||
| int head_num = key.size(1); | ||
| int head_dim = key.size(2); | ||
| int block_size = key_cache.size(2); | ||
| int batch_size = block_tables.size(0); | ||
|
|
||
| int64_t key_stride = key.stride(0); | ||
| int64_t value_stride = value.stride(0); | ||
| int block_table_stride = block_tables.stride(0); | ||
|
|
||
| int vec_size = get_vec_size<scalar_t>(key); | ||
|
|
||
| if (head_dim % vec_size != 0) { | ||
isky-cd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // Disable vectorized loading optimization when head_dim is not divisible by VecSize. | ||
| vec_size = 1; | ||
| } | ||
|
|
||
| int thread_nums = head_num * head_dim / vec_size; | ||
|
|
||
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
|
||
| dim3 grid(max_seq_len_in_batch, batch_size); | ||
| dim3 block(std::min(thread_nums, 512)); | ||
|
|
||
| switch (vec_size) { | ||
| case 1: | ||
| context_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>( | ||
| key.data_ptr<scalar_t>(), | ||
| value.data_ptr<scalar_t>(), | ||
| key_cache.data_ptr<scalar_t>(), | ||
| value_cache.data_ptr<scalar_t>(), | ||
| sequence_lengths.data_ptr<int>(), | ||
| cu_seqlens.data_ptr<int>(), | ||
| block_tables.data_ptr<int>(), | ||
| head_num, | ||
| head_dim, | ||
| block_size, | ||
| batch_size, | ||
| block_table_stride, | ||
| key_stride, | ||
| value_stride | ||
| ); | ||
| break; | ||
| case 2: | ||
| context_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>( | ||
| key.data_ptr<scalar_t>(), | ||
| value.data_ptr<scalar_t>(), | ||
| key_cache.data_ptr<scalar_t>(), | ||
| value_cache.data_ptr<scalar_t>(), | ||
| sequence_lengths.data_ptr<int>(), | ||
| cu_seqlens.data_ptr<int>(), | ||
| block_tables.data_ptr<int>(), | ||
| head_num, | ||
| head_dim, | ||
| block_size, | ||
| batch_size, | ||
| block_table_stride, | ||
| key_stride, | ||
| value_stride | ||
| ); | ||
| break; | ||
| case 4: | ||
| context_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>( | ||
| key.data_ptr<scalar_t>(), | ||
| value.data_ptr<scalar_t>(), | ||
| key_cache.data_ptr<scalar_t>(), | ||
| value_cache.data_ptr<scalar_t>(), | ||
| sequence_lengths.data_ptr<int>(), | ||
| cu_seqlens.data_ptr<int>(), | ||
| block_tables.data_ptr<int>(), | ||
| head_num, | ||
| head_dim, | ||
| block_size, | ||
| batch_size, | ||
| block_table_stride, | ||
| key_stride, | ||
| value_stride | ||
| ); | ||
| break; | ||
| default: | ||
| AT_ERROR("Unsupported vectorized size ", vec_size); | ||
| break; | ||
| } | ||
|
|
||
| AT_CUDA_CHECK(cudaGetLastError()); | ||
|
|
||
| } | ||
|
|
||
| void context_kv_cache_memcpy( | ||
| at::Tensor& key, // [num_tokens, head_num, head_dim] | ||
| at::Tensor& value, // [num_tokens, head_num, head_dim] | ||
| at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] | ||
| at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] | ||
| at::Tensor& sequence_lengths, // [batch_size] | ||
| at::Tensor& cu_seqlens, // [batch_size + 1] | ||
| at::Tensor& block_tables, // [batch_size, max_seq_len] | ||
| int max_seq_len_in_batch) | ||
| { | ||
| DISPATCH_FLOAT_HALF_AND_BFLOAT( | ||
| key.scalar_type(), | ||
| "context_kv_cache_memcpy", | ||
| apply_context_kv_cache_memcpy<scalar_t>( | ||
| key, | ||
| value, | ||
| key_cache, | ||
| value_cache, | ||
| sequence_lengths, | ||
| cu_seqlens, | ||
| block_tables, | ||
| max_seq_len_in_batch | ||
| );) | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.