From 981a46104682e1aaaa76efca3491b2a5ef81b918 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Mon, 31 Jul 2023 16:36:09 +0800 Subject: [PATCH] [Fix] Remove unused code to reduce binary size (#181) * clean-up * fix lint * fix lint --- CMakeLists.txt | 6 - src/turbomind/kernels/CMakeLists.txt | 18 - src/turbomind/kernels/activation_kernels.cu | 353 +------- .../kernels/beam_search_penalty_kernels.cu | 313 ------- .../kernels/beam_search_penalty_kernels.h | 48 - .../kernels/beam_search_topk_kernels.cu | 845 ------------------ .../kernels/beam_search_topk_kernels.h | 94 -- .../kernels/bert_preprocess_kernels.cu | 293 ------ .../kernels/bert_preprocess_kernels.h | 58 -- .../decoder_masked_multihead_attention.h | 1 - src/turbomind/kernels/decoding_kernels.cu | 509 +---------- src/turbomind/kernels/decoding_kernels.h | 76 -- .../kernels/gen_relative_pos_bias.cu | 304 ------- src/turbomind/kernels/gen_relative_pos_bias.h | 56 -- src/turbomind/kernels/logprob_kernels.cu | 50 +- .../online_softmax_beamsearch_kernels.cu | 739 --------------- .../online_softmax_beamsearch_kernels.h | 41 - src/turbomind/layers/CMakeLists.txt | 7 +- src/turbomind/layers/DynamicDecodeLayer.cc | 137 +-- src/turbomind/layers/DynamicDecodeLayer.h | 3 - src/turbomind/layers/FfnFP8Layer.cc | 535 ----------- src/turbomind/layers/FfnFP8Layer.h | 133 --- src/turbomind/layers/FfnFP8Weight.h | 30 - src/turbomind/layers/FfnINT8Weight.h | 28 - src/turbomind/layers/FfnLayerINT8.cc | 340 ------- src/turbomind/layers/FfnLayerINT8.h | 146 --- .../attention_layers_fp8/AttentionFP8Weight.h | 34 - .../BaseAttentionFP8Layer.h | 65 -- .../attention_layers_fp8/CMakeLists.txt | 15 - .../AttentionINT8Weight.h | 29 - .../attention_layers_int8/CMakeLists.txt | 15 - .../beam_search_layers/BaseBeamSearchLayer.cu | 291 ------ .../beam_search_layers/BaseBeamSearchLayer.h | 80 -- .../beam_search_layers/BeamSearchLayer.cu | 354 -------- .../beam_search_layers/BeamSearchLayer.h | 68 -- .../layers/beam_search_layers/CMakeLists.txt | 30 - .../OnlineBeamSearchLayer.cu | 249 ------ .../OnlineBeamSearchLayer.h | 65 -- src/turbomind/models/llama/CMakeLists.txt | 1 - src/turbomind/models/llama/prefix_cache.cu | 55 -- src/turbomind/models/llama/prefix_cache.h | 9 - 41 files changed, 89 insertions(+), 6434 deletions(-) delete mode 100644 src/turbomind/kernels/beam_search_penalty_kernels.cu delete mode 100644 src/turbomind/kernels/beam_search_penalty_kernels.h delete mode 100644 src/turbomind/kernels/beam_search_topk_kernels.cu delete mode 100644 src/turbomind/kernels/beam_search_topk_kernels.h delete mode 100644 src/turbomind/kernels/gen_relative_pos_bias.cu delete mode 100644 src/turbomind/kernels/gen_relative_pos_bias.h delete mode 100644 src/turbomind/kernels/online_softmax_beamsearch_kernels.cu delete mode 100644 src/turbomind/kernels/online_softmax_beamsearch_kernels.h delete mode 100644 src/turbomind/layers/FfnFP8Layer.cc delete mode 100644 src/turbomind/layers/FfnFP8Layer.h delete mode 100644 src/turbomind/layers/FfnFP8Weight.h delete mode 100644 src/turbomind/layers/FfnINT8Weight.h delete mode 100644 src/turbomind/layers/FfnLayerINT8.cc delete mode 100644 src/turbomind/layers/FfnLayerINT8.h delete mode 100644 src/turbomind/layers/attention_layers_fp8/AttentionFP8Weight.h delete mode 100644 src/turbomind/layers/attention_layers_fp8/BaseAttentionFP8Layer.h delete mode 100644 src/turbomind/layers/attention_layers_fp8/CMakeLists.txt delete mode 100644 src/turbomind/layers/attention_layers_int8/AttentionINT8Weight.h delete mode 100644 src/turbomind/layers/attention_layers_int8/CMakeLists.txt delete mode 100644 src/turbomind/layers/beam_search_layers/BaseBeamSearchLayer.cu delete mode 100644 src/turbomind/layers/beam_search_layers/BaseBeamSearchLayer.h delete mode 100644 src/turbomind/layers/beam_search_layers/BeamSearchLayer.cu delete mode 100644 src/turbomind/layers/beam_search_layers/BeamSearchLayer.h delete mode 100644 src/turbomind/layers/beam_search_layers/CMakeLists.txt delete mode 100644 src/turbomind/layers/beam_search_layers/OnlineBeamSearchLayer.cu delete mode 100644 src/turbomind/layers/beam_search_layers/OnlineBeamSearchLayer.h delete mode 100644 src/turbomind/models/llama/prefix_cache.cu delete mode 100644 src/turbomind/models/llama/prefix_cache.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4fe23a91ce..0a19014897 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -299,21 +299,16 @@ endif() ######################################## add_library(transformer-shared SHARED - $ $ - $ $ $ $ $ - $ $ $ $ $ $ - $ - $ $ $ $ @@ -329,7 +324,6 @@ add_library(transformer-shared SHARED $ $ $ - $ $ $ $ diff --git a/src/turbomind/kernels/CMakeLists.txt b/src/turbomind/kernels/CMakeLists.txt index d718b1fdf1..853131c440 100644 --- a/src/turbomind/kernels/CMakeLists.txt +++ b/src/turbomind/kernels/CMakeLists.txt @@ -26,11 +26,6 @@ add_library(activation_kernels STATIC activation_kernels.cu) set_property(TARGET activation_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET activation_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -add_library(gen_relative_pos_bias STATIC gen_relative_pos_bias.cu) -set_property(TARGET gen_relative_pos_bias PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET gen_relative_pos_bias PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(gen_relative_pos_bias PUBLIC activation_kernels) - add_library(logprob_kernels STATIC logprob_kernels.cu) set_property(TARGET logprob_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET logprob_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) @@ -51,10 +46,6 @@ add_library(decoder_masked_multihead_attention STATIC ${decoder_masked_multihead set_property(TARGET decoder_masked_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET decoder_masked_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -add_library(online_softmax_beamsearch_kernels STATIC online_softmax_beamsearch_kernels.cu) -set_property(TARGET online_softmax_beamsearch_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET online_softmax_beamsearch_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - add_library(decoding_kernels STATIC decoding_kernels.cu) set_property(TARGET decoding_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET decoding_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) @@ -63,15 +54,6 @@ add_library(gpt_kernels STATIC gpt_kernels.cu) set_property(TARGET gpt_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET gpt_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -add_library(beam_search_penalty_kernels STATIC beam_search_penalty_kernels.cu) -set_property(TARGET beam_search_penalty_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET beam_search_penalty_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(beam_search_penalty_kernels PRIVATE cuda_utils) - -add_library(beam_search_topk_kernels STATIC beam_search_topk_kernels.cu) -set_property(TARGET beam_search_topk_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET beam_search_topk_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) - add_library(sampling_topk_kernels STATIC sampling_topk_kernels.cu) set_property(TARGET sampling_topk_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET sampling_topk_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/turbomind/kernels/activation_kernels.cu b/src/turbomind/kernels/activation_kernels.cu index 664ae68a14..ef0465ce01 100644 --- a/src/turbomind/kernels/activation_kernels.cu +++ b/src/turbomind/kernels/activation_kernels.cu @@ -306,17 +306,17 @@ void invokeGenericActivation(T* out, const int seq_len, \ cudaStream_t stream); -INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, float, float); -INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, half, half); -#ifdef ENABLE_BF16 -INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, __nv_bfloat16, __nv_bfloat16); -#endif - -INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, float, float); -INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, half, half); -#ifdef ENABLE_BF16 -INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, __nv_bfloat16, __nv_bfloat16); -#endif +// INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, float, float); +// INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, half, half); +// #ifdef ENABLE_BF16 +// INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, __nv_bfloat16, __nv_bfloat16); +// #endif + +// INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, float, float); +// INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, half, half); +// #ifdef ENABLE_BF16 +// INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, __nv_bfloat16, __nv_bfloat16); +// #endif INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, float, float); INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, half, half); @@ -324,335 +324,4 @@ INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, half, half); INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, __nv_bfloat16, __nv_bfloat16); #endif -INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, float); -INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, half, half); -INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, half); -#ifdef ENABLE_BF16 -INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, __nv_bfloat16, __nv_bfloat16); -INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, __nv_bfloat16); -#endif -#undef INSTANCIATE_GENERIC_ACTIVATION - -template -__global__ void add_bias_tanh(T* out, const T* __restrict bias, int m, int n) -{ - for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { - T val = out[id]; - if (bias != nullptr) { - val = val + ldg(&bias[id % n]); - } - out[id] = tanhf(val); - } -} - -template<> -__global__ void add_bias_tanh(half* out, const half* __restrict bias, int m, int n) -{ - half2* out_ptr = (half2*)out; - const half2* bias_ptr = (half2*)bias; - - for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { - half2 val = out_ptr[id]; - if (bias != nullptr) { - val = val + __ldg(&bias_ptr[id % n]); - } - val.x = tanhf(val.x); - val.y = tanhf(val.y); - out_ptr[id] = val; - } -} - -#ifdef ENABLE_BF16 -template<> -__global__ void add_bias_tanh(__nv_bfloat16* out, const __nv_bfloat16* __restrict bias, int m, int n) -{ - __nv_bfloat162* out_ptr = (__nv_bfloat162*)out; - const __nv_bfloat162* bias_ptr = (__nv_bfloat162*)bias; - - for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { - __nv_bfloat162 val = out_ptr[id]; - if (bias != nullptr) { - val = bf16hadd2(val, ldg(&bias_ptr[id % n])); - } - val.x = tanhf(val.x); - val.y = tanhf(val.y); - out_ptr[id] = val; - } -} -#endif - -template -void invokeAddBiasTanh(T* out, const T* bias, const int m, const int n, cudaStream_t stream) -{ - const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 - dim3 block, grid; - if (n / 4 / data_type_factor <= 1024) { - block.x = n / 4 / data_type_factor; - grid.x = m; - } - else { - block.x = 1024; - grid.x = ceil(m * n / 1024.); - } - add_bias_tanh<<>>(out, bias, m, n / data_type_factor); -} - -template void invokeAddBiasTanh(float* out, const float* bias, const int m, const int n, cudaStream_t stream); -template void invokeAddBiasTanh(half* out, const half* bias, const int m, const int n, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void -invokeAddBiasTanh(__nv_bfloat16* out, const __nv_bfloat16* bias, const int m, const int n, cudaStream_t stream); -#endif - -template -__global__ void addBiasGeluV2(T2* out, - const T2* __restrict bias, - const int* ia3_tasks, - const T2* ia3_weights, - const int size, - const int* padding_offset, - const int seq_len) -{ - const bool with_ia3 = ia3_tasks != nullptr; - for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) { - T2 val = out[id]; - if (bias != nullptr) { - T2 reg_bias = ldg(&bias[id % N]); - val = hadd2(val, reg_bias); - } - val = GeluActivation::apply(val); - if (with_ia3) { - const int word_id = id / N; - const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id]; - const int batch_id = (word_id + offset) / seq_len; - const int task = ia3_tasks[batch_id]; - val = val * ia3_weights[task * N + (id % N)]; - } - out[id] = val; - } -} - -template -__global__ void addBiasGeluV3(T2* out, - const T2* __restrict bias, - const int* ia3_tasks, - const T2* ia3_weights, - const int size, - const int* padding_offset, - const int seq_len) -{ - const bool with_ia3 = ia3_tasks != nullptr; - T2 buffer[ELEMENT_PER_ROUND]; - T2 tmp_bias[ELEMENT_PER_ROUND]; - for (int id = blockIdx.x * blockDim.x * ELEMENT_PER_ROUND + threadIdx.x * ELEMENT_PER_ROUND; id < size; - id += blockDim.x * gridDim.x * ELEMENT_PER_ROUND) { -#pragma unroll - for (int i = 0; i < ELEMENT_PER_ROUND; i++) { - buffer[i] = out[id + i]; - if (bias != nullptr) { - tmp_bias[i] = ldg(&bias[(id + i) % N]); - } - } -#pragma unroll - for (int i = 0; i < ELEMENT_PER_ROUND; i++) { - if (bias != nullptr) { - buffer[i] = hadd2(buffer[i], tmp_bias[i]); - } - buffer[i] = GeluActivation::apply(buffer[i]); - if (with_ia3) { - const int word_id = (id + i) / N; - const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id]; - const int batch_id = (word_id + offset) / seq_len; - const int task = ia3_tasks[batch_id]; - buffer[i] = buffer[i] * ia3_weights[task * N + ((id + i) % N)]; - } - out[id + i] = buffer[i]; - } - } -} - -#define ADD_BIAS_GELU(HALF_N, ELEMENT_PER_ROUND) \ - case HALF_N: \ - if (ELEMENT_PER_ROUND > 1) { \ - grid.x = grid.x / ELEMENT_PER_ROUND; \ - addBiasGeluV3<<>>( \ - (T2*)out, (const T2*)bias, ia3_tasks, (T2*)ia3_weights, m * half_n, padding_offset, seq_len); \ - } \ - else { \ - addBiasGeluV2<<>>( \ - (T2*)out, (const T2*)bias, ia3_tasks, (T2*)ia3_weights, m * half_n, padding_offset, seq_len); \ - } \ - break; - -template -void invokeAddBiasGeluV2(T* out, - const T* bias, - const int* ia3_tasks, - const T* ia3_weights, - const int* padding_offset, - const int seq_len, - const int m, - const int n, - cudaStream_t stream) -{ - if (n % 2 == 0 && sizeof(T) == 2) { - const int half_n = n / 2; - dim3 block, grid; - block.x = std::min(half_n, 512); - grid.x = (m * half_n + (block.x - 1)) / block.x; - using T2 = typename TypeConverter::Type; - - if (grid.x >= 512) { - switch (half_n) { - ADD_BIAS_GELU(256, 1) - ADD_BIAS_GELU(512, 1) - ADD_BIAS_GELU(1024, 1) - ADD_BIAS_GELU(1536, 1) - ADD_BIAS_GELU(2048, 1) - ADD_BIAS_GELU(4096, 2) - ADD_BIAS_GELU(8192, 2) - ADD_BIAS_GELU(16384, 2) - ADD_BIAS_GELU(24576, 2) - ADD_BIAS_GELU(40960, 4) - default: - invokeGenericActivation(out, - bias, - (T*)nullptr, - (T*)nullptr, - ia3_tasks, - ia3_weights, - m, - n, - 0, - (float*)nullptr, - (float*)nullptr, - padding_offset, - seq_len, - stream); - break; - } - } - else { - switch (half_n) { - ADD_BIAS_GELU(256, 1) - ADD_BIAS_GELU(512, 1) - ADD_BIAS_GELU(1024, 1) - ADD_BIAS_GELU(1536, 1) - ADD_BIAS_GELU(2048, 1) - ADD_BIAS_GELU(4096, 1) - ADD_BIAS_GELU(8192, 2) - ADD_BIAS_GELU(16384, 2) - ADD_BIAS_GELU(24576, 2) - ADD_BIAS_GELU(40960, 2) - default: - invokeGenericActivation(out, - bias, - (T*)nullptr, - (T*)nullptr, - ia3_tasks, - ia3_weights, - m, - n, - 0, - (float*)nullptr, - (float*)nullptr, - padding_offset, - seq_len, - stream); - break; - } - } - } - else { - invokeGenericActivation(out, - bias, - (T*)nullptr, - (T*)nullptr, - ia3_tasks, - ia3_weights, - m, - n, - 0, - (float*)nullptr, - (float*)nullptr, - padding_offset, - seq_len, - stream); - } -} - -#undef ADD_BIAS_GELU - -template void invokeAddBiasGeluV2(float* out, - const float* bias, - const int* ia3_tasks, - const float* ia3_weights, - const int* padding_offset, - const int seq_len, - const int m, - const int n, - cudaStream_t stream); -template void invokeAddBiasGeluV2(half* out, - const half* bias, - const int* ia3_tasks, - const half* ia3_weights, - const int* padding_offset, - const int seq_len, - const int m, - const int n, - cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeAddBiasGeluV2(__nv_bfloat16* out, - const __nv_bfloat16* bias, - const int* ia3_tasks, - const __nv_bfloat16* ia3_weights, - const int* padding_offset, - const int seq_len, - const int m, - const int n, - cudaStream_t stream); -#endif // ENABLE_BF16 - -template -__global__ void sigmoid_kernel(T* data, const int size, const float scale) -{ - const int index = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; - if (index < size) { - float val = cuda_cast(data[index]); - val = 1.0f / (1.0f + exp(-val)) * scale; - data[index] = T(val); - } -} - -template<> -__global__ void sigmoid_kernel(half2* data, const int size, const float scale) -{ - const int index = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; - if (index < size / 2) { - half2 val = data[index]; - float2 val_float2 = cuda_cast(val); - val_float2.x = 1.0f / (1.0f + exp(-val_float2.x)) * scale; - val_float2.y = 1.0f / (1.0f + exp(-val_float2.y)) * scale; - data[index] = cuda_cast(val_float2); - } -} - -template -void invokeSigmoid(T* data, const int size, const float scale, cudaStream_t stream) -{ - if (std::is_same::value || (size % 2 != 0)) { - dim3 block(128); - dim3 grid((size + 127) / 128); - sigmoid_kernel<<>>(data, size, scale); - } - else { - dim3 block(128); - dim3 grid((size + 255) / 256); - sigmoid_kernel<<>>((half2*)data, size, scale); - } -} - -template void invokeSigmoid(float* data, const int size, const float scale, cudaStream_t stream); -template void invokeSigmoid(half* data, const int size, const float scale, cudaStream_t stream); - } // namespace turbomind diff --git a/src/turbomind/kernels/beam_search_penalty_kernels.cu b/src/turbomind/kernels/beam_search_penalty_kernels.cu deleted file mode 100644 index 99b65be57a..0000000000 --- a/src/turbomind/kernels/beam_search_penalty_kernels.cu +++ /dev/null @@ -1,313 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "src/turbomind/kernels/beam_search_penalty_kernels.h" -#include "src/turbomind/kernels/reduce_kernel_utils.cuh" - -namespace turbomind { - -template -__global__ void add_bias_temperature(T* logits, - const T* bias, - const int batch_size, - const int beam_width, - const int vocab_size, - const int vocab_size_padded, - const float temperature) -{ - int tid = threadIdx.x; - int bid = blockIdx.x; - int bbid = blockIdx.y; - - logits += bbid * vocab_size_padded; - - const T MASK_VAL = (std::is_same::value) ? -HALF_FLT_MAX : -FLT_MAX; - const T inv_temp = static_cast(1.0f / (temperature + 1e-6f)); - for (int i = tid + bid * blockDim.x; i < vocab_size_padded; i += blockDim.x * gridDim.x) { - if (i < vocab_size) { - T bias_val = bias == nullptr ? (T)(0.0f) : bias[i]; - logits[i] = (logits[i] + bias_val) * inv_temp; - } - else { - logits[i] = MASK_VAL; - } - } -} - -template<> -__global__ void add_bias_temperature(half2* logits, - const half2* bias, - const int batch_size, - const int beam_width, - const int vocab_size, - const int vocab_size_padded, - const float temperature) -{ - assert(vocab_size % 2 == 0); - assert(vocab_size_padded % 2 == 0); - - const int tid = threadIdx.x; - const int bid = blockIdx.x; - const int bbid = blockIdx.y; - - const half2 mask_val = __float2half2_rn(-HALF_FLT_MAX); - const half2 inv_temp = __float2half2_rn(1.0f / (temperature + 1e-6f)); - - const int half_vocab_size = vocab_size / 2; - const int half_vocab_size_padded = vocab_size_padded / 2; - - logits += bbid * half_vocab_size_padded; - for (int index = tid + bid * blockDim.x; index < half_vocab_size_padded; index += blockDim.x * gridDim.x) { - int vocab_idx = index % half_vocab_size_padded; - half2 logit = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val; - if (vocab_idx < half_vocab_size) { - if (bias != nullptr) { - logit = __hadd2(logit, bias[vocab_idx]); - } - logit = __hmul2(logit, inv_temp); - } - logits[index] = logit; - } -} - -template -__global__ void apply_repetition_penalty(T* logits, - const int batch_size, - const int beam_width, - const int vocab_size, - const int vocab_size_padded, - const int step, - const int* current_ids, - const int* previous_ids, - const int* parent_ids, - const int* input_lengths, - const int max_input_length, - const float repetition_penalty) -{ - assert(step > 0); - - const int tid = threadIdx.x; - const int bbid = blockIdx.x; - const int batch_id = bbid / beam_width; - const int bbsize = batch_size * beam_width; - - logits += bbid * vocab_size_padded; - extern __shared__ char sbuf[]; - T* penalty_logits = reinterpret_cast(sbuf); - // prevent misaligment when sizeof(T) = 2 - int* penalty_indices = reinterpret_cast(sbuf + (sizeof(T) * step + 31) / 32 * 32); - const int input_length = (input_lengths != nullptr) ? input_lengths[bbid] : max_input_length; - if (tid == 0) { - T repet_penalty = static_cast(repetition_penalty); - int prev_id = current_ids[bbid]; - T prev_logit = logits[prev_id]; - penalty_indices[step - 1] = prev_id; - - if (IS_ADDITIVE) { - penalty_logits[step - 1] = prev_logit - repet_penalty; - } - else { - penalty_logits[step - 1] = prev_logit > T(0) ? prev_logit / repet_penalty : prev_logit * repet_penalty; - } - if (step > 1) { - int parent_beam = bbid % beam_width; - for (int i = step - 2; i >= 0; --i) { - // Skip the padded tokens. - if (i >= input_length && i < max_input_length) { - continue; - } - parent_beam = parent_ids[i * bbsize + batch_id * beam_width + parent_beam]; - prev_id = previous_ids[i * bbsize + batch_id * beam_width + parent_beam]; - prev_logit = logits[prev_id]; - penalty_indices[i] = prev_id; - if (IS_ADDITIVE) { - penalty_logits[i] = prev_logit - repet_penalty; - } - else { - penalty_logits[i] = prev_logit > T(0) ? prev_logit / repet_penalty : prev_logit * repet_penalty; - } - } - } - } - __syncthreads(); - for (int i = tid; i < step; i += blockDim.x) { - if (i >= input_length && i < max_input_length) { - continue; - } - logits[penalty_indices[i]] = penalty_logits[i]; - } -} - -template -__global__ void apply_min_length_penalty(T* logits, - const int min_length, - const int* end_ids, - const int* sequence_lengths, - const int max_input_length, - const int beam_width, - const int vocab_size_padded) -{ - int bbid = threadIdx.x + blockIdx.x * blockDim.x; // batch-beam index - int bid = bbid / beam_width; // batch index - // We need +1 because sequence_lengths = max_input_length + num_gen_tokens - 1, - // which is equal to the length of k/v caches. - if (sequence_lengths[bbid] + 1 - max_input_length < min_length) { - T mask_val = (std::is_same::value) ? -HALF_FLT_MAX : -FLT_MAX; - logits[bbid * vocab_size_padded + end_ids[bid]] = mask_val; - } -} - -template -void invokeAddBiasApplyPenalties(int step, - T* logits, - const int* current_ids, - const int* previous_ids, - const int* parent_ids, - const int* input_lengths, - const int* sequence_lengths, - const T* bias, - const int ite, - const int max_input_length, - const int local_batch_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int vocab_size_padded, - const int* end_ids, - const float temperature, - const float repetition_penalty, - const RepetitionPenaltyType repetition_penalty_type, - const int min_length, - cudaStream_t stream) -{ - if (bias != nullptr || temperature != 1.0f || vocab_size != vocab_size_padded) { - dim3 block(512); - if (std::is_same::value && vocab_size % 2 == 0 && vocab_size_padded % 2 == 0) { - dim3 grid((vocab_size_padded / 2 + block.x - 1) / block.x, beam_width * local_batch_size); - add_bias_temperature<<>>(reinterpret_cast(logits), - reinterpret_cast(bias), - batch_size, - beam_width, - vocab_size, - vocab_size_padded, - temperature); - } - else { - dim3 grid((vocab_size_padded + block.x - 1) / block.x, beam_width * local_batch_size); - add_bias_temperature<<>>( - logits, bias, batch_size, beam_width, vocab_size, vocab_size_padded, temperature); - } - } - - if (repetition_penalty_type != RepetitionPenaltyType::None && step > 0) { - if (repetition_penalty != getDefaultPenaltyValue(repetition_penalty_type)) { - size_t smem_size = (sizeof(T) * step + 31) / 32 * 32 + sizeof(int) * step; - dim3 block(256); - dim3 grid(beam_width * local_batch_size); - if (repetition_penalty_type == RepetitionPenaltyType::Multiplicative) { - apply_repetition_penalty - <<>>(logits, - batch_size, - beam_width, - vocab_size, - vocab_size_padded, - step, - current_ids, - previous_ids, - // TODO(jaedeokk): - // Remove (+ite ...) by getting parent_ids with offset - // and then remove 'ite' argument from the function. - parent_ids + ite * beam_width * local_batch_size, - input_lengths, - max_input_length, - repetition_penalty); - } - else if (repetition_penalty_type == RepetitionPenaltyType::Additive) { - apply_repetition_penalty - <<>>(logits, - batch_size, - beam_width, - vocab_size, - vocab_size_padded, - step, - current_ids, - previous_ids, - parent_ids + ite * beam_width * local_batch_size, - input_lengths, - max_input_length, - repetition_penalty); - } - } - } - - if (step - max_input_length < min_length) { - FT_CHECK_WITH_INFO(sequence_lengths != nullptr, "Need sequence_lengths to apply min length penlaty"); - FT_CHECK_WITH_INFO(end_ids != nullptr, "Need end_id to apply min length penlaty"); - - const int block_size = min(local_batch_size * beam_width, 1024); - const int grid_size = (local_batch_size * beam_width + block_size - 1) / block_size; - apply_min_length_penalty<<>>( - logits, min_length, end_ids, sequence_lengths, max_input_length, beam_width, vocab_size_padded); - } -} - -template void invokeAddBiasApplyPenalties(int step, - float* logits, - const int* current_ids, - const int* previous_ids, - const int* parent_ids, - const int* input_lengths, - const int* sequence_lengths, - const float* bias, - const int ite, - const int max_input_length, - const int local_batch_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int vocab_size_padded, - const int* end_ids, - const float temperature, - const float repetition_penalty, - const RepetitionPenaltyType repetition_penalty_type, - const int min_length, - cudaStream_t stream); - -template void invokeAddBiasApplyPenalties(int step, - half* logits, - const int* current_ids, - const int* previous_ids, - const int* parent_ids, - const int* input_lengths, - const int* sequence_lengths, - const half* bias, - const int ite, - const int max_input_length, - const int local_batch_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int vocab_size_padded, - const int* end_ids, - const float temperature, - const float repetition_penalty, - const RepetitionPenaltyType repetition_penalty_type, - const int min_length, - cudaStream_t stream); - -} // namespace turbomind diff --git a/src/turbomind/kernels/beam_search_penalty_kernels.h b/src/turbomind/kernels/beam_search_penalty_kernels.h deleted file mode 100644 index f91b7fac8f..0000000000 --- a/src/turbomind/kernels/beam_search_penalty_kernels.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include "src/turbomind/kernels/penalty_types.h" -#include "src/turbomind/utils/cuda_utils.h" - -namespace turbomind { - -template -void invokeAddBiasApplyPenalties(int step, - T* logits, - const int* current_ids, - const int* previous_ids, - const int* parent_ids, - const int* input_lengths, - const int* sequence_lengths, - const T* bias, - const int ite, - const int max_input_length, - const int local_batch_size, - const int batch_size, - const int beam_width, - const int vocab_size, - const int vocab_size_padded, - const int* end_ids, - const float temperature, - const float repetition_penalty, - const RepetitionPenaltyType repetition_penalty_type, - const int min_length, - cudaStream_t stream); - -} // namespace turbomind diff --git a/src/turbomind/kernels/beam_search_topk_kernels.cu b/src/turbomind/kernels/beam_search_topk_kernels.cu deleted file mode 100644 index d758c93f48..0000000000 --- a/src/turbomind/kernels/beam_search_topk_kernels.cu +++ /dev/null @@ -1,845 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef CUDART_VERSION -#error CUDART_VERSION Undefined! -#elif (CUDART_VERSION >= 11050) -#include -#else -#include "3rdparty/cub/cub.cuh" -#endif - -#include "src/turbomind/kernels/beam_search_topk_kernels.h" -#include "src/turbomind/kernels/reduce_kernel_utils.cuh" -#include "src/turbomind/utils/cuda_type_utils.cuh" -#include "src/turbomind/utils/cuda_utils.h" -#include "src/turbomind/utils/logger.h" - -namespace turbomind { - -template -__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty) -{ - // score = log(prob) / (length)^length_penalty. - if (length_penalty == 0.0f || length == 1) { - return log_prob; - } - return log_prob / static_cast(powf((float)length, length_penalty)); -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_topK_kernel(const T* log_probs, - int* topk_tmp_id_buf, - T* topk_tmp_val_buf, - const bool* finished, - const int* sequence_lengths, - const int vocab_size, - T diversity_rate, - float length_penalty) -{ - typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int thread_id = threadIdx.x; - int block_id = blockIdx.x; // batch beam index. - TopK partial; - - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - -#pragma unroll - for (int i = 0; i < MAX_K; ++i) { - partial.p[i] = -1; - partial.u[i] = -MAX_T_VAL; - } - -#pragma unroll - for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) { - int index = elem_id + block_id * vocab_size; - T score = length_penalty == 0.0f ? log_probs[index] : - apply_length_penalty(log_probs[index], - finished[block_id] ? sequence_lengths[block_id] : - sequence_lengths[block_id] + 1, - length_penalty); - partial.insert(score, index); - } - - TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); - - if (thread_id == 0) { - int index = block_id * MAX_K; - -#pragma unroll - for (int i = 0; i < MAX_K; ++i) { - topk_tmp_id_buf[index + i] = total.p[i]; - topk_tmp_val_buf[index + i] = total.u[i] + diversity_rate * (T)i; - } - } -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ - void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf) -{ - int thread_id = threadIdx.x; - int block_id = blockIdx.x; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - TopK partial; - if (thread_id == 0) { - for (int i = 0; i < MAX_K; ++i) { - partial.p[i] = -1; - partial.u[i] = -MAX_T_VAL; - } - - int index = block_id * MAX_K * MAX_K; - for (int i = 0; i < MAX_K * MAX_K; i++) { - partial.insert((T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]); - } - - index = block_id * MAX_K; - for (int i = 0; i < MAX_K; i++) { - id_buf[index + i] = partial.p[i]; - } - } -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ - void batch_topK_kernel_v2(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf) -{ - typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int tid = threadIdx.x; - int bid = blockIdx.x; - TopK partial; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - -#pragma unroll - for (int i = 0; i < MAX_K; ++i) { - partial.p[i] = -1; - partial.u[i] = -MAX_T_VAL; - } - - int ite = MAX_K * MAX_K / THREADBLOCK_SIZE; -#pragma unroll - for (int i = 0; i < ite; i++) { - int index = bid * MAX_K * MAX_K + i * THREADBLOCK_SIZE + tid; - partial.insert((T)topk_tmp_val_buf[index], topk_tmp_id_buf[index]); - } - - TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); - - if (tid == 0) { -#pragma unroll - for (int i = 0; i < MAX_K; i++) { - id_buf[bid * MAX_K + i] = total.p[i]; - } - } -} - -template -__global__ void topk_stage_1_opt3(const T* __restrict log_probs, - T* tmp_log_probs, - int* topk_tmp_id_buf, - T* topk_tmp_val_buf, - const bool* finished, - const int* sequence_lengths, - const int k, - const int vocab_size, - const float length_penalty, - const int* end_ids) -{ - typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - const int tid = threadIdx.x; - const int bid = blockIdx.x; - - const int row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index) - const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam - const int tmp_log_buf_index = row_id * vocab_size; - const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k; - TopK_2 partial; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - - if (finished != nullptr && finished[row_id] == true) { - if (tid < k) { - const int index = tmp_topk_buf_index + tid; - if (block_lane == 0 && tid == 0) { - const int end_id = end_ids[row_id / k]; - topk_tmp_id_buf[index] = tmp_log_buf_index + end_id; - topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id]; - } - else { - topk_tmp_id_buf[index] = -1; - topk_tmp_val_buf[index] = -MAX_T_VAL; - } - } - return; - } - - for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; - elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { - int index = elem_id + tmp_log_buf_index; - tmp_log_probs[index] = log_probs[index]; - } - - for (int ite = 0; ite < k; ite++) { - partial.init(); -#pragma unroll - for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; - elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { - int index = elem_id + tmp_log_buf_index; - partial.insert(tmp_log_probs[index], index); - } - - TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); - - if (tid == 0) { - const int index = tmp_topk_buf_index + ite; - topk_tmp_id_buf[index] = total.p; - topk_tmp_val_buf[index] = total.u; - tmp_log_probs[total.p] = -MAX_T_VAL; - } - __syncthreads(); - } -} - -template -__global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, - T* topk_tmp_val_buf, - int* ids, - BeamHypotheses beam_hyps, - const int* end_ids, - const int vocab_size, - const int k) -{ - const int size = k * k * BLOCKS_PER_BEAM_; - const int tid = threadIdx.x; - const int batch_id = blockIdx.x; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - - typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - extern __shared__ char array[]; - T* s_val = topk_tmp_val_buf + batch_id * size; - int* s_id = (int*)(array); - - __shared__ int selected_beams; - __shared__ bool is_stop; - - if (tid == 0) { - selected_beams = 0; - is_stop = false; - } - __syncthreads(); - if (beam_hyps.num_beams != nullptr) { - const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; - if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0) { - // initialize the buffer - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; - } - else if (beam_hyps.num_beams[global_batch_idx] == k) { - return; - } - } - - TopK_2 partial; - - // In some cases, we may encounter k finished sentences, but scores are bad. So, the max iteration - // is 2*k here - for (int ite = 0; ite < 2 * k; ite++) { - partial.init(); -#pragma unroll - for (int i = tid; i < size; i += BLOCK_SIZE_) { - partial.insert(s_val[i], i); - } - - TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); - - if (tid == 0) { - if (beam_hyps.num_beams != nullptr - && topk_tmp_id_buf[batch_id * size + total.p] % vocab_size == end_ids[batch_id]) { - // if beam_token does not belong to top num_beams tokens, it should not be added. Refer from - // https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257 - if (ite >= k) { - s_val[total.p] = -MAX_T_VAL; - } - else { - const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; - const float normed_score = - apply_length_penalty(s_val[total.p], beam_hyps.step, beam_hyps.length_penalty); - const int num_beam = beam_hyps.num_beams[global_batch_idx]; - int beam_idx = num_beam; - // If there are beam_width finished sentences, check that the score of selected candidatet - // is higher than min_normed_score or not. If current score is better, replace worst one - // and update the min_normed_score. - if (num_beam == k) { - if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) { - // end the tracing and exist this for loop - selected_beams = k; - is_stop = true; - break; - } - else { - // find the beam index which's score = min_normed_score, erase it. - for (int j = 0; j < k; j++) { - if (beam_hyps.normed_scores[global_batch_idx * k + j] - == beam_hyps.min_normed_scores[global_batch_idx]) { - beam_idx = j; - beam_hyps.num_beams[global_batch_idx]--; - - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; - beam_hyps.normed_scores[global_batch_idx * k + j] = normed_score; - for (int l = 0; l < k; l++) { - beam_hyps.min_normed_scores[global_batch_idx] = - min(beam_hyps.min_normed_scores[global_batch_idx], - beam_hyps.normed_scores[global_batch_idx * k + l]); - } - break; - } - } - } - } - const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx) - * (beam_hyps.max_seq_len); - beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id]; - - int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k; - for (int j = beam_hyps.step - 1; j >= 0; j--) { - const int src_idx = j * beam_hyps.batch_size * k - + beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id; - - beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx]; - prev_id = beam_hyps.parent_ids_src[src_idx]; - } - const int tgt_beam_idx = global_batch_idx * k + beam_idx; - beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step; - beam_hyps.normed_scores[tgt_beam_idx] = normed_score; - beam_hyps.min_normed_scores[global_batch_idx] = - min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]); - - s_val[total.p] = -MAX_T_VAL; - - beam_hyps.num_beams[global_batch_idx]++; - } - } - else { - s_id[selected_beams] = total.p; - s_val[total.p] = -MAX_T_VAL; - selected_beams++; - } - } - __syncthreads(); - if (selected_beams >= k) { - break; - } - } - if (tid < k && is_stop == false) { - ids[batch_id * k + tid] = topk_tmp_id_buf[batch_id * size + s_id[tid]]; - } -} - -template -__global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, - T* tmp_log_probs, - int* topk_tmp_id_buf, - T* topk_tmp_val_buf, - const bool* finished, - const int* sequence_lengths, - const int k, - const int vocab_size, - const float length_penalty) -{ - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - const int tid = threadIdx.x; - const int bid = blockIdx.x; - const int row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs - const int block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam - const int tmp_log_buf_index = row_id * vocab_size; - const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k; - TopK_2 partial; - - for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) { - int index = elem_id + tmp_log_buf_index; - tmp_log_probs[index] = log_probs[index]; - } - - for (int ite = 0; ite < k; ite++) { - partial.init(); -#pragma unroll - for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; - elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) { - int index = elem_id + tmp_log_buf_index; - partial.insert(tmp_log_probs[index], index); - } - - TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); - - if (tid == 0) { - const int index = tmp_topk_buf_index + ite; - topk_tmp_id_buf[index] = total.p; - topk_tmp_val_buf[index] = total.u; - tmp_log_probs[total.p] = -MAX_T_VAL; - } - __syncthreads(); - } -} - -template -__global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, - T* topk_tmp_val_buf, - int* ids, - BeamHypotheses beam_hyps, - const int* end_ids, - const int k, - const int vocab_size) -{ - const int size = k * k * BLOCKS_PER_BEAM; - const int tid = threadIdx.x; - const int batch_id = blockIdx.x; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - - typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - extern __shared__ char array[]; - T* s_val = topk_tmp_val_buf + batch_id * size; - int* s_id = (int*)(array); - - __shared__ int selected_beams; - __shared__ bool is_stop; - - if (tid == 0) { - selected_beams = 0; - is_stop = false; - } - __syncthreads(); - if (beam_hyps.num_beams != nullptr) { - const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; - if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0) { - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; - } - else if (beam_hyps.num_beams[global_batch_idx] == k) { - return; - } - } - - TopK_2 partial; - - // In some cases, we may encounter k finished sentences, but scores are bad. So, the max iteration - // is 2*k here - for (int ite = 0; ite < 2 * k; ite++) { - partial.init(); -#pragma unroll - for (int i = tid; i < size; i += BLOCK_SIZE) { - partial.insert(s_val[i], i); - } - - TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); - - if (tid == 0) { - if (beam_hyps.num_beams != nullptr - && topk_tmp_id_buf[batch_id * size + total.p] % vocab_size == end_ids[batch_id]) { - // if beam_token does not belong to top num_beams tokens, it should not be added. Refer from - // https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257 - if (ite >= k) { - s_val[total.p] = -MAX_T_VAL; - } - else { - const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; - const float normed_score = - apply_length_penalty(s_val[total.p], beam_hyps.step, beam_hyps.length_penalty); - const int num_beam = beam_hyps.num_beams[global_batch_idx]; - int beam_idx = num_beam; - // If there are beam_width finished sentences, check that the score of selected candidatet - // is higher than min_normed_score or not. If current score is better, replace worst one - // and update the min_normed_score. - if (num_beam == k) { - if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) { - // end the tracing and exist this for loop - selected_beams = k; - is_stop = true; - break; - } - else { - // find the beam index which's score = min_normed_score, erase it. - for (int j = 0; j < k; j++) { - if (beam_hyps.normed_scores[global_batch_idx * k + j] - == beam_hyps.min_normed_scores[global_batch_idx]) { - beam_idx = j; - beam_hyps.num_beams[global_batch_idx]--; - - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; - beam_hyps.normed_scores[global_batch_idx * k + j] = normed_score; - for (int l = 0; l < k; l++) { - beam_hyps.min_normed_scores[global_batch_idx] = - min(beam_hyps.min_normed_scores[global_batch_idx], - beam_hyps.normed_scores[global_batch_idx * k + l]); - } - break; - } - } - } - } - const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx) - * (beam_hyps.max_seq_len); - beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id]; - - int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k; - for (int j = beam_hyps.step - 1; j >= 0; j--) { - const int src_idx = j * beam_hyps.batch_size * k - + beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id; - - beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx]; - prev_id = beam_hyps.parent_ids_src[src_idx]; - } - const int tgt_beam_idx = global_batch_idx * k + beam_idx; - beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step; - beam_hyps.normed_scores[tgt_beam_idx] = normed_score; - beam_hyps.min_normed_scores[global_batch_idx] = - min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]); - - s_val[total.p] = -MAX_T_VAL; - - beam_hyps.num_beams[global_batch_idx]++; - } - } - else { - s_id[selected_beams] = total.p; - s_val[total.p] = -MAX_T_VAL; - selected_beams++; - } - } - __syncthreads(); - if (selected_beams >= k) { - break; - } - } - if (tid < k && is_stop == false) { - ids[batch_id * k + tid] = topk_tmp_id_buf[batch_id * size + s_id[tid]]; - } -} - -#define CASE_K_DIV(K, BLOCK_SIZE_1, BLOCK_SIZE_2) \ - case K: \ - beam_topK_kernel<<>>(log_probs, \ - topk_tmp_id_buf, \ - topk_tmp_val_buf, \ - finished, \ - sequence_lengths, \ - vocab_size, \ - diversity_rate, \ - length_penalty); \ - if (K < 10) \ - batch_topK_kernel \ - <<>>(topk_tmp_id_buf, topk_tmp_val_buf, ids); \ - else \ - batch_topK_kernel_v2<<>>(topk_tmp_id_buf, topk_tmp_val_buf, ids); \ - break; - -#define CASE_K(K, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \ - case K: \ - topk_stage_1_opt3 \ - <<>>(log_probs, \ - temp_log_probs, \ - topk_tmp_id_buf, \ - topk_tmp_val_buf, \ - finished, \ - sequence_lengths, \ - beam_width, \ - vocab_size, \ - length_penalty, \ - end_ids); \ - topk_stage_2_opt3 \ - <<>>( \ - topk_tmp_id_buf, topk_tmp_val_buf, ids, *beam_hyps, end_ids, vocab_size, beam_width); \ - sync_check_cuda_error(); \ - break; - -template -void invokeTopkBeamSearch(void* workspace, - size_t& workspace_size, - T* log_probs, - int* ids, - BeamHypotheses* beam_hyps, - const bool* finished, - const int* sequence_lengths, - const int batch_size, - const int beam_width, - const int vocab_size_padded_, - const T diversity_rate, - const float length_penalty, - const int* end_ids, - cudaStream_t stream) -{ - TM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); - // log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a token. - const int vocab_size = vocab_size_padded_; - // Beam size should be less than or equal to vocab size. - assert(beam_width <= vocab_size); - // Beam search needs the sequence lengths of beams to apply length penalty. - assert(length_penalty == 0.0f || sequence_lengths != nullptr); - const int max_block_per_beam = 8; - int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float - int topk_tmp_ids_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type int - int topk_tmp_val_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type float - - // prevent memory misaligned address - temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4; - topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4; - topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4; - - if (workspace == nullptr) { - workspace_size = sizeof(float) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size - + sizeof(float) * topk_tmp_val_buf_size; - return; - } - else { - T* temp_log_probs = (T*)workspace; - int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size); - T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size); - if (diversity_rate == 0.0f) { - switch (beam_width) { - CASE_K(1, 128, 128, 8); - CASE_K(4, 128, 128, 8); - CASE_K(10, 128, 128, 8); - CASE_K(16, 128, 128, 5); - CASE_K(32, 256, 128, 1); - CASE_K(64, 256, 256, 1); - default: - topk_stage_1_opt2_general - <<>>(log_probs, - temp_log_probs, - topk_tmp_id_buf, - topk_tmp_val_buf, - finished, - sequence_lengths, - beam_width, - vocab_size, - length_penalty); - topk_stage_2_opt2_general - <<>>( - topk_tmp_id_buf, topk_tmp_val_buf, ids, *beam_hyps, end_ids, beam_width, vocab_size); - break; - } - } - else { - switch (beam_width) { - CASE_K_DIV(1, 256, 256); - CASE_K_DIV(4, 256, 256); - CASE_K_DIV(16, 256, 64); - CASE_K_DIV(32, 256, 64); - CASE_K_DIV(64, 256, 64); - default: - FT_CHECK_WITH_INFO(false, fmtstr("Topk kernel does not support beamwidth = %d \n", beam_width)); - break; - } - } - return; - } -} - -#undef CASE_K -#undef CASE_K_DIV - -template void invokeTopkBeamSearch(void* workspace, - size_t& workspace_size, - float* log_probs, - int* ids, - BeamHypotheses* beam_hyps, - const bool* finished, - const int* sequence_lengths, - const int batch_size, - const int beam_width, - const int vocab_size_padded_, - const float diversity_rate, - const float length_penalty, - const int* end_ids, - cudaStream_t stream); - -template -__global__ void tileEncoderResults(T* tiled_output, - int* tiled_sequence_length, - const T* output, - const int* sequence_length, - const uint batch_size, - const uint beam_width, - const uint d_model) -{ - if (blockIdx.x == 0) { - for (uint i = threadIdx.x; i < batch_size * beam_width; i += blockDim.x) { - tiled_sequence_length[i] = sequence_length[i / beam_width]; - } - } - - int tgt_offset = - blockIdx.x * gridDim.y * gridDim.z * d_model + blockIdx.y * gridDim.z * d_model + blockIdx.z * d_model; - int src_offset = blockIdx.x * gridDim.z * d_model + blockIdx.z * d_model; - for (uint i = threadIdx.x; i < d_model; i += blockDim.x) { - tiled_output[i + tgt_offset] = output[i + src_offset]; - } -} - -template -void invokeTileEncoderResults(T* tiled_output, - int* tiled_sequence_length, - const T* output, - const int* sequence_length, - const size_t batch_size, - const size_t beam_width, - const size_t mem_max_seq_len, - const size_t d_model, - cudaStream_t stream) -{ - // tiled_output: [batch_size, beam_width, mem_max_seq_len, d_model] - // tiled_sequence_length: [batch_size, beam_width] - - // output: [batch_size, mem_max_seq_len, d_model] - // sequence_length [batch_size] - - dim3 grid(batch_size, beam_width, mem_max_seq_len); - bool is_half2 = (std::is_same::value) && (d_model % 2 == 0); - - if (is_half2) { - using T2 = typename TypeConverter::Type; // fp16 to half2, bf16 to bf162 - dim3 block(min(512, (int)(d_model / 2))); - tileEncoderResults<<>>((T2*)tiled_output, - tiled_sequence_length, - (const T2*)output, - sequence_length, - batch_size, - beam_width, - d_model / 2); - } - else { - dim3 block(min(512, (int)d_model)); - tileEncoderResults<<>>( - tiled_output, tiled_sequence_length, output, sequence_length, batch_size, beam_width, d_model); - } -} - -template void invokeTileEncoderResults(float* tiled_output, - int* tiled_sequence_length, - const float* output, - const int* sequence_length, - const size_t batch_size, - const size_t beam_width, - const size_t mem_max_seq_len, - const size_t d_model, - cudaStream_t stream); - -template void invokeTileEncoderResults(half* tiled_output, - int* tiled_sequence_length, - const half* output, - const int* sequence_length, - const size_t batch_size, - const size_t beam_width, - const size_t mem_max_seq_len, - const size_t d_model, - cudaStream_t stream); - -template void invokeTileEncoderResults(half2* tiled_output, - int* tiled_sequence_length, - const half2* output, - const int* sequence_length, - const size_t batch_size, - const size_t beam_width, - const size_t mem_max_seq_len, - const size_t d_model, - cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeTileEncoderResults(__nv_bfloat16* tiled_output, - int* tiled_sequence_length, - const __nv_bfloat16* output, - const int* sequence_length, - const size_t batch_size, - const size_t beam_width, - const size_t mem_max_seq_len, - const size_t d_model, - cudaStream_t stream); -#endif - -__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, - const bool* finished, - const float* cum_log_probs, - const int batch_size, - const int beam_width) -{ - const int bid = blockIdx.x; - const int tgt_start_idx = beam_hyps.num_beams[bid]; - if (beam_hyps.is_done[bid]) { - return; - } - for (int i = 0; i < beam_width; i++) { - if (threadIdx.x == 0) { - const int src_beam_idx = bid * beam_width + i; - const int tgt_beam_idx = bid * beam_width * 2 + i + tgt_start_idx; - - const int length = beam_hyps.sequence_lengths_src[src_beam_idx]; - - beam_hyps.output_ids_tgt[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + length] = - beam_hyps.output_ids_src[length * batch_size * beam_width + src_beam_idx]; - if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) { - beam_hyps.log_probs[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + length] = - beam_hyps.log_probs_src[length * batch_size * beam_width + src_beam_idx]; - } - int prev_id = beam_hyps.parent_ids_src[length * batch_size * beam_width + src_beam_idx]; - for (int j = length - 1; j >= 0; j--) { - // output_ids_tgt need to use max_seq_len + 1 because its shape is - // [bs, beam_width, max_seq_len + 1] - beam_hyps.output_ids_tgt[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + j] = - beam_hyps.output_ids_src[j * batch_size * beam_width + bid * beam_width + prev_id]; - if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) { - beam_hyps.log_probs[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + j] = - beam_hyps.log_probs_src[j * batch_size * beam_width + bid * beam_width + prev_id]; - } - prev_id = beam_hyps.parent_ids_src[j * batch_size * beam_width + bid * beam_width + prev_id]; - } - beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = length; - - beam_hyps.normed_scores[tgt_beam_idx] = apply_length_penalty( - cum_log_probs[src_beam_idx], finished[src_beam_idx] ? length + 1 : length, beam_hyps.length_penalty); - beam_hyps.cum_log_probs[tgt_beam_idx] = cum_log_probs[src_beam_idx]; - - beam_hyps.num_beams[bid]++; - } - } -} - -void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, - const bool* finished, - const float* cum_log_probs, - const int batch_size, - const int beam_width, - cudaStream_t stream) -{ - insertUnfinishedPath<<>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width); -} - -} // namespace turbomind diff --git a/src/turbomind/kernels/beam_search_topk_kernels.h b/src/turbomind/kernels/beam_search_topk_kernels.h deleted file mode 100644 index 993dda1dbd..0000000000 --- a/src/turbomind/kernels/beam_search_topk_kernels.h +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#pragma once - -namespace turbomind { - -// In original beam search implementation, if a beam is finished, we set it as finished -// and only continue to do beam search on remain beams (namely, beam_width - 1 beams in next step) -// -// In this implementation, when a beam is finished, we trace the path and record it in output_ids_tgt, -// and also record the normalized scores. And the beam search continue to use `beam_width` beams in -// next step. -// -// After we collect `beam_width` beams, we will sort them by their norm_scores. -struct BeamHypotheses { - int* output_ids_tgt = nullptr; - int* sequence_lengths_tgt = nullptr; - float* cum_log_probs = nullptr; // cum_log - float* normed_scores = nullptr; // cum_log / (length**length_penalty) - float* log_probs = nullptr; // log probs of each generated token - float* min_normed_scores = nullptr; // record the min normed scores for each batch - int* num_beams = nullptr; // the number of finished beams we collect - bool* is_done = nullptr; - - // Used to set inputs - const int* output_ids_src; - const int* parent_ids_src; - const int* sequence_lengths_src; - const int* end_ids; - const float* log_probs_src; - - // some variables for kernels - int step; - int ite; - int batch_size; - int local_batch_size; - int max_seq_len; - float length_penalty; - - bool early_stopping = true; - bool is_return_normed_score = true; // return normed_cum_log_probs or cum_log_probs -}; - -template -void invokeTopkBeamSearch(void* workspace, - size_t& workspace_size, - T* log_probs, - int* ids, - BeamHypotheses* beam_hyps, - const bool* finished, - const int* sequence_lengths, - const int batch_size, - const int beam_width, - const int vocab_size_padded_, - const T diversity_rate, - const float length_penalty, - const int* end_ids, - cudaStream_t stream); - -template -void invokeTileEncoderResults(T* tiled_encoder_output, - int* tiled_encoder_sequence_length, - const T* encoder_output, - const int* encoder_sequence_length, - const size_t batch_size, - const size_t beam_width, - const size_t mem_max_seq_len, - const size_t d_model, - cudaStream_t stream); - -void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, - const bool* finished, - const float* cum_log_probs, - const int batch_size, - const int beam_width, - cudaStream_t stream); - -} // namespace turbomind diff --git a/src/turbomind/kernels/bert_preprocess_kernels.cu b/src/turbomind/kernels/bert_preprocess_kernels.cu index e7246173ee..9c32000e6d 100644 --- a/src/turbomind/kernels/bert_preprocess_kernels.cu +++ b/src/turbomind/kernels/bert_preprocess_kernels.cu @@ -68,120 +68,6 @@ void invokeGetPaddingOffsetAndCuSeqLens(size_t* h_pinned_token_num, sync_check_cuda_error(); } -template -__global__ void buildEncoderAttentionMaskKernel(T* attention_mask, const int* sequence_lengths, const int max_seq_len) -{ - // sequence_lengths: [batch_size] - // attention_mask: [batch_size, 1, max_seq_len, max_seq_len] - attention_mask += blockIdx.x * max_seq_len * max_seq_len; - const int length = sequence_lengths[blockIdx.x]; - for (int i = threadIdx.x; i < max_seq_len * max_seq_len; i += blockDim.x) { - // int row_id = i / max_seq_len; - int col_id = i % max_seq_len; - // if (row_id < length && col_id < length) { - // TODO (bhsueh) check this modification is ok or not on other rmodel - if (col_id < length) { - attention_mask[i] = (T)(1.0f); - } - else { - attention_mask[i] = (T)(0.0f); - } - } -} - -template -void invokeBuildEncoderAttentionMask( - T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream) -{ - buildEncoderAttentionMaskKernel<<>>(attention_mask, sequence_lengths, max_seq_len); -} - -template void invokeBuildEncoderAttentionMask(float* attention_mask, - const int* sequence_lengths, - const int batch_size, - const int max_seq_len, - cudaStream_t stream); -template void invokeBuildEncoderAttentionMask(half* attention_mask, - const int* sequence_lengths, - const int batch_size, - const int max_seq_len, - cudaStream_t stream); -#ifdef ENABLE_FP8 -template void invokeBuildEncoderAttentionMask(__nv_fp8_e4m3* attention_mask, - const int* sequence_lengths, - const int batch_size, - const int max_seq_len, - cudaStream_t stream); -#endif // ENABLE_FP8 -#ifdef ENABLE_BF16 -template void invokeBuildEncoderAttentionMask(__nv_bfloat16* attention_mask, - const int* sequence_lengths, - const int batch_size, - const int max_seq_len, - cudaStream_t stream); -#endif - -__global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, const int* sequence_length, const int batch_size) -{ - // use for get tensorrt fused mha padding offset - // when we remove the padding - - extern __shared__ int tmp_offset[]; - if (threadIdx.x == 0) { - tmp_offset[0] = 0; - for (int i = 0; i < batch_size; i++) { - tmp_offset[i + 1] = tmp_offset[i] + sequence_length[i]; - } - } - __syncthreads(); - - for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) { - trt_mha_padding_offset[i] = tmp_offset[i]; - } -} - -void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - const int* sequence_length, - const int batch_size, - cudaStream_t stream) -{ - getTrtPaddingOffsetKernel<<<1, 256, sizeof(int) * (batch_size + 1), stream>>>( - trt_mha_padding_offset, sequence_length, batch_size); -} - -__global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, - const int* sequence_length, - const int request_batch_size, - const int request_seq_len) -{ - // use for get tensorrt fused mha padding offset - // when we keep the padding - - extern __shared__ int tmp_offset[]; - if (threadIdx.x == 0) { - tmp_offset[0] = 0; - for (int i = 0; i < request_batch_size; i++) { - tmp_offset[i * 2 + 1] = tmp_offset[i * 2] + sequence_length[i]; - tmp_offset[i * 2 + 2] = request_seq_len * (i + 1); - } - } - __syncthreads(); - - for (int i = threadIdx.x; i < 2 * request_batch_size + 1; i += blockDim.x) { - trt_mha_padding_offset[i] = tmp_offset[i]; - } -} - -void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - const int* sequence_length, - const int request_batch_size, - const int request_seq_len, - cudaStream_t stream) -{ - getTrtPaddingOffsetKernel<<<1, 256, sizeof(int) * (2 * request_batch_size + 1), stream>>>( - trt_mha_padding_offset, sequence_length, request_batch_size, request_seq_len); -} - template __global__ void rebuild_sequence_length_padding(const T* src, T* dst, const int* padding_offset, const int n) { @@ -287,183 +173,4 @@ template void invokeRemovePadding(__nv_bfloat16* dst, cudaStream_t stream); #endif -template -__global__ void buildRelativeAttentionBias(T* relative_attention_bias, - const T* relative_attention_bias_table, - const int head_num, - const int seq_len, - const int num_bucket, - const bool is_bidirectional, - const int max_distance) -{ - - const int head_id = blockIdx.x; - for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x) { - int row_id = seq_id / seq_len; - int col_id = seq_id % seq_len; - - int relative_position = col_id - row_id; - - int relative_buckets = 0; - int tmp_num_bucket = num_bucket; - if (is_bidirectional) { - tmp_num_bucket /= 2; - if (relative_position > 0) { - relative_buckets += tmp_num_bucket; - } - else { - relative_position *= -1; - } - } - else { - relative_position = abs(relative_position); - } - - int max_exact = tmp_num_bucket / 2; - bool is_small = relative_position < max_exact; - - int relative_position_if_large = - max_exact - + (int)(logf(relative_position * 1.0f / max_exact) / logf((float)max_distance / max_exact) - * (tmp_num_bucket - max_exact)); - - relative_position_if_large = min(relative_position_if_large, tmp_num_bucket - 1); - - relative_buckets += is_small ? relative_position : relative_position_if_large; - - relative_attention_bias[head_id * seq_len * seq_len + seq_id] = - relative_attention_bias_table[head_id * num_bucket + relative_buckets]; - } -} - -template -void invokeBuildRelativeAttentionBias(T* relative_attention_bias, - const T* relative_attention_bias_table, - const int head_num, - const int seq_len, - const int num_bucket, - const bool is_bidirectional, - const int max_distance, - const PositionEmbeddingType position_embedding_type, - cudaStream_t stream) -{ - if (position_embedding_type == PositionEmbeddingType::absolute) { - return; - } - dim3 grid(head_num); - dim3 block(256); - buildRelativeAttentionBias<<>>(relative_attention_bias, - relative_attention_bias_table, - head_num, - seq_len, - num_bucket, - is_bidirectional, - max_distance); -} - -template void invokeBuildRelativeAttentionBias(float* relative_attention_bias, - const float* relative_attention_bias_table, - const int head_num, - const int seq_len, - const int num_bucket, - const bool is_bidirectional, - const int max_distance, - const PositionEmbeddingType position_embedding_type, - cudaStream_t stream); - -template void invokeBuildRelativeAttentionBias(half* relative_attention_bias, - const half* relative_attention_bias_table, - const int head_num, - const int seq_len, - const int num_bucket, - const bool is_bidirectional, - const int max_distance, - const PositionEmbeddingType position_embedding_type, - cudaStream_t stream); - -#ifdef ENABLE_BF16 -template void invokeBuildRelativeAttentionBias(__nv_bfloat16* relative_attention_bias, - const __nv_bfloat16* relative_attention_bias_table, - const int head_num, - const int seq_len, - const int num_bucket, - const bool is_bidirectional, - const int max_distance, - const PositionEmbeddingType position_embedding_type, - cudaStream_t stream); -#endif - -#ifdef ENABLE_FP8 - -template -__global__ void getLastTokenDequantize(getLastTokenDequantizeParam param) -{ - param.output[blockIdx.x * param.d_model + threadIdx.x] = (T_OUT)( - (float)param.input[blockIdx.x * param.max_seq_len * param.d_model + threadIdx.x] * __ldg(param.input_scale)); -} - -template -void invokeGetLastTokenDequantize(getLastTokenDequantizeParam param) -{ - FT_CHECK(param.d_model <= 1024); - getLastTokenDequantize<<>>(param); -} - -template void invokeGetLastTokenDequantize<__nv_bfloat16, __nv_fp8_e4m3>( - getLastTokenDequantizeParam<__nv_bfloat16, __nv_fp8_e4m3> param); - -template -__global__ void quantizeMatrixRebuildPadding(QuantizeMatrixRebuildPaddingParam param) -{ - for (int i = threadIdx.x; i < param.d_model; i += blockDim.x) { - int padded_row_id = blockIdx.x + (param.padding_offset == nullptr ? 0 : param.padding_offset[blockIdx.x]); - if (quantize_mode == QUANTIZE_MODE::PER_TENSOR) { - param.dst[padded_row_id * param.d_model + i] = - (T_OUT)((float)param.src[blockIdx.x * param.d_model + i] * __ldg(param.scale)); - } - else if (quantize_mode == QUANTIZE_MODE::PER_CHANNEL) { - param.dst[padded_row_id * param.d_model + i] = - (T_OUT)((float)param.src[blockIdx.x * param.d_model + i] * __ldg(param.scale + i)); - } - } -} - -template<> -__global__ void -quantizeMatrixRebuildPadding(QuantizeMatrixRebuildPaddingParam param) -{ - int padded_row_id = blockIdx.x + (param.padding_offset == nullptr ? 0 : __ldg(¶m.padding_offset[blockIdx.x])); - __nv_fp8x4_e4m3* src_ptr = ((__nv_fp8x4_e4m3*)param.src) + blockIdx.x * (param.d_model / 4); - half2* dst_ptr = ((half2*)param.dst) + padded_row_id * (param.d_model / 2); - half2 scale = cuda_cast(__ldg(param.scale)); - for (int i = threadIdx.x; i < param.d_model / 4; i += blockDim.x) { - half2 val_0; - half2 val_1; - fp8x4_e4m3_to_half2(&val_0, &val_1, src_ptr + i); - - val_0 = hmul2(val_0, scale); - val_1 = hmul2(val_1, scale); - - dst_ptr[2 * i + 0] = val_0; - dst_ptr[2 * i + 1] = val_1; - } -} - -template -void invokeQuantizeMatrixRebuildPadding(QuantizeMatrixRebuildPaddingParam param) -{ - dim3 grid(param.token_num); - dim3 block(param.d_model); - FT_CHECK(block.x <= 1024); - if (block.x % 4 == 0) { - block.x /= 4; - } - quantizeMatrixRebuildPadding<<>>(param); -} - -template void invokeQuantizeMatrixRebuildPadding( - QuantizeMatrixRebuildPaddingParam param); - -#endif - } // namespace turbomind diff --git a/src/turbomind/kernels/bert_preprocess_kernels.h b/src/turbomind/kernels/bert_preprocess_kernels.h index deab2826f9..867aaf6b8f 100644 --- a/src/turbomind/kernels/bert_preprocess_kernels.h +++ b/src/turbomind/kernels/bert_preprocess_kernels.h @@ -15,7 +15,6 @@ */ #pragma once -#include "src/turbomind/kernels/gen_relative_pos_bias.h" #include "src/turbomind/utils/cuda_utils.h" #include #include @@ -46,21 +45,6 @@ inline void invokeGetPaddingOffset(size_t* h_pinned_token_num, h_pinned_token_num, h_token_num, tmp_mask_offset, nullptr, sequence_length, batch_size, max_seq_len, stream); } -template -void invokeBuildEncoderAttentionMask( - T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream); - -void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - const int* sequence_length, - const int request_batch_size, - cudaStream_t stream); - -void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, - const int* sequence_length, - const int request_batch_size, - const int request_seq_len, - cudaStream_t stream); - template void invokeRebuildPadding( T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream); @@ -69,46 +53,4 @@ template void invokeRemovePadding( T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream); -template -void invokeBuildRelativeAttentionBias(T* relative_attention_bias, - const T* relative_attention_bias_table, - const int head_num, - const int seq_len, - const int num_bucket, - const bool is_bidirectional, - const int max_distance, - const PositionEmbeddingType position_embedding_type, - cudaStream_t stream); - -template -struct getLastTokenDequantizeParam { - T_OUT* const output; - T_IN const* const input; - float const* const input_scale; - - const int batch_size; - const int max_seq_len; - const int d_model; - cudaStream_t stream; -}; - -template -void invokeGetLastTokenDequantize(getLastTokenDequantizeParam param); - -#ifdef ENABLE_FP8 -template -struct QuantizeMatrixRebuildPaddingParam { - T_OUT* dst; - const T_IN* src; - const int* padding_offset; - const int token_num; - const int d_model; - const float* scale; - cudaStream_t stream; -}; - -template -void invokeQuantizeMatrixRebuildPadding(QuantizeMatrixRebuildPaddingParam param); -#endif // ENABLE_FP8 - } // namespace turbomind diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention.h b/src/turbomind/kernels/decoder_masked_multihead_attention.h index 5cf502555d..cc441494e1 100644 --- a/src/turbomind/kernels/decoder_masked_multihead_attention.h +++ b/src/turbomind/kernels/decoder_masked_multihead_attention.h @@ -16,7 +16,6 @@ #pragma once -#include "src/turbomind/layers/attention_layers_fp8/AttentionFP8Weight.h" #include "src/turbomind/utils/cuda_bf16_wrapper.h" #include "src/turbomind/utils/cuda_fp8_utils.h" #include diff --git a/src/turbomind/kernels/decoding_kernels.cu b/src/turbomind/kernels/decoding_kernels.cu index 98fb5e5a48..ffce48c460 100644 --- a/src/turbomind/kernels/decoding_kernels.cu +++ b/src/turbomind/kernels/decoding_kernels.cu @@ -21,81 +21,6 @@ namespace turbomind { -// static const float HALF_FLT_MAX = 65504.F; - -template -__global__ void decodingInitialize(bool* finished, - int* sequence_length, - int* word_ids, - T* cum_log_probs, - const int* sentence_ids, - const int batch_size, - const int beam_width, - const int max_input_length) -{ - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? (T)HALF_FLT_MAX : (T)1e20f; // BF16 and FP32 have the same dynamic range - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * beam_width; - index += blockDim.x * gridDim.x) { - finished[index] = false; - sequence_length[index] = max_input_length; - if (word_ids != nullptr) { - word_ids[index] = sentence_ids[index / beam_width]; - } - cum_log_probs[index] = (index % beam_width == 0) ? (T)0.0f : (T)-MAX_T_VAL; - } -} - -template -void invokeDecodingInitialize(bool* finished, - int* sequence_length, - int* word_ids, - T* cum_log_probs, - const int* sentence_ids, - const int batch_size, - const int beam_width, - const int max_input_length, - cudaStream_t stream) -{ - dim3 grid((int)ceil(batch_size * beam_width * 1.0 / 256)); - dim3 block(256); - - decodingInitialize<<>>( - finished, sequence_length, word_ids, cum_log_probs, sentence_ids, batch_size, beam_width, max_input_length); -} - -template void invokeDecodingInitialize(bool* finished, - int* sequence_length, - int* word_ids, - float* cum_log_probs, - const int* sentence_ids, - const int batch_size, - const int beam_width, - const int max_input_length, - cudaStream_t stream); - -template void invokeDecodingInitialize(bool* finished, - int* sequence_length, - int* word_ids, - half* cum_log_probs, - const int* sentence_ids, - const int batch_size, - const int beam_width, - const int max_input_length, - cudaStream_t stream); - -#ifdef ENABLE_BF16 -template void invokeDecodingInitialize(bool* finished, - int* sequence_length, - int* word_ids, - __nv_bfloat16* cum_log_probs, - const int* sentence_ids, - const int batch_size, - const int beam_width, - const int max_input_length, - cudaStream_t stream); -#endif - // PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts template __global__ void embeddingLookupPosEncoding(T* from_tensor, @@ -364,33 +289,33 @@ void invokePaddingEmbedding(T* padded_embedding_kernel, vocab_size_padded); } -template void invokePaddingEmbedding(float* padded_embedding_kernel, - float* padded_embedding_bias, - const float* embedding_kernel, - const float* embedding_bias, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, - cudaStream_t stream); - -template void invokePaddingEmbedding(half* padded_embedding_kernel, - half* padded_embedding_bias, - const half* embedding_kernel, - const half* embedding_bias, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, - cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokePaddingEmbedding(__nv_bfloat16* padded_embedding_kernel, - __nv_bfloat16* padded_embedding_bias, - const __nv_bfloat16* embedding_kernel, - const __nv_bfloat16* embedding_bias, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, - cudaStream_t stream); -#endif +// template void invokePaddingEmbedding(float* padded_embedding_kernel, +// float* padded_embedding_bias, +// const float* embedding_kernel, +// const float* embedding_bias, +// const int hidden_unit, +// const int vocab_size, +// const int vocab_size_padded, +// cudaStream_t stream); + +// template void invokePaddingEmbedding(half* padded_embedding_kernel, +// half* padded_embedding_bias, +// const half* embedding_kernel, +// const half* embedding_bias, +// const int hidden_unit, +// const int vocab_size, +// const int vocab_size_padded, +// cudaStream_t stream); +// #ifdef ENABLE_BF16 +// template void invokePaddingEmbedding(__nv_bfloat16* padded_embedding_kernel, +// __nv_bfloat16* padded_embedding_bias, +// const __nv_bfloat16* embedding_kernel, +// const __nv_bfloat16* embedding_bias, +// const int hidden_unit, +// const int vocab_size, +// const int vocab_size_padded, +// cudaStream_t stream); +// #endif template __global__ void paddingEmbeddingKernel(T* padded_embedding_kernel, @@ -426,256 +351,28 @@ void invokePaddingEmbeddingKernel(T* padded_embedding_kernel, padded_embedding_kernel, embedding_kernel, hidden_unit, vocab_size, vocab_size_padded); } -template void invokePaddingEmbeddingKernel(float* padded_embedding_kernel, - const float* embedding_kernel, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, - cudaStream_t stream); - -template void invokePaddingEmbeddingKernel(half* padded_embedding_kernel, - const half* embedding_kernel, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, - cudaStream_t stream); - -#ifdef ENABLE_BF16 -template void invokePaddingEmbeddingKernel(__nv_bfloat16* padded_embedding_kernel, - const __nv_bfloat16* embedding_kernel, - const int hidden_unit, - const int vocab_size, - const int vocab_size_padded, - cudaStream_t stream); -#endif - -__global__ void gatherTree(gatherTreeParam param) -{ - // PREFIX SOFT PROMPT - // beam: have six parts - // [prompt | input | input_padding | prompt_padding | generated output | padding (use end_token)] - // parents: have five parts - // [prompt | input | input_padding | prompt_padding | generated output | padding (use 0)] - // step_ids: need to remove prompt, input_padding and prompt_padding - // the shape is [input_length + requested_output_length, bs, beam_width] - // need to transpose to output_ids [bs, beam_width, input_length + requested_output_length] - // max_input_length: input + input_padding + prompt_padding - - // P/PROMPT TUNING - // NOTE: input (real ids | prompt virtual ids) have already been preprocessed during embedding lookup, no prompt - // templates now beam: [input (real ids | prompt virtual ids) | input_padding | generated output | padding (use - // end_token)] parents: [input (real ids | prompt virtual ids) | input_padding | generated output | padding (use - // 0)] step_ids: need to remove virtual prompt ids in input ids - // the shape is [input_length (real input length, prompt length) + requested_output_length, bs, beam_width] - // need to transpose to output_ids [bs, beam_width, input_length + requested_output_length] - // max_input_length: input (real ids | prompt virtual ids) + input_padding - - const int max_input_length = param.input_lengths == nullptr ? 0 : param.max_input_length; - - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < param.batch_size * param.beam_width; - i += gridDim.x * blockDim.x) { - const int batch = i / param.beam_width; - const int beam = i % param.beam_width; - const int prompt_len = - param.prefix_soft_prompt_lengths == nullptr ? 0 : param.prefix_soft_prompt_lengths[batch]; - int input_len = param.input_lengths == nullptr ? 0 : param.input_lengths[i]; - // virtual prompts mean the prompt embedded in input ids (with prompt templates) [p/prompt tuning] - const int virtual_prompt_length = - param.p_prompt_tuning_prompt_lengths == nullptr ? 0 : param.p_prompt_tuning_prompt_lengths[batch]; - // real input length (without virtual prompts) [p/prompt tuning] - input_len -= virtual_prompt_length; - - const int* parent_ids = param.parent_ids; - const int* step_ids = param.step_ids; - - // TODO(bhsueh) optimize the reduce_max operation for large beam_width - int max_len = -1; - bool update_response_input_length = param.response_input_lengths != nullptr; - // int selected_beam_index = 0; - for (int j = 0; j < param.beam_width; j++) { - int tmp_len = - param.max_sequence_lengths[batch * param.beam_width + j] + param.max_sequence_length_final_step; - // also remove the length of the soft prompts, p_prompt_tuning - param.max_sequence_lengths[batch * param.beam_width + j] = - tmp_len - param.max_prefix_soft_prompt_length - - (param.max_input_length - param.max_input_without_prompt_length); - // update the response input length - if (update_response_input_length) { - param.response_input_lengths[batch * param.beam_width + j] = input_len - prompt_len; - } - if (tmp_len > max_len) { - max_len = tmp_len; - // selected_beam_index = j; - } - } - const int max_seq_len_b = min(param.max_time, max_len); - if (max_seq_len_b <= 0) { - continue; - } - -#define GET_IX(time_ix, beam_ix) \ - (param.batch_size * param.beam_width * (time_ix) + param.beam_width * batch + (beam_ix)) - - const int padding_offset_and_prompt_offset = max_input_length - input_len + prompt_len; - const int initial_tgt_ix = GET_IX(max_seq_len_b - 1 - padding_offset_and_prompt_offset, beam); - const int initial_parent_ix = GET_IX(max_seq_len_b - 1, beam); - param.beams[initial_tgt_ix] = __ldg(step_ids + initial_parent_ix); - int parent = parent_ids == nullptr ? 0 : __ldg(parent_ids + initial_parent_ix) % param.beam_width; - bool found_bad = false; - - for (int level = max_seq_len_b - 2; level >= 0; --level) { - if (level < prompt_len || (level >= input_len && level < max_input_length)) { - continue; - } - int tgt_level = level >= max_input_length ? level - padding_offset_and_prompt_offset : level - prompt_len; - const int level_beam_ix = GET_IX(tgt_level, beam); - const int level_parent_ix = GET_IX(level, parent); - if (parent < 0 || parent > param.beam_width) { - // param.beams[level_beam_ix] = -1; - param.beams[level_beam_ix] = param.end_tokens[batch]; - parent = -1; - found_bad = true; - } - else { - param.beams[level_beam_ix] = __ldg(step_ids + level_parent_ix); - parent = parent_ids == nullptr ? 0 : __ldg(parent_ids + level_parent_ix) % param.beam_width; - } - } - - // set the padded part as end_token - // input_len - for (int index = max_len - padding_offset_and_prompt_offset; - index < param.max_time - param.max_prefix_soft_prompt_length; - ++index) { - param.beams[GET_IX(index, beam)] = param.end_tokens[batch]; - } - - // Not necessary when using a BeamSearchDecoder, but necessary - // when a user feeds in possibly broken trajectory (i.e., non-eos - // entries in a beam following eos entries). - if (!found_bad) { - bool finished = false; - // skip the step 0 because it is often the start token - int start_step = max_input_length == 0 ? 1 : max_input_length; - for (int time = start_step; time < max_seq_len_b; ++time) { - const int level_beam_ix = GET_IX(time, beam); - if (finished) { - param.beams[level_beam_ix] = param.end_tokens[batch]; - } - else if (param.beams[level_beam_ix] == param.end_tokens[batch]) { - finished = true; - } - } - } -#undef GET_IX - - // transpose on output_ids - // remove p_prompt tuning virtual tokens (end tokens) - int actual_output_length = param.max_time - param.max_prefix_soft_prompt_length - - (param.max_input_length - param.max_input_without_prompt_length); - if (param.output_ids != nullptr) { - for (int j = 0; j < actual_output_length; j++) { - param.output_ids[i * actual_output_length + j] = - param.beams[j * param.batch_size * param.beam_width + i]; - } - } - } -} - -void invokeGatherTree(int* beams, - int* max_sequence_lengths, - const int max_time, - const int batch_size, - const int beam_width, - const int* step_ids, - const int* parent_ids, - const int* end_tokens, - cudaStream_t stream) -{ - gatherTreeParam param; - param.beams = beams; - param.max_sequence_lengths = max_sequence_lengths; - param.max_time = max_time; - param.batch_size = batch_size; - param.beam_width = beam_width; - param.step_ids = step_ids; - param.parent_ids = parent_ids; - param.end_tokens = end_tokens; - param.max_input_length = 1; - param.prefix_soft_prompt_lengths = nullptr; - param.stream = stream; - invokeGatherTree(param); -} - -void invokeGatherTree(int* beams, - int* max_sequence_lengths, - const int max_time, - const int batch_size, - const int beam_width, - const int* step_ids, - const int* parent_ids, - const int* end_tokens, - const int max_input_length, - cudaStream_t stream) -{ - gatherTreeParam param; - param.beams = beams; - param.max_sequence_lengths = max_sequence_lengths; - param.max_time = max_time; - param.batch_size = batch_size; - param.beam_width = beam_width; - param.step_ids = step_ids; - param.parent_ids = parent_ids; - param.end_tokens = end_tokens; - param.max_input_length = max_input_length; - param.prefix_soft_prompt_lengths = nullptr; - param.stream = stream; - invokeGatherTree(param); -} - -void invokeGatherTree(gatherTreeParam param) -{ - int batchbeam = param.batch_size * param.beam_width; - dim3 grid(1), block(batchbeam); - // though decoder do not support > 1024 for now - if (batchbeam > 1024) { - grid.x = ceil(param.batch_size * param.beam_width / 1024.); - block.x = 1024; - } - gatherTree<<>>(param); -} - -__global__ void minusUnfinishedSeqlen(int* sequence_lengths, const bool* finished, const int token_num) -{ - for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < token_num; i += blockDim.x * gridDim.x) { - if (finished[i] == false) { - sequence_lengths[i] -= 1; - } - } -} - -void invokeMinusUnfinishedSeqlen(int* sequence_lengths, const bool* finished, const int token_num, cudaStream_t stream) -{ - dim3 block(min(256, token_num)); - dim3 grid(ceil(token_num / 256.)); - minusUnfinishedSeqlen<<>>(sequence_lengths, finished, token_num); -} - -__global__ void plusUnfinishedSeqlen(int* sequence_lengths, const bool* finished, const int token_num) -{ - for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < token_num; i += blockDim.x * gridDim.x) { - if (finished[i] == false) { - sequence_lengths[i] += 1; - } - } -} - -void invokePlusUnfinishedSeqlen(int* sequence_lengths, const bool* finished, const int token_num, cudaStream_t stream) -{ - dim3 block(min(256, token_num)); - dim3 grid(ceil(token_num / 256.)); - plusUnfinishedSeqlen<<>>(sequence_lengths, finished, token_num); -} +// template void invokePaddingEmbeddingKernel(float* padded_embedding_kernel, +// const float* embedding_kernel, +// const int hidden_unit, +// const int vocab_size, +// const int vocab_size_padded, +// cudaStream_t stream); + +// template void invokePaddingEmbeddingKernel(half* padded_embedding_kernel, +// const half* embedding_kernel, +// const int hidden_unit, +// const int vocab_size, +// const int vocab_size_padded, +// cudaStream_t stream); + +// #ifdef ENABLE_BF16 +// template void invokePaddingEmbeddingKernel(__nv_bfloat16* padded_embedding_kernel, +// const __nv_bfloat16* embedding_kernel, +// const int hidden_unit, +// const int vocab_size, +// const int vocab_size_padded, +// cudaStream_t stream); +// #endif template __global__ void plusScalar(T* buf, const T val, const int size) @@ -695,112 +392,4 @@ void invokePlusScalar(T* buf, const T val, const int size, cudaStream_t stream) template void invokePlusScalar(int* buf, const int val, const int size, cudaStream_t stream); -__global__ void finalize(int* output_ids, - int* sequence_lengths, - float* cum_log_probs, - float* output_log_probs, - const int* topk_output_ids, - const int* topk_sequence_lengths, - const float* scores, - const float* topk_cum_log_probs, - const float* topk_log_probs, - const int* num_beams, - const int beam_width, - const int max_seq_len) -{ - // output_ids: [bs, beam_width, max_seq_len] - // sequence_lengths: [bs, beam_width] - // cum_log_probs: [bs, beam_width] - // output_log_probs: [bs, beam_width, max_seq_len] - // topk_output_ids: [bs, 2 * beam_width, max_seq_len + 1] - // topk_sequence_lengths: [bs, 2 * beam_width] - // scores: [bs, 2 * beam_width] - // topk_cum_log_probs: [bs, 2 * beam_width] - // topk_log_probs: [bs, 2 * beam_width, max_seq_len + 1] - // num_beams: [bs] - - // This kernel do a sorting for scores first, and then put the topk_output_ids - // into output_ids by the rank of scores. - // Note that we remove the start_token (the id at first position) from topk_output_ids - - extern __shared__ char array[]; - int* rank = (int*)(array); - float* s_scores = (float*)(rank + beam_width); - if (threadIdx.x < num_beams[blockIdx.x]) { - s_scores[threadIdx.x] = scores[blockIdx.x * beam_width * 2 + threadIdx.x]; - } - __syncthreads(); - - for (int i = 0; i < beam_width; i++) { - float score = threadIdx.x < num_beams[blockIdx.x] ? s_scores[threadIdx.x] : -FLT_MAX; - float max_score = blockReduceMax(score); - - if (threadIdx.x == 0) { - for (int j = 0; j < beam_width * 2; j++) { - if (s_scores[j] == max_score) { - rank[i] = j; - s_scores[j] = -FLT_MAX; - break; - } - } - } - __syncthreads(); - } - - if (threadIdx.x < beam_width) { - sequence_lengths[blockIdx.x * beam_width + threadIdx.x] = - topk_sequence_lengths[blockIdx.x * beam_width * 2 + rank[threadIdx.x]]; - if (cum_log_probs != nullptr) { - cum_log_probs[blockIdx.x * beam_width + threadIdx.x] = - topk_cum_log_probs[blockIdx.x * beam_width * 2 + rank[threadIdx.x]]; - } - } - for (int beam_idx = 0; beam_idx < beam_width; beam_idx++) { - // start from step 1 to skip the start token - for (int i = threadIdx.x; i < sequence_lengths[blockIdx.x * beam_width + beam_idx]; i += blockDim.x) { - output_ids[blockIdx.x * beam_width * max_seq_len + beam_idx * max_seq_len + i] = - topk_output_ids[blockIdx.x * (beam_width * 2) * (max_seq_len + 1) + rank[beam_idx] * (max_seq_len + 1) - + (i + 1)]; - if (output_log_probs != nullptr) { - output_log_probs[blockIdx.x * beam_width * max_seq_len + beam_idx * max_seq_len + i] = - topk_log_probs[blockIdx.x * (beam_width * 2) * (max_seq_len + 1) - + rank[beam_idx] * (max_seq_len + 1) + (i + 1)]; - } - } - } -} - -void invokeFinalize(int* output_ids, - int* sequence_lengths, - float* cum_log_probs, - float* output_log_probs, - const int* topk_output_ids, - const int* topk_sequence_lengths, - const float* scores, - const float* topk_cum_log_probs, - const float* topk_log_probs, - const int* num_beams, - const int beam_width, - const int max_seq_len, - const int batch_size, - cudaStream_t stream) -{ - dim3 block(beam_width * 2); - block.x = (block.x + 31) / 32 * 32; - FT_CHECK(block.x < 1024); - finalize<<>>( - output_ids, - sequence_lengths, - cum_log_probs, - output_log_probs, - topk_output_ids, - topk_sequence_lengths, - scores, - topk_cum_log_probs, - topk_log_probs, - num_beams, - beam_width, - max_seq_len); -} - } // namespace turbomind diff --git a/src/turbomind/kernels/decoding_kernels.h b/src/turbomind/kernels/decoding_kernels.h index f4024d687b..49db3c3fcb 100644 --- a/src/turbomind/kernels/decoding_kernels.h +++ b/src/turbomind/kernels/decoding_kernels.h @@ -22,17 +22,6 @@ namespace turbomind { -template -void invokeDecodingInitialize(bool* finished, - int* sequence_length, - int* word_ids, - T* cum_log_probs, - const int* sentence_ids, - const int batch_size, - const int beam_width, - const int max_input_length, - cudaStream_t stream); - // get token from all_ids at step, then lookup from the embedding table // by the token template @@ -99,72 +88,7 @@ void invokePaddingEmbeddingKernel(T* padded_embedding_kernel, const int vocab_size_padded, cudaStream_t stream); -void invokeGatherTree(int* beams, - int* max_sequence_lengths, - const int max_time, - const int batch_size, - const int beam_width, - const int* step_ids, - const int* parent_ids, - const int* end_tokens, - cudaStream_t stream); - -void invokeGatherTree(int* beams, - int* max_sequence_lengths, - const int max_time, - const int batch_size, - const int beam_width, - const int* step_ids, - const int* parent_ids, - const int* end_tokens, - const int max_input_length, - cudaStream_t stream); - -struct gatherTreeParam { - int* beams = nullptr; - int* max_sequence_lengths = nullptr; - int max_sequence_length_final_step = 0; - const int* input_lengths = nullptr; - // response input lengths (used to slice the ids during postprocessing) - int* response_input_lengths = nullptr; - int max_time = 0; - int batch_size = 0; - int beam_width = 0; - const int* step_ids = nullptr; - const int* parent_ids = nullptr; - const int* end_tokens = nullptr; - int max_input_length = 0; - const int* prefix_soft_prompt_lengths = nullptr; - // p_prompt_tuning prompt leangths, used to remove prompts during post-processing - const int* p_prompt_tuning_prompt_lengths = nullptr; - int max_input_without_prompt_length = 0; - // prefix soft prompt - int max_prefix_soft_prompt_length = 0; - int* output_ids = nullptr; - cudaStream_t stream; -}; - -void invokeGatherTree(gatherTreeParam param); - -void invokeMinusUnfinishedSeqlen(int* sequence_lengths, const bool* finished, const int token_num, cudaStream_t stream); -void invokePlusUnfinishedSeqlen(int* sequence_lengths, const bool* finished, const int token_num, cudaStream_t stream); - template void invokePlusScalar(T* buf, const T val, const int size, cudaStream_t stream); -void invokeFinalize(int* output_ids, - int* sequence_lengths, - float* cum_log_probs, - float* output_log_probs, - const int* topk_output_ids, - const int* topk_sequence_lengths, - const float* scores, - const float* topk_cum_log_probs, - const float* topk_log_probs, - const int* num_beams, - const int beam_width, - const int max_seq_len, - const int batch_size, - cudaStream_t stream); - } // namespace turbomind diff --git a/src/turbomind/kernels/gen_relative_pos_bias.cu b/src/turbomind/kernels/gen_relative_pos_bias.cu deleted file mode 100644 index ddf34f44bd..0000000000 --- a/src/turbomind/kernels/gen_relative_pos_bias.cu +++ /dev/null @@ -1,304 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cublas_v2.h" -#include "gen_relative_pos_bias.h" -#include "reduce_kernel_utils.cuh" -#include "src/turbomind/kernels/activation_kernels.h" -#include "src/turbomind/utils/cuda_utils.h" -#include - -namespace turbomind { - -/******************* invokeGenRelativePosBias ***********************/ -// relative_position_bias_table is [(2*window_size-1)*(2*window_size-1), headNum] -// relative_position_bias is [head_num, window_size^2, window_size^2] -// grid(window_size*window_size, head_num) -// block(window_size*window_size) - -template -__global__ void gen_relative_pos_bias(T* relative_position_bias, - const T* relative_position_bias_table, - const Tindex* relative_position_bias_index, - const int window_size, - const int head_num) -{ - const int h_in_window = blockIdx.x / window_size; - const int w_in_window = blockIdx.x % window_size; - const int h_in_token = threadIdx.x / window_size; - const int w_in_token = threadIdx.x % window_size; - const int head_idx = blockIdx.y; - const int elements_per_window = window_size * window_size; - const size_t elements_per_window_2 = elements_per_window * elements_per_window; - const size_t output_idx = head_idx * elements_per_window_2 + blockIdx.x * elements_per_window + threadIdx.x; - if (output_idx < head_num * elements_per_window_2) { - const Tindex idx_in_table = - relative_position_bias_index[(h_in_window * window_size + w_in_window) * elements_per_window - + h_in_token * window_size + w_in_token]; - relative_position_bias[output_idx] = relative_position_bias_table[idx_in_table * head_num + head_idx]; - } -} - -template -void invokeGenRelativePosBias(T* relative_position_bias, - const T* relative_position_bias_table, - const Tindex* relative_position_bias_index, - const int window_size, - const int head_num, - cudaStream_t stream) -{ - dim3 grid(window_size * window_size, head_num); - dim3 block(window_size * window_size); - - if (block.x > 1024) { - printf("[ERROR][invokeGenRelativePosBias] window_size*window_size > 1024.\n"); - exit(-1); - } - - gen_relative_pos_bias<<>>( - relative_position_bias, relative_position_bias_table, relative_position_bias_index, window_size, head_num); -} - -/******************* invokeGenRelativePosBiasV2 ***********************/ -template -void invokeGenRelativePosBiasV2(T* relative_position_bias, - const T* relative_coords_table, - const Tindex* relative_position_bias_index, - const T* cpb_mlp_weight1, - const T* cpb_mlp_bias1, - const T* cpb_mlp_weight2, - const int window_size, - const int cpb_mlp_in_dim, - const int cpb_mlp_out_dim, - const int head_num, - cudaStream_t stream) -{ - - dim3 grid(window_size * window_size, head_num); - dim3 block(window_size * window_size); - - if (block.x > 1024) { - printf("[ERROR][invokeGenRelativePosBias] window_size*window_size > 1024.\n"); - exit(-1); - } - - T* relative_position_bias_table; - check_cuda_error(cudaMalloc(&relative_position_bias_table, - ((2 * window_size - 1) * (2 * window_size - 1) * head_num) * sizeof(T))); - T* cpb_mlp_1; - check_cuda_error( - cudaMalloc(&cpb_mlp_1, ((2 * window_size - 1) * (2 * window_size - 1) * cpb_mlp_out_dim) * sizeof(T))); - cublasHandle_t cublas_handle; - check_cuda_error(cublasCreate(&cublas_handle)); - - int m = (2 * window_size - 1) * (2 * window_size - 1); - T alpha = (T)1.0f; - T beta = (T)0.0f; - cudaDataType_t type = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -#if (CUDART_VERSION >= 11000) - cublasComputeType_t compute_type = std::is_same::value ? CUBLAS_COMPUTE_32F : CUBLAS_COMPUTE_16F; -#else - cudaDataType_t compute_type = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; -#endif - cublasGemmAlgo_t algo = std::is_same::value ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP; - check_cuda_error(cublasGemmEx(cublas_handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - cpb_mlp_out_dim, - m, - cpb_mlp_in_dim, - &alpha, - cpb_mlp_weight1, - type, - cpb_mlp_in_dim, - relative_coords_table, - type, - cpb_mlp_in_dim, - &beta, - cpb_mlp_1, - type, - cpb_mlp_out_dim, - compute_type, - algo)); - - invokeGenericActivation( - cpb_mlp_1, cpb_mlp_bias1, nullptr, nullptr, nullptr, nullptr, m, cpb_mlp_out_dim, 0, nullptr, nullptr, stream); - - check_cuda_error(cublasGemmEx(cublas_handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - head_num, - m, - cpb_mlp_out_dim, - &alpha, - cpb_mlp_weight2, - type, - cpb_mlp_out_dim, - cpb_mlp_1, - type, - cpb_mlp_out_dim, - &beta, - relative_position_bias_table, - type, - head_num, - compute_type, - algo)); - - gen_relative_pos_bias<<>>( - relative_position_bias, relative_position_bias_table, relative_position_bias_index, window_size, head_num); - - invokeSigmoid( - relative_position_bias, window_size * window_size * window_size * window_size * head_num, 16.0f, stream); - check_cuda_error(cudaFree(relative_position_bias_table)); - check_cuda_error(cudaFree(cpb_mlp_1)); - check_cuda_error(cublasDestroy(cublas_handle)); -} - -/******************* instantiation ***********************/ - -template void invokeGenRelativePosBias(float* relative_position_bias, - const float* relative_position_bias_table, - const int* relative_position_bias_index, - const int window_size, - const int head_num, - cudaStream_t stream); - -template void invokeGenRelativePosBias(half* relative_position_bias, - const half* relative_position_bias_table, - const int* relative_position_bias_index, - const int window_size, - const int head_num, - cudaStream_t stream); - -template void invokeGenRelativePosBias(float* relative_position_bias, - const float* relative_position_bias_table, - const int64_t* relative_position_bias_index, - const int window_size, - const int head_num, - cudaStream_t stream); - -template void invokeGenRelativePosBias(half* relative_position_bias, - const half* relative_position_bias_table, - const int64_t* relative_position_bias_index, - const int window_size, - const int head_num, - cudaStream_t stream); - -__host__ __device__ uint32_t pow2_rounddown(uint32_t x) -{ - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - x |= x >> 8; - x |= x >> 16; - x >>= 1; - return x + 1; -} - -template -__global__ void generate_alibi_slopes(T* alibi_slopes, const size_t num_heads) -{ - if (threadIdx.x < num_heads) { - // The nearest power of 2 greater than num_heads followed by HF's implementation. - int num_heads_pow2 = pow2_rounddown(num_heads); - // Loop over the attention head. - for (int h = threadIdx.x; h < num_heads; h += blockDim.x) { - if (h < num_heads_pow2) { - alibi_slopes[h] = static_cast(powf(powf(0.5f, powf(0.5f, log2f(num_heads_pow2) - 3.f)), h + 1)); - } - else { - alibi_slopes[h] = static_cast( - powf(powf(0.5f, powf(0.5f, log2f(num_heads_pow2 << 1) - 3.f)), (h - num_heads_pow2) * 2 + 1)); - } - } - } -} - -template -void invokeBuildAlibiSlopes(T* alibi_slopes, const size_t num_heads, cudaStream_t stream) -{ - // Generate the slopes of a linear attention linear bias. - // - // Paper: https://arxiv.org/abs/2108.12409 - // HF's implementation - // https://github.com/huggingface/transformers/blob/56ef0ba44765162f830873c140bd40bdc975cc34/src/transformers/models/bloom/modeling_bloom.py#L86 - // Author's implementation - // https://github.com/ofirpress/attention_with_linear_biases/blob/02aa87e7a29e9340efd28d6d169018eafb3aa57a/fairseq/models/transformer.py#L760 - // - // alibi_slopes: [num_heads], - // strictly follows how HF implements. which treats power-of-2 heads, and non-power-of-2 heads differently. - // what paper generates differs with HF's when number of heads is not a power of 2. - // num_heads: the number of attention heads. - // stream: a cuda stream. - - dim3 block(min((int)num_heads, 512)); - generate_alibi_slopes<<<1, block, 0, stream>>>(alibi_slopes, num_heads); -} - -template void invokeBuildAlibiSlopes(float* alibi_slopes, const size_t num_heads, cudaStream_t stream); -template void invokeBuildAlibiSlopes(half* alibi_slopes, const size_t num_heads, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeBuildAlibiSlopes(__nv_bfloat16* alibi_slopes, const size_t num_heads, cudaStream_t stream); -#endif - -template void invokeGenRelativePosBiasV2(float* relative_position_bias, - const float* relative_coords_table, - const int* relative_position_bias_index, - const float* cpb_mlp_weight1, - const float* cpb_mlp_bias1, - const float* cpb_mlp_weight2, - const int window_size, - const int cpb_mlp_in_dim, - const int cpb_mlp_out_dim, - const int head_num, - cudaStream_t stream); - -template void invokeGenRelativePosBiasV2(half* relative_position_bias, - const half* relative_coords_table, - const int* relative_position_bias_index, - const half* cpb_mlp_weight1, - const half* cpb_mlp_bias1, - const half* cpb_mlp_weight2, - const int window_size, - const int cpb_mlp_in_dim, - const int cpb_mlp_out_dim, - const int head_num, - cudaStream_t stream); - -template void invokeGenRelativePosBiasV2(float* relative_position_bias, - const float* relative_coords_table, - const int64_t* relative_position_bias_index, - const float* cpb_mlp_weight1, - const float* cpb_mlp_bias1, - const float* cpb_mlp_weight2, - const int window_size, - const int cpb_mlp_in_dim, - const int cpb_mlp_out_dim, - const int head_num, - cudaStream_t stream); - -template void invokeGenRelativePosBiasV2(half* relative_position_bias, - const half* relative_coords_table, - const int64_t* relative_position_bias_index, - const half* cpb_mlp_weight1, - const half* cpb_mlp_bias1, - const half* cpb_mlp_weight2, - const int window_size, - const int cpb_mlp_in_dim, - const int cpb_mlp_out_dim, - const int head_num, - cudaStream_t stream); -} // namespace turbomind diff --git a/src/turbomind/kernels/gen_relative_pos_bias.h b/src/turbomind/kernels/gen_relative_pos_bias.h deleted file mode 100644 index 209448483b..0000000000 --- a/src/turbomind/kernels/gen_relative_pos_bias.h +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "src/turbomind/utils/cuda_bf16_wrapper.h" - -#include -#include -#include - -namespace turbomind { - -enum class PositionEmbeddingType -{ - relative, - absolute, -}; - -template -void invokeGenRelativePosBias(T* relative_position_bias, - const T* relative_position_bias_table, - const Tindex* relative_position_bias_index, - const int window_size, - const int head_num, - cudaStream_t stream); - -template -void invokeBuildAlibiSlopes(T* linear_position_bias_slopes, const size_t head_num, cudaStream_t stream); - -template -void invokeGenRelativePosBiasV2(T* relative_position_bias, - const T* relative_coords_table, - const Tindex* relative_position_bias_index, - const T* cpb_mlp_weight1, - const T* cpb_mlp_bias1, - const T* cpb_mlp_weight2, - const int window_size, - const int cpb_mlp_in_dim, - const int cpb_mlp_out_dim, - const int head_num, - cudaStream_t stream); -} // namespace turbomind diff --git a/src/turbomind/kernels/logprob_kernels.cu b/src/turbomind/kernels/logprob_kernels.cu index c94c4f45be..20474a7ab2 100644 --- a/src/turbomind/kernels/logprob_kernels.cu +++ b/src/turbomind/kernels/logprob_kernels.cu @@ -182,29 +182,29 @@ void invokeLogProbFromLogits(float* cum_log_probs, cum_log_probs, log_probs, input_lengths, max_input_length, batch_size, batch_first); } -template void invokeLogProbFromLogits(float* cum_log_probs, - const float* logits, - const int* input_ids, - const int* input_lengths, - const size_t max_input_length, - const size_t batch_size, - const size_t vocab_size, - const size_t vocab_size_padded, - void* workspace, - const size_t workspace_size, - cudaStream_t stream, - const bool batch_first); - -template void invokeLogProbFromLogits(float* cum_log_probs, - const half* logits, - const int* input_ids, - const int* input_lengths, - const size_t max_input_length, - const size_t batch_size, - const size_t vocab_size, - const size_t vocab_size_padded, - void* workspace, - const size_t workspace_size, - cudaStream_t stream, - const bool batch_first); +// template void invokeLogProbFromLogits(float* cum_log_probs, +// const float* logits, +// const int* input_ids, +// const int* input_lengths, +// const size_t max_input_length, +// const size_t batch_size, +// const size_t vocab_size, +// const size_t vocab_size_padded, +// void* workspace, +// const size_t workspace_size, +// cudaStream_t stream, +// const bool batch_first); + +// template void invokeLogProbFromLogits(float* cum_log_probs, +// const half* logits, +// const int* input_ids, +// const int* input_lengths, +// const size_t max_input_length, +// const size_t batch_size, +// const size_t vocab_size, +// const size_t vocab_size_padded, +// void* workspace, +// const size_t workspace_size, +// cudaStream_t stream, +// const bool batch_first); } // end of namespace turbomind diff --git a/src/turbomind/kernels/online_softmax_beamsearch_kernels.cu b/src/turbomind/kernels/online_softmax_beamsearch_kernels.cu deleted file mode 100644 index 93c163b10d..0000000000 --- a/src/turbomind/kernels/online_softmax_beamsearch_kernels.cu +++ /dev/null @@ -1,739 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef CUDART_VERSION -#error CUDART_VERSION Undefined! -#elif (CUDART_VERSION >= 11050) -#include -#else -#include "3rdparty/cub/cub.cuh" -#endif - -#include "src/turbomind/kernels/online_softmax_beamsearch_kernels.h" -#include "src/turbomind/kernels/reduce_kernel_utils.cuh" -#include "src/turbomind/utils/cuda_utils.h" - -namespace turbomind { - -#define DO_SPLIT_SMALL_TOP_K_SOFTMAX -static const int SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256; - -#define TOPK_FP16_STORAGE 0 - -template -__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty) -{ - // score = log(prob) / (length)^length_penalty. - if (length_penalty == 0.0f || length == 1) { - return log_prob; - } - return log_prob / static_cast(powf(length, length_penalty)); -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ - void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf) -{ - int thread_id = threadIdx.x; - int block_id = blockIdx.x; - TopK partial; - if (thread_id == 0) { - for (int i = 0; i < MAX_K; ++i) { - partial.p[i] = -1; - partial.u[i] = -FLT_MAX; - } - - int index = block_id * MAX_K * MAX_K; - for (int i = 0; i < MAX_K * MAX_K; i++) { - partial.insert((T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]); - } - - index = block_id * MAX_K; - for (int i = 0; i < MAX_K; i++) { - id_buf[index + i] = partial.p[i]; - } - } -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int* __restrict topk_tmp_id_buf, - const T* __restrict topk_tmp_val_buf, - int* __restrict id_buf, - T* __restrict val_buf) -{ - int thread_id = threadIdx.x; - int block_id = blockIdx.x; - TopK partial; - if (thread_id == 0) { - for (int i = 0; i < MAX_K; ++i) { - partial.p[i] = -1; - partial.u[i] = -FLT_MAX; - } - - int index = block_id * MAX_K * MAX_K; - for (int i = 0; i < MAX_K * MAX_K; i++) { - partial.insert((T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]); - } - - index = block_id * MAX_K; - for (int i = 0; i < MAX_K; i++) { - id_buf[index + i] = partial.p[i]; - val_buf[index + i] = partial.u[i]; - } - } -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* __restrict x, - const T* __restrict y, - int* __restrict z, - float* __restrict v, - float* output_log_probs, - const bool* finished, - const int* sequence_lengths, - BeamHypotheses beam_hyps, - const int V, - const int K, - const int vocab_size, - const float length_penalty, - const T diversity_rate) -{ - int thread_id = threadIdx.x; - int vector_id = blockIdx.x; - - // reposition x, y to data for the current vector - x += vector_id * V; - y += vector_id * V; - - typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; - - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ int selected_beams; - __shared__ float old_cum_log_probs[MAX_K]; - - if (thread_id == 0) { - selected_beams = 0; - } - if (thread_id < K) { - old_cum_log_probs[thread_id] = v[vector_id * K + thread_id]; - } - __syncthreads(); - if (beam_hyps.num_beams != nullptr) { - const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + vector_id; - if (beam_hyps.num_beams[global_batch_idx] == 0 && thread_id == 0) { - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; - } - else if (beam_hyps.num_beams[global_batch_idx] == K) { - return; - } - } - - TopK partial; - for (int i = 0; i < MAX_K; ++i) { - partial.p[i] = -1; - partial.u[i] = -FLT_MAX; - } - - for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) { - int i = elem_id % K; - T elem = length_penalty == 0.0f ? y[elem_id] : - apply_length_penalty(y[elem_id], - finished[vector_id] ? sequence_lengths[vector_id] : - sequence_lengths[vector_id] + 1, - length_penalty); - elem += diversity_rate * (T)i; - int elem_idx = elem_id; // x[elem_id]; - partial.insert(elem, elem_idx); - } - - TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); - - if (thread_id == 0) { - z += vector_id * K; - v += vector_id * K; - - for (int i = 0; i < MAX_K; ++i) { - if (beam_hyps.num_beams != nullptr && x[total.p[i]] % vocab_size == beam_hyps.end_ids[vector_id]) { - // if beam_token does not belong to top num_beams tokens, it should not be added. Refer from - // https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257 - if (i >= K) { - // do nothing - } - else { - const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + vector_id; - const float normed_score = (float)total.u[i]; - const int num_beam = beam_hyps.num_beams[global_batch_idx]; - int beam_idx = num_beam; - // If there are beam_width finished sentences, check that the score of selected candidatet - // is higher than min_normed_score or not. If current score is better, replace worst one - // and update the min_normed_score. - if (num_beam == K) { - if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) { - // end the tracing and exist this for loop - selected_beams = K; - break; - } - else { - // find the beam index which's score = min_normed_score, erase it. - for (int j = 0; j < K; j++) { - if (beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] - == beam_hyps.min_normed_scores[global_batch_idx]) { - beam_idx = j; - beam_hyps.num_beams[global_batch_idx]--; - - beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; - beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] = normed_score; - for (int l = 0; l < K; l++) { - beam_hyps.min_normed_scores[global_batch_idx] = - min(beam_hyps.min_normed_scores[global_batch_idx], - beam_hyps.normed_scores[global_batch_idx * (K * 2) + l]); - } - break; - } - } - } - } - const int tgt_id_offset = - ((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * (K * 2) + beam_idx) - * (beam_hyps.max_seq_len); - beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = beam_hyps.end_ids[vector_id]; - if (beam_hyps.log_probs != nullptr) { - beam_hyps.log_probs[tgt_id_offset + beam_hyps.step] = - (float)y[total.p[i]] - old_cum_log_probs[(x[total.p[i]] / vocab_size) % K]; - } - - int prev_id = (x[total.p[i]] / vocab_size) % K; - for (int j = beam_hyps.step - 1; j >= 0; j--) { - const int src_idx = j * beam_hyps.batch_size * K - + beam_hyps.ite * beam_hyps.local_batch_size * K + vector_id * K + prev_id; - - beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx]; - if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) { - beam_hyps.log_probs[tgt_id_offset + j] = beam_hyps.log_probs_src[src_idx]; - } - prev_id = beam_hyps.parent_ids_src[src_idx]; - } - const int tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx; - beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step; - beam_hyps.normed_scores[tgt_beam_idx] = normed_score; - beam_hyps.min_normed_scores[global_batch_idx] = - min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]); - - beam_hyps.num_beams[global_batch_idx]++; - beam_hyps.cum_log_probs[tgt_beam_idx] = (float)y[total.p[i]]; - } - } - else if ((beam_hyps.num_beams != nullptr && i < 2 * K) || (beam_hyps.num_beams == nullptr && i < K)) { - z[selected_beams] = x[total.p[i]]; - if (output_log_probs != nullptr) { - output_log_probs[vector_id * K + selected_beams] = - (float)y[total.p[i]] - old_cum_log_probs[(z[selected_beams] / vocab_size) % K]; - } - v[selected_beams] = (float)y[total.p[i]]; - selected_beams++; - } - __syncthreads(); - if (selected_beams >= K) { - break; - } - } - } - if (threadIdx.x == 0 && beam_hyps.num_beams != nullptr) { - if (beam_hyps.num_beams[blockIdx.x] < K) { - beam_hyps.is_done[blockIdx.x] = false; - } - else if (beam_hyps.early_stopping) { - beam_hyps.is_done[blockIdx.x] = true; - } - } -} - -struct __align__(8) MD -{ - float m; - float d; -}; - -__device__ __forceinline__ MD reduce_md_op(MD a, MD b) -{ - bool a_bigger = (a.m > b.m); - MD bigger_m = a_bigger ? a : b; - MD smaller_m = a_bigger ? b : a; - MD res; - res.d = bigger_m.d + smaller_m.d * __expf(smaller_m.m - bigger_m.m); - res.m = bigger_m.m; - return res; -} - -template -struct TopKMD { - MD md; - TopK topk; -}; - -template -__device__ __forceinline__ TopKMD reduce_topk_md_op(const TopKMD& a, const TopKMD& b) -{ - TopKMD res; - res.md = reduce_md_op(a.md, b.md); - res.topk = reduce_topk_op(a.topk, b.topk); - return res; -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_kernel(const T* __restrict x, - const T* __restrict b, - const float* __restrict c, - const bool* __restrict finished, - int* __restrict z, - T* __restrict v, - int V, - int K, - const int* __restrict end_ids) -{ - int thread_id = threadIdx.x; - int vector_id = blockIdx.x; - - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - - // reposition y to data for the current vector - x += vector_id * V; - - typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - TopKMD partial; - bool finish = finished[vector_id]; - for (int i = 0; i < MAX_K; ++i) { - partial.topk.p[i] = -1; - partial.topk.u[i] = -MAX_T_VAL; - } - partial.md.m = -MAX_T_VAL; - partial.md.d = 0.0F; - - if (finish) { - for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) { - float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; - MD new_elem{elem, 1.0F}; - partial.md = reduce_md_op(partial.md, new_elem); - partial.topk.insert(elem, elem_id); - // if (elem_id > THREADBLOCK_SIZE * MAX_K && (elem_id == E)) break; - } - } - else { - for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) { - float elem = x[elem_id] + b[elem_id]; - MD new_elem{elem, 1.0F}; - partial.md = reduce_md_op(partial.md, new_elem); - partial.topk.insert(elem, elem_id); - } - } - - TopKMD total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op); - - if (thread_id == 0) { - z += vector_id * K; - v += vector_id * K; - c += vector_id; - - // float d_total_inverse = __fdividef(1.0F, total.md.d); - float d_total_log = logf(total.md.d); - for (int i = 0; i < MAX_K; ++i) { - // float val = __expf(total.topk.u[i] - total.md.m) * d_total_inverse; - float val = total.topk.u[i] - total.md.m - d_total_log; - if (i < K) { - z[i] = total.topk.p[i] + vector_id * V; // faster transformer needs absolute id - v[i] = val + c[0]; - } - } - } -} - -template -__launch_bounds__(THREADBLOCK_SIZE, 1) __global__ - void beam_online_softmax_topk_stage1_kernel(const T* __restrict x, - const T* __restrict b, - const bool* __restrict finished, - float* __restrict t, - int V, - int K, - const int* __restrict end_ids) -{ - int thread_id = threadIdx.x; - int vector_id = blockIdx.x; // batch beam index. - - const int PACKED_TOP_KMD_SIZE = 2 * MAX_K + 2; - - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - - // one will have multiple sections per V - const int v_local = (V + gridDim.y - 1) / gridDim.y; - const int section_start = v_local * blockIdx.y; - int section_end = section_start + v_local; - section_end = (section_end > V) ? V : section_end; - - // reposition x to data for the current vector - x += vector_id * V; -#if TOPK_FP16_STORAGE == 1 - typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; -#else - typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; -#endif - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ float buf_s[PACKED_TOP_KMD_SIZE]; // save intermediate result - -#if TOPK_FP16_STORAGE == 1 - TopKMD<__half, MAX_K> partial; -#else - TopKMD partial; -#endif - bool finish = finished[vector_id]; - for (int i = 0; i < MAX_K; ++i) { - partial.topk.p[i] = -1; - partial.topk.u[i] = -MAX_T_VAL; - } - partial.md.m = -MAX_T_VAL; - partial.md.d = 0.0F; - - if (finish) { -#pragma unroll 1 - for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) { - float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; - MD new_elem{elem, 1.0F}; - partial.md = reduce_md_op(partial.md, new_elem); - partial.topk.insert(elem, elem_id); - } - } - else { -#pragma unroll 1 - for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) { - T bias = b == nullptr ? (T)0.0f : b[elem_id]; // gpt-2 does not use bias - T elem = x[elem_id] + bias; - MD new_elem{elem, 1.0F}; - partial.md = reduce_md_op(partial.md, new_elem); - partial.topk.insert(elem, elem_id); - } - } - -#if TOPK_FP16_STORAGE == 1 - TopKMD<__half, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op<__half, MAX_K>); -#else - TopKMD total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op); -#endif - - if (thread_id == 0) { - for (int i = 0; i < 2 * K; i++) { - reinterpret_cast(buf_s)[i] = total.topk.p[i] + vector_id * V; // faster transformer needs absolute id - buf_s[MAX_K + i] = total.topk.u[i]; - } - buf_s[2 * MAX_K] = total.md.d; - buf_s[2 * MAX_K + 1] = total.md.m; - } - __syncthreads(); - for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE) { - t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id]; - } -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_stage2_kernel( - const float* __restrict x, const float* __restrict c, int* __restrict z, T* __restrict v, int K, int parts_per_beam) -{ - const int vector_id = blockIdx.x; - const int thread_id = threadIdx.x; - const int PACKED_TOP_KMD_SIZE = 2 * MAX_K + 2; - - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - - extern __shared__ char buf_s_[]; // intermediate result - float* buf_s = reinterpret_cast(buf_s_); - //__shared__ float buf_s[PACKED_TOP_KMD_SIZE * THREADBLOCK_SIZE]; // intermediate result - - typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - x += vector_id * PACKED_TOP_KMD_SIZE * parts_per_beam; - - TopKMD partial; - for (int i = 0; i < MAX_K; ++i) { - partial.topk.p[i] = -1; - partial.topk.u[i] = -MAX_T_VAL; - } - partial.md.m = -MAX_T_VAL; - partial.md.d = 0.0F; - - // load and unpack into registers through smem - for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * parts_per_beam; idx += THREADBLOCK_SIZE) { - buf_s[idx] = x[idx]; - } - __syncthreads(); - - if (threadIdx.x < parts_per_beam) { - float* b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE; - for (int i = 0; i < 2 * K; i++) { - partial.topk.p[i] = reinterpret_cast(b_s)[i]; - partial.topk.u[i] = b_s[MAX_K + i]; - } - partial.md.d = b_s[2 * MAX_K]; - partial.md.m = b_s[2 * MAX_K + 1]; - } - __syncthreads(); - - TopKMD total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op); - - if (thread_id == 0) { - z += vector_id * 2 * K; - v += vector_id * 2 * K; - c += vector_id; - - float d_total_log = logf(total.md.d); - for (int i = 0; i < MAX_K; ++i) { - float val = (float)total.topk.u[i] - total.md.m - d_total_log; - if (i < 2 * K) { - z[i] = total.topk.p[i]; - v[i] = (float)val + (float)c[0]; - } - } - } -} - -template -void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, - const float* cum_log_probs, - int* ids, - T* vals, - int batch_size, - int beam_width, - int parts_per_beam, - cudaStream_t stream) -{ - // might rewrite beam_online_softmax_topk_stage2_kernel no to depend on constant block size - // in oreder to reduce compilation time - int smem_stage2_size = parts_per_beam * (2 * MAX_K + 2) * sizeof(float); - - if (parts_per_beam <= 32) { - beam_online_softmax_topk_stage2_kernel<<>>( - temp_storage, cum_log_probs, ids, vals, beam_width, parts_per_beam); - return; - } - if (parts_per_beam <= 64) { - beam_online_softmax_topk_stage2_kernel<<>>( - temp_storage, cum_log_probs, ids, vals, beam_width, parts_per_beam); - return; - } - if (parts_per_beam <= 128) { - beam_online_softmax_topk_stage2_kernel - <<>>( - temp_storage, cum_log_probs, ids, vals, beam_width, parts_per_beam); - return; - } - assert(0); -} - -template -void topK_softMax_kernelLauncher(const T* log_probs, - const T* bias, - const bool* finished, - const int* sequence_lengths, - float* cum_log_probs, - float* output_log_probs, - int* ids, - void* temp_storage, - const int temp_storage_size, - BeamHypotheses* beam_hyps, - const int batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids, - T diversity_rate, - const float length_penalty, - cudaStream_t stream) -{ - const int items_per_thread = 1; - const int block_sz = (MAX_K < 16) ? (MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128 : 64; - // const int block_sz = SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE; - - assert(temp_storage_size % 2 == 0); - assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width * 2); - // Beam search needs the sequence lengths of beams to apply length penalty. - assert(length_penalty == 0.0f || sequence_lengths != nullptr); - - const int topk_buf_offset = ceil(batch_size * beam_width * beam_width * 2 / 4.) * 4; - int* topk_tmp_id_buf = reinterpret_cast(temp_storage); - T* topk_tmp_val_buf = reinterpret_cast(topk_tmp_id_buf + topk_buf_offset); - float* tmp_buffer = reinterpret_cast(topk_tmp_val_buf + topk_buf_offset); - -#ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX - int voc_parts = 4; - if (batch_size * beam_width < 256) { - // Volta has 80 SMs, so we aim for three waves - voc_parts = (240 + batch_size * beam_width - 1) / (batch_size * beam_width); - voc_parts = std::min(128, voc_parts); // we implement up to 128 - } - dim3 grid(batch_size * beam_width, voc_parts); - cudaFuncSetAttribute(beam_online_softmax_topk_stage1_kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, - cudaSharedmemCarveoutMaxL1); - beam_online_softmax_topk_stage1_kernel - <<>>(log_probs, bias, finished, tmp_buffer, vocab_size, beam_width, end_ids); - sync_check_cuda_error(); -#endif - if (beam_width > 1) { -#ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX - beam_online_softmax_topk_stage2_kernelLauncher( - tmp_buffer, cum_log_probs, topk_tmp_id_buf, topk_tmp_val_buf, batch_size, beam_width, voc_parts, stream); - sync_check_cuda_error(); -#else - beam_online_softmax_topk_kernel - <<>>(log_probs, - bias, - cum_log_probs, - finished, - topk_tmp_id_buf, - topk_tmp_val_buf, - vocab_size, - beam_width, - end_ids); -#endif -#if 0 - // wrong result with diversity_rate != 0.f - batch_topK_kernel<<>> - (topk_tmp_id_buf, topk_tmp_val_buf, ids, cum_log_probs); -#else - // We need 2*MAX_K candidates because at most k candidates are finished, and we - // will not put them into next iteration - batch_topk_kernel<<>>(topk_tmp_id_buf, - topk_tmp_val_buf, - ids, - cum_log_probs, - output_log_probs, - finished, - sequence_lengths, - *beam_hyps, - beam_width * beam_width * 2, - beam_width, - vocab_size, - length_penalty, - diversity_rate); - sync_check_cuda_error(); -#endif - } - else { - FT_CHECK(false); -#ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX - beam_online_softmax_topk_stage2_kernelLauncher( - tmp_buffer, cum_log_probs, ids, cum_log_probs, batch_size, beam_width, voc_parts, stream); -#else - beam_online_softmax_topk_kernel - <<>>( - log_probs, bias, cum_log_probs, finished, ids, cum_log_probs, vocab_size, beam_width, end_ids); -#endif - } -} - -#define CASE_K(K, MAX_K) \ - case K ... MAX_K: \ - topK_softMax_kernelLauncher(log_probs, \ - bias, \ - finished, \ - sequence_lengths, \ - cum_log_probs, \ - output_log_probs, \ - ids, \ - temp_storage, \ - temp_storage_size, \ - beam_hyps, \ - batch_size, \ - beam_width, \ - vocab_size, \ - end_ids, \ - diversity_rate, \ - length_penalty, \ - stream); \ - break; - -template -void invokeTopkSoftMax(const T* log_probs, - const T* bias, - const bool* finished, - const int* sequence_lengths, - float* cum_log_probs, - float* output_log_probs, - int* ids, - void* temp_storage, - const int temp_storage_size, - BeamHypotheses* beam_hyps, - const int batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids, - const float diversity_rate, - const float length_penalty, - cudaStream_t stream) -{ - switch (beam_width) { - CASE_K(1, 4); - CASE_K(5, 8); - CASE_K(9, 16); - CASE_K(17, 32); - CASE_K(33, 64); - default: - throw std::runtime_error(fmtstr("Topk kernel of beam search does not support beam_width=%d", beam_width)); - } -} - -#undef CASE_K - -template void invokeTopkSoftMax(const float* log_probs, - const float* bias, - const bool* finished, - const int* sequence_lengths, - float* cum_log_probs, - float* output_log_probs, - int* ids, - void* tmp_storage, - const int temp_storage_size, - BeamHypotheses* beam_hyps, - const int batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids, - const float diversity_rate, - const float length_penalty, - cudaStream_t stream); - -template void invokeTopkSoftMax(const half* log_probs, - const half* bias, - const bool* finished, - const int* sequence_lengths, - float* cum_log_probs, - float* output_log_probs, - int* ids, - void* tmp_storage, - const int temp_storage_size, - BeamHypotheses* beam_hyps, - const int batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids, - const float diversity_rate, - const float length_penalty, - cudaStream_t stream); - -} // end of namespace turbomind diff --git a/src/turbomind/kernels/online_softmax_beamsearch_kernels.h b/src/turbomind/kernels/online_softmax_beamsearch_kernels.h deleted file mode 100644 index 7717fbaed1..0000000000 --- a/src/turbomind/kernels/online_softmax_beamsearch_kernels.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "src/turbomind/kernels/beam_search_topk_kernels.h" - -namespace turbomind { - -template -void invokeTopkSoftMax(const T* log_probs, - const T* bias, - const bool* finished, - const int* sequence_lengths, - float* cum_log_probs, - float* output_log_probs, - int* ids, - void* tmp_storage, - const int temp_storage_size, - BeamHypotheses* beam_hyps, - const int batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids, - const float diversity_rate, - const float length_penalty, - cudaStream_t stream); - -} // namespace turbomind diff --git a/src/turbomind/layers/CMakeLists.txt b/src/turbomind/layers/CMakeLists.txt index 1333c6c5b9..1a32d233a2 100644 --- a/src/turbomind/layers/CMakeLists.txt +++ b/src/turbomind/layers/CMakeLists.txt @@ -14,13 +14,10 @@ cmake_minimum_required(VERSION 3.8) -add_subdirectory(beam_search_layers) add_subdirectory(sampling_layers) add_library(DynamicDecodeLayer STATIC DynamicDecodeLayer.cc) set_property(TARGET DynamicDecodeLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET DynamicDecodeLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(DynamicDecodeLayer PUBLIC -lcudart - TopKSamplingLayer TopPSamplingLayer - OnlineBeamSearchLayer BeamSearchLayer ban_bad_words stop_criteria - gpt_kernels tensor nvtx_utils) +target_link_libraries(DynamicDecodeLayer PUBLIC -lcudart TopKSamplingLayer + TopPSamplingLayer ban_bad_words stop_criteria gpt_kernels tensor nvtx_utils) diff --git a/src/turbomind/layers/DynamicDecodeLayer.cc b/src/turbomind/layers/DynamicDecodeLayer.cc index a5f7a6ffda..3f79bfe3e1 100644 --- a/src/turbomind/layers/DynamicDecodeLayer.cc +++ b/src/turbomind/layers/DynamicDecodeLayer.cc @@ -17,11 +17,9 @@ #include "src/turbomind/layers/DynamicDecodeLayer.h" #include "src/turbomind/kernels/ban_bad_words.h" #include "src/turbomind/kernels/stop_criteria_kernels.h" -#include "src/turbomind/layers/beam_search_layers/BaseBeamSearchLayer.h" -#include "src/turbomind/layers/beam_search_layers/BeamSearchLayer.h" -#include "src/turbomind/layers/beam_search_layers/OnlineBeamSearchLayer.h" #include "src/turbomind/layers/sampling_layers/TopKSamplingLayer.h" #include "src/turbomind/layers/sampling_layers/TopPSamplingLayer.h" +#include "src/turbomind/utils/cuda_utils.h" namespace turbomind { @@ -45,37 +43,6 @@ template void DynamicDecodeLayer::initialize() { TM_LOG_DEBUG(__PRETTY_FUNCTION__); - online_beamsearch_decode_ = new OnlineBeamSearchLayer(0, // max_batch_size, deprecated - 0, // local_head_num, deprecated - 0, // size_per_head, deprecated - 0, // beam_width, deprecated - vocab_size_, - vocab_size_padded_, - 0, // end_id, deprecated - 0.0f, // beam_search_diversity_rate_, deprecated - 1.0f, // temperature_, deprecated - 0.0f, // len_penalty_, deprecated - 1.0f, // repetition_penalty_, deprecated - stream_, - cublas_wrapper_, - allocator_, - is_free_buffer_after_forward_); - - beamsearch_decode_ = new BeamSearchLayer(0, // max_batch_size, deprecated - 0, // local_head_num, deprecated - 0, // size_per_head, deprecated - 0, // beam_width, deprecated - vocab_size_, - vocab_size_padded_, - 0, // end_id, deprecated - 0.0f, // beam_search_diversity_rate_, deprecated - 1.0f, // temperature_, deprecated - 0.0f, // len_penalty_, deprecated - 1.0f, // repetition_penalty_, deprecated - stream_, - cublas_wrapper_, - allocator_, - is_free_buffer_after_forward_); topk_decode_ = new TopKSamplingLayer(0, vocab_size_, @@ -131,8 +98,6 @@ template DynamicDecodeLayer::~DynamicDecodeLayer() { TM_LOG_DEBUG(__PRETTY_FUNCTION__); - delete online_beamsearch_decode_; - delete beamsearch_decode_; delete topk_decode_; delete topp_decode_; freeBuffer(); @@ -284,105 +249,7 @@ void DynamicDecodeLayer::forward(TensorMap* output_tensors, TensorMap* input_ // dynamic decode GPT if (beam_width > 1) { - // Because we still not support batch beam search now, so we need to compute one by one if there are different - // runtime arguments. - const size_t dynamic_decode_batch_size = has_diff_runtime_args_ ? 1 : local_batch_size; - const int dynamic_decode_total_iteration = local_batch_size / dynamic_decode_batch_size; - - for (uint dynamic_ite = ite * dynamic_decode_total_iteration; - dynamic_ite < (ite + 1) * dynamic_decode_total_iteration; - ++dynamic_ite) { - const int dynamic_id_offset = dynamic_ite * dynamic_decode_batch_size * beam_width; - const int dynamic_decode_vocab_size_units_offset = dynamic_id_offset * vocab_size_padded_; - - // common inputs - Tensor logits = input_tensors->at("logits"); - Tensor end_id = input_tensors->at("end_id"); - - TensorMap dynamic_decode_input_tensors( - {{"logits", - Tensor{logits.where, - logits.type, - {dynamic_decode_batch_size, logits.shape[1], logits.shape[2]}, - logits.getPtrWithOffset(dynamic_decode_vocab_size_units_offset)}}, - {"step", input_tensors->at("step")}, - {"max_input_length", input_tensors->at("max_input_length")}, - {"end_id", - Tensor{end_id.where, - end_id.type, - {dynamic_decode_batch_size}, - end_id.getPtrWithOffset(dynamic_ite * dynamic_decode_batch_size)}}, - {"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &dynamic_ite}}}); - - if (input_tensors->isExist("embedding_bias")) { - dynamic_decode_input_tensors.insert({"embedding_bias", input_tensors->at("embedding_bias")}); - } - if (input_tensors->isExist("input_lengths")) { - Tensor input_lengths = input_tensors->at("input_lengths"); - dynamic_decode_input_tensors.insert( - {"input_lengths", - input_lengths.slice({dynamic_decode_batch_size, input_lengths.shape[1]}, dynamic_id_offset)}); - } - for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { - if (t->first.find("random_seed") == std::string::npos) { - dynamic_decode_input_tensors.insert(*t); - } - } - - // common outputs - TensorMap dynamic_decode_output_tensors({{"output_ids", output_tensors->at("output_ids")}}); - if (output_tensors->isExist("sequence_length")) { - Tensor sequence_length = output_tensors->at("sequence_length"); - dynamic_decode_output_tensors.insert({"sequence_length", - Tensor{sequence_length.where, - sequence_length.type, - {dynamic_decode_batch_size * beam_width}, - sequence_length.getPtrWithOffset(dynamic_id_offset)}}); - } - if (output_tensors->isExist("finished")) { - Tensor finished = output_tensors->at("finished"); - dynamic_decode_output_tensors.insert({"finished", - Tensor{finished.where, - finished.type, - {dynamic_decode_batch_size * beam_width}, - finished.getPtrWithOffset(dynamic_id_offset)}}); - } - if (output_tensors->isExist("cum_log_probs")) { - Tensor cum_log_probs = output_tensors->at("cum_log_probs"); - dynamic_decode_output_tensors.insert({"cum_log_probs", - Tensor{cum_log_probs.where, - cum_log_probs.type, - {dynamic_decode_batch_size * beam_width}, - cum_log_probs.getPtrWithOffset(dynamic_id_offset)}}); - } - if (output_tensors->isExist("beam_hyps")) { - dynamic_decode_output_tensors.insert("beam_hyps", output_tensors->at("beam_hyps")); - } - - if (output_tensors->isExist("output_log_probs")) { - dynamic_decode_output_tensors.insert({"output_log_probs", output_tensors->at("output_log_probs")}); - } - - dynamic_decode_input_tensors.insert({"src_cache_indirection", input_tensors->at("src_cache_indirection")}); - - dynamic_decode_output_tensors.insert({"parent_ids", output_tensors->at("parent_ids")}); - dynamic_decode_output_tensors.insert( - {"tgt_cache_indirection", output_tensors->at("tgt_cache_indirection")}); - - FT_CHECK_WITH_INFO(dynamic_decode_output_tensors.isExist("cum_log_probs"), - "cum_log_probs should be provided in beam search."); - - if (true || beam_width < 16 - || (output_tensors->isExist("beam_hyps") - && input_tensors->getVal("beam_search_diversity_rate", 0.0f) != 0.0f)) { - // only online_beamsearch_decode_ support beam_search_diversity_rate when beam_hyps is used - online_beamsearch_decode_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); - } - else { - FT_CHECK(false); // deprecate this module - beamsearch_decode_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors); - } - } // end of dynamic_ite + FT_CHECK_WITH_INFO(0, "Beam-search is not supported."); } else { // beam_width=1 // In sampling, we have supported batch sampling. So, we always compute all sentences once. diff --git a/src/turbomind/layers/DynamicDecodeLayer.h b/src/turbomind/layers/DynamicDecodeLayer.h index fd4a77600e..cae2118c19 100644 --- a/src/turbomind/layers/DynamicDecodeLayer.h +++ b/src/turbomind/layers/DynamicDecodeLayer.h @@ -19,7 +19,6 @@ #include #include -#include "src/turbomind/kernels/beam_search_topk_kernels.h" #include "src/turbomind/layers/BaseLayer.h" #include "src/turbomind/layers/DynamicDecodeBaseLayer.h" #include "src/turbomind/layers/sampling_layers/TopPSamplingLayer.h" @@ -34,8 +33,6 @@ class DynamicDecodeLayer: public BaseLayer { void initialize(); bool hasDiffRuntimeArgs(TensorMap* input_tensors); - DynamicDecodeBaseLayer* online_beamsearch_decode_; - DynamicDecodeBaseLayer* beamsearch_decode_; DynamicDecodeBaseLayer* topk_decode_; DynamicDecodeBaseLayer* topp_decode_; diff --git a/src/turbomind/layers/FfnFP8Layer.cc b/src/turbomind/layers/FfnFP8Layer.cc deleted file mode 100644 index 28f5e5ed8b..0000000000 --- a/src/turbomind/layers/FfnFP8Layer.cc +++ /dev/null @@ -1,535 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/turbomind/layers/FfnFP8Layer.h" -#include "src/turbomind/kernels/activation_fp8_kernels.h" -#include "src/turbomind/utils/cublasFP8MMWrapper.h" -#include "src/turbomind/utils/nvtx_utils.h" - -namespace turbomind { - -template -void FfnFP8Layer::forward(TensorMap* output_tensors, - TensorMap* input_tensors, - const FfnFP8Weight* ffn_weights) -{ - // input tensors: - // input_hidden_state [token_num, d_model], - - // output tensors: - // output_hidden_state [token_num, d_model], - - TM_LOG_DEBUG(__PRETTY_FUNCTION__); - FT_CHECK(input_tensors->size() == 1); - FT_CHECK(output_tensors->size() == 1); - - const int m = input_tensors->at("input_hidden_state").shape[0]; - const int d_model = input_tensors->at("input_hidden_state").shape[1]; - const T1* input_hidden_state = input_tensors->at("input_hidden_state").getPtr(); - Tensor output_tensor = output_tensors->at("output_hidden_state"); - allocateBuffer(m); - -#ifdef FUSE_GEMM_ACT - if (fp8_mode_ == 1) { - const float alpha = 1.0f; - const float beta = 0.0f; - reinterpret_cast(cublas_wrapper_) - ->Gemm(inter_buf_bf16_, - (int)1, - (int)m, - (int)inter_size_, - (int)d_model, - (int64_t)0, - (int64_t)0, - (int64_t)0, - &alpha, - &beta, - input_hidden_state, - ffn_weights->intermediate_weight.kernel, - ffn_weights->intermediate_weight.input_scale, - ffn_weights->intermediate_weight.per_channel_scale_min, // identity_scale - stream_); - invokeAddBiasActivation(m, - ffn_weights->intermediate_weight.bias, - ffn_weights->intermediate_weight.output_scale, - ffn_weights->intermediate_weight.scale, - ffn_weights->intermediate_weight.per_channel_scale_min, - ffn_weights->output_weight.input_scale_inv); - } - else if (fp8_mode_ == 2) { -#ifdef USE_QGMMA - if (getActivationType() == ActivationType::Gelu) { - PUSH_RANGE("FFN gemm 1 bias gelu"); - reinterpret_cast(cublas_wrapper_) - ->Conv1x1Gemm(inter_buf_, - m, - inter_size_, - d_model, - input_hidden_state, - ffn_weights->intermediate_weight.kernel, - ffn_weights->intermediate_weight.bias, - *(ffn_weights->intermediate_weight.input_h_scale), // scale_a, - *(ffn_weights->intermediate_weight.weight_h_scale), // scale_b, - *(ffn_weights->output_weight.input_h_scale_inv), // scale_d, - stream_); - POP_RANGE; - } - else if (getActivationType() == ActivationType::Relu) { - reinterpret_cast(cublas_wrapper_) - ->Conv1x1Gemm(inter_buf_, - m, - inter_size_, - d_model, - input_hidden_state, - ffn_weights->intermediate_weight.kernel, - ffn_weights->intermediate_weight.bias, - *(ffn_weights->intermediate_weight.input_h_scale), // scale_a, - *(ffn_weights->intermediate_weight.weight_h_scale), // scale_b, - *(ffn_weights->output_weight.input_h_scale_inv), // scale_d, - stream_); - } -#else // USE_QGMMA - const float alpha = 1.0f; - const float beta = 0.0f; - if (getActivationType() == ActivationType::Gelu) { - reinterpret_cast(cublas_wrapper_) -#ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE - ->Gemm_Bias_Act(inter_buf_bf16_, -#else // FP8_GEMM_OUTPUT_QUANT_DISABLE - ->Gemm_Bias_Act(inter_buf_, -#endif // FP8_GEMM_OUTPUT_QUANT_DISABLE - (int)1, - (int)m, - (int)inter_size_, - (int)d_model, - (int64_t)0, - (int64_t)0, - (int64_t)0, - &alpha, - &beta, - input_hidden_state, - ffn_weights->intermediate_weight.kernel, - ffn_weights->intermediate_weight.input_scale, - ffn_weights->intermediate_weight.weight_scale, - ffn_weights->intermediate_weight.bias, - ffn_weights->intermediate_weight.output_scale, - stream_); - } - else if (getActivationType() == ActivationType::Relu) { - reinterpret_cast(cublas_wrapper_) -#ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE - ->Gemm_Bias_Act(inter_buf_bf16_, -#else // FP8_GEMM_OUTPUT_QUANT_DISABLE - ->Gemm_Bias_Act(inter_buf_, -#endif // #ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE - (int)1, - (int)m, - (int)inter_size_, - (int)d_model, - (int64_t)0, - (int64_t)0, - (int64_t)0, - &alpha, - &beta, - input_hidden_state, - ffn_weights->intermediate_weight.kernel, - ffn_weights->intermediate_weight.input_scale, - ffn_weights->intermediate_weight.weight_scale, - ffn_weights->intermediate_weight.bias, - ffn_weights->intermediate_weight.output_scale, - stream_); - } -#ifdef FP8_GEMM_OUTPUT_QUANT_DISABLE - invokeQuantizeMatrix( - inter_buf_, ffn_weights->output_weight.input_scale_inv, inter_buf_bf16_, m * inter_size_, 1, stream_); -#endif FP8_GEMM_OUTPUT_QUANT_DISABLE -#endif // USE_QGMMA - } - -#else // FUSE_GEMM_ACT - PUSH_RANGE("FFN gemm 1"); -#ifdef SPARSITY_ENABLED - int m_tmp = m; - if (m_tmp % 8 != 0) { - m_tmp = (m_tmp / 8 + 1) * 8; - } - const int m_padded = m_tmp; - if (sparse_ && cublas_wrapper_->isUseSparse(1, inter_size_, m, d_model)) { - FT_CHECK(false); - // cublas_wrapper_->SpGemm(CUBLAS_OP_N, - // CUBLAS_OP_N, - // inter_size_, - // m_padded, - // d_model, - // ffn_weights->intermediate_weight.sp_kernel, - // input_hidden_state, - // inter_buf_); - } - else { -#endif // SPARSITY_ENABLED - if (fp8_mode_ == 1) { - const float alpha = 1.0f; - const float beta = 0.0f; - reinterpret_cast(cublas_wrapper_) - ->Gemm(inter_buf_bf16_, - (int)1, - (int)m, - (int)inter_size_, - (int)d_model, - (int64_t)0, - (int64_t)0, - (int64_t)0, - &alpha, - &beta, - input_hidden_state, - ffn_weights->intermediate_weight.kernel, - ffn_weights->intermediate_weight.input_scale, - ffn_weights->intermediate_weight.per_channel_scale_min, // identity_scale - stream_); - } - else if (fp8_mode_ == 2) { - const float alpha = 1.0f; - const float beta = 0.0f; - reinterpret_cast(cublas_wrapper_) - ->Gemm(inter_buf_bf16_, - (int)1, - (int)m, - (int)inter_size_, - (int)d_model, - (int64_t)0, - (int64_t)0, - (int64_t)0, - &alpha, - &beta, - input_hidden_state, - ffn_weights->intermediate_weight.kernel, - ffn_weights->intermediate_weight.input_scale, - ffn_weights->intermediate_weight.weight_scale, - stream_); - } -#ifdef SPARSITY_ENABLED - } -#endif // SPARSITY_ENABLED - POP_RANGE; - - PUSH_RANGE("FFN add bias act"); - if (fp8_mode_ == 1) { - invokeAddBiasActivation(m, - ffn_weights->intermediate_weight.bias, - ffn_weights->intermediate_weight.output_scale, - ffn_weights->intermediate_weight.scale, - ffn_weights->intermediate_weight.per_channel_scale_min, - ffn_weights->output_weight.input_scale_inv); - } - else if (fp8_mode_ == 2) { - invokeAddBiasActivation(m, - ffn_weights->intermediate_weight.bias, - ffn_weights->intermediate_weight.output_scale, - nullptr, - nullptr, - ffn_weights->output_weight.input_scale_inv); - } - sync_check_cuda_error(); - POP_RANGE; -#endif // FUSE_GEMM_ACT - - PUSH_RANGE("FFN gemm 2"); -#ifdef SPARSITY_ENABLED - if (sparse_ && cublas_wrapper_->isUseSparse(1, d_model, m, inter_size_)) { - FT_CHECK(false); - // cublas_wrapper_->SpGemm(CUBLAS_OP_N, - // CUBLAS_OP_N, - // d_model, - // m_padded, - // inter_size_, - // ffn_weights->output_weight.sp_kernel, - // inter_buf_, - // output_tensor); - } - else { -#endif SPARSITY_ENABLED - if (fp8_mode_ == 1) { - const float alpha = 1.0f; - const float beta = 0.0f; - if (output_tensor.type == TYPE_BF16) { - reinterpret_cast(cublas_wrapper_) - ->Gemm(output_tensor.getPtr(), - (int)1, - (int)m, - (int)d_model, - (int)inter_size_, - (int64_t)0, - (int64_t)0, - (int64_t)0, - &alpha, - &beta, - (const __nv_fp8_e4m3*)inter_buf_, - (const __nv_fp8_e4m3*)ffn_weights->output_weight.kernel, - ffn_weights->output_weight.input_scale, - ffn_weights->identity_scale, - stream_); - } - else if (output_tensor.type == TYPE_FP8_E4M3) { - const float alpha = 1.0f; - const float beta = 0.0f; - reinterpret_cast(cublas_wrapper_) - ->Gemm(output_tensor.getPtr(), - (int)1, - (int)m, - (int)d_model, - (int)inter_size_, - (int64_t)0, - (int64_t)0, - (int64_t)0, - &alpha, - &beta, - (const __nv_fp8_e4m3*)inter_buf_, - (const __nv_fp8_e4m3*)ffn_weights->output_weight.kernel, - ffn_weights->output_weight.input_scale, - ffn_weights->output_weight.per_channel_scale_min, - ffn_weights->output_weight.output_scale_inv, - stream_); - } - else { - FT_CHECK(false); - } - } - else if (fp8_mode_ == 2) { - if (output_tensor.type == TYPE_BF16) { - const float alpha = 1.0f; - const float beta = 0.0f; - reinterpret_cast(cublas_wrapper_) - ->Gemm(output_tensor.getPtr(), - (int)1, - (int)m, - (int)d_model, - (int)inter_size_, - (int64_t)0, - (int64_t)0, - (int64_t)0, - &alpha, - &beta, - (const __nv_fp8_e4m3*)inter_buf_, - (const __nv_fp8_e4m3*)ffn_weights->output_weight.kernel, - ffn_weights->output_weight.input_scale, - ffn_weights->output_weight.weight_scale, - stream_); - } - else if (output_tensor.type == TYPE_FP8_E4M3) { - // It looks like conv1x1Gemm does not bring better performance for this gemm - // because the k dimension of this gemm is large - // #ifdef USE_QGMMA - // reinterpret_cast(cublas_wrapper_) - // ->Conv1x1Gemm(output_tensor.getPtr(), - // m, - // d_model, - // inter_size_, - // inter_buf_, - // ffn_weights->output_weight.kernel, - // ffn_weights->output_weight.bias, - // *(ffn_weights->output_weight.input_h_scale), // - // scale_a, - // *(ffn_weights->output_weight.weight_h_scale), // - // scale_b, - // *(ffn_weights->output_weight.output_h_scale_inv), // - // scale_d, stream_); - // #else // USE_QGMMA - const float alpha = 1.0f; - const float beta = 0.0f; - reinterpret_cast(cublas_wrapper_) - ->Gemm(output_tensor.getPtr(), - (int)1, - (int)m, - (int)d_model, - (int)inter_size_, - (int64_t)0, - (int64_t)0, - (int64_t)0, - &alpha, - &beta, - (const __nv_fp8_e4m3*)inter_buf_, - (const __nv_fp8_e4m3*)ffn_weights->output_weight.kernel, - ffn_weights->output_weight.input_scale, - ffn_weights->output_weight.weight_scale, - ffn_weights->output_weight.output_scale_inv, - stream_); - // #endif // USE_QGMMA - } - else { - FT_CHECK(false); - } - } -#ifdef SPARSITY_ENABLED - } -#endif // SPARSITY_ENABLED - POP_RANGE; - - sync_check_cuda_error(); - if (is_free_buffer_after_forward_ == true) { - freeBuffer(); - } - sync_check_cuda_error(); -} - -template -FfnFP8Layer::FfnFP8Layer(size_t inter_size, - int fp8_mode, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): - BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr, sparse), - inter_size_(inter_size), - fp8_mode_(fp8_mode) -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); -} - -template -FfnFP8Layer::FfnFP8Layer(FfnFP8Layer const& ffn_layer): - BaseLayer(ffn_layer.stream_, - ffn_layer.cublas_wrapper_, - ffn_layer.allocator_, - ffn_layer.is_free_buffer_after_forward_, - ffn_layer.cuda_device_prop_, - ffn_layer.sparse_), - inter_size_(ffn_layer.inter_size_), - fp8_mode_(ffn_layer.fp8_mode_) -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); -} - -template -FfnFP8Layer::~FfnFP8Layer() -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); - cublas_wrapper_ = nullptr; - freeBuffer(); -} - -template -void FfnFP8Layer::allocateBuffer() -{ - FT_CHECK(false); -} - -template -void FfnFP8Layer::allocateBuffer(size_t token_num) -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); - - inter_buf_ = (T1*)allocator_->reMalloc(inter_buf_, sizeof(T1) * token_num * inter_size_, false); - inter_buf_bf16_ = (T2*)allocator_->reMalloc(inter_buf_bf16_, sizeof(T2) * token_num * inter_size_, false); - is_allocate_buffer_ = true; -} - -template -void FfnFP8Layer::freeBuffer() -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); - if (is_allocate_buffer_) { - allocator_->free((void**)(&inter_buf_)); - allocator_->free((void**)(&inter_buf_bf16_)); - is_allocate_buffer_ = false; - } -} - -template class FfnFP8Layer<__nv_fp8_e4m3, __nv_bfloat16>; - -template -GeluFfnFP8Layer::GeluFfnFP8Layer(size_t inter_size, - int fp8_mode, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): - FfnFP8Layer(inter_size, fp8_mode, stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse) -{ -} - -template -GeluFfnFP8Layer::GeluFfnFP8Layer(GeluFfnFP8Layer const& gelu_ffn_layer): - FfnFP8Layer(gelu_ffn_layer) -{ -} - -template -void GeluFfnFP8Layer::invokeAddBiasActivation(const int m, - const T2* bias, - const float* input_scale, - const float* input_scale_2, - const float* input_scale_2_min, - const float* output_scale) -{ - FP8ActivationParam param{inter_buf_bf16_, - inter_buf_, - bias, - input_scale, - input_scale_2, - input_scale_2_min, - output_scale, - (uint32_t)m, - (uint32_t)inter_size_, - stream_}; - invokeFP8AddBiasGelu(param); -} - -template class GeluFfnFP8Layer<__nv_fp8_e4m3, __nv_bfloat16>; - -template -ReluFfnFP8Layer::ReluFfnFP8Layer(size_t inter_size, - int fp8_mode, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): - FfnFP8Layer(inter_size, fp8_mode, stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse) -{ -} - -template -ReluFfnFP8Layer::ReluFfnFP8Layer(ReluFfnFP8Layer const& relu_ffn_layer): - FfnFP8Layer(relu_ffn_layer) -{ -} - -template -void ReluFfnFP8Layer::invokeAddBiasActivation(const int m, - const T2* bias, - const float* input_scale, - const float* input_scale_2, - const float* input_scale_2_min, - const float* output_scale) -{ - FP8ActivationParam param{inter_buf_bf16_, - inter_buf_, - bias, - input_scale, - input_scale_2, - input_scale_2_min, - output_scale, - (uint32_t)m, - (uint32_t)inter_size_, - stream_}; - invokeFP8AddBiasRelu(param); -} - -template class ReluFfnFP8Layer<__nv_fp8_e4m3, __nv_bfloat16>; - -} // namespace turbomind diff --git a/src/turbomind/layers/FfnFP8Layer.h b/src/turbomind/layers/FfnFP8Layer.h deleted file mode 100644 index d4a31b111a..0000000000 --- a/src/turbomind/layers/FfnFP8Layer.h +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "src/turbomind/layers/BaseLayer.h" -#include "src/turbomind/layers/FfnFP8Weight.h" -#include "src/turbomind/layers/FfnLayer.h" -#include "src/turbomind/utils/memory_utils.h" -#include - -namespace turbomind { - -template -class FfnFP8Layer: public BaseLayer { -private: - void allocateBuffer() override; - void freeBuffer() override; - void allocateBuffer(size_t token_num); - -protected: - const int fp8_mode_; - T1* inter_buf_ = nullptr; - T2* inter_buf_bf16_ = nullptr; - size_t inter_size_; - virtual void invokeAddBiasActivation(const int m, - const T2* bias, - const float* input_scale, - const float* input_scale_2, - const float* input_scale_2_min, - const float* output_scale) = 0; - -public: - FfnFP8Layer(size_t inter_size, - int fp8_mode, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); - - FfnFP8Layer(FfnFP8Layer const& ffn_layer); - - virtual ~FfnFP8Layer(); - - virtual void forward(TensorMap* output_tensors, TensorMap* input_tensors, const FfnFP8Weight* ffn_weights); - virtual ActivationType getActivationType() = 0; -}; - -template -class GeluFfnFP8Layer: public FfnFP8Layer { -public: - GeluFfnFP8Layer(size_t inter_size, - int fp8_mode_, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); - - GeluFfnFP8Layer(GeluFfnFP8Layer const& ffn_layer); - - virtual ~GeluFfnFP8Layer() = default; - ActivationType getActivationType() override - { - return ActivationType::Gelu; - }; - -protected: - using FfnFP8Layer::stream_; - -private: - using FfnFP8Layer::inter_buf_; - using FfnFP8Layer::inter_size_; - using FfnFP8Layer::fp8_mode_; - using FfnFP8Layer::inter_buf_bf16_; - void invokeAddBiasActivation(const int m, - const T2* bias, - const float* input_scale, - const float* input_scale_2, - const float* input_scale_2_min, - const float* output_scale) override; -}; - -template -class ReluFfnFP8Layer: public FfnFP8Layer { -public: - ReluFfnFP8Layer(size_t inter_size, - int fp8_mode, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); - - ReluFfnFP8Layer(ReluFfnFP8Layer const& ffn_layer); - - virtual ~ReluFfnFP8Layer() = default; - ActivationType getActivationType() override - { - return ActivationType::Relu; - }; - -protected: - using FfnFP8Layer::stream_; - -private: - using FfnFP8Layer::inter_buf_; - using FfnFP8Layer::inter_size_; - using FfnFP8Layer::fp8_mode_; - using FfnFP8Layer::inter_buf_bf16_; - void invokeAddBiasActivation(const int m, - const T2* bias, - const float* input_scale, - const float* input_scale_2, - const float* input_scale_2_min, - const float* output_scale) override; -}; - -} // namespace turbomind diff --git a/src/turbomind/layers/FfnFP8Weight.h b/src/turbomind/layers/FfnFP8Weight.h deleted file mode 100644 index 223fea17f1..0000000000 --- a/src/turbomind/layers/FfnFP8Weight.h +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "FfnWeight.h" -#include "src/turbomind/utils/ScaleList.h" -namespace turbomind { - -template -struct FfnFP8Weight: FfnWeight { - ScaleList* scale_list_ptr; - float* identity_scale; - float* identity_h_scale; -}; - -} // namespace turbomind diff --git a/src/turbomind/layers/FfnINT8Weight.h b/src/turbomind/layers/FfnINT8Weight.h deleted file mode 100644 index d208ea33da..0000000000 --- a/src/turbomind/layers/FfnINT8Weight.h +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "FfnWeight.h" -#include "src/turbomind/utils/ScaleList.h" -namespace turbomind { - -template -struct FfnINT8Weight: FfnWeight { - ScaleList* scale_list_ptr; -}; - -} // namespace turbomind diff --git a/src/turbomind/layers/FfnLayerINT8.cc b/src/turbomind/layers/FfnLayerINT8.cc deleted file mode 100644 index 88ccd447e1..0000000000 --- a/src/turbomind/layers/FfnLayerINT8.cc +++ /dev/null @@ -1,340 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "FfnLayerINT8.h" -#include "src/turbomind/utils/nvtx_utils.h" - -namespace turbomind { - -template -void FfnLayerINT8::forward(std::vector* output_tensors, - const std::vector* input_tensors, - const FfnWeight* ffn_weights) -{ - // input_tensors: [input (token_num, hidden_dimension)] - // output_tensors: [output (token_num, hidden_dimension)] - ScaleList* scale_list = ((const FfnINT8Weight*)ffn_weights)->scale_list_ptr; - - cublasINT8MMWrapper* cublas_wrapper = (cublasINT8MMWrapper*)cublas_wrapper_; - - FT_CHECK(isValidTokenNum(input_tensors->at(0).shape[0])); - allocateBuffer(); - - const int m = static_cast(input_tensors->at(0).shape[0]); -#ifdef SPARSITY_ENABLED - int m_tmp = m; - if (m_tmp % 16 != 0) { - m_tmp = (m_tmp / 16 + 1) * 16; - } - const int m_padded = m_tmp; -#endif - - int32_t* output_tensor = output_tensors->at(0).getPtr(); - const int8_t* input_tensor = input_tensors->at(0).getPtr(); - - PUSH_RANGE("FFN gemm 1"); - if (int8_mode_ == 1) { - cublas_wrapper->Gemm(inter_int_buf_, - 1, - m, - inter_size_, - hidden_units_, - 0, - 0, - 0, - input_tensor, - (int8_t*)(ffn_weights->intermediate_weight.kernel)); - } - else if (int8_mode_ == 2 || int8_mode_ == 3) { -#ifdef SPARSITY_ENABLED - if (sparse_) { - cublas_wrapper->SpGemm(inter_size_, - m_padded, - hidden_units_, - scale_list->h_scale_list_[scale_list->p3_offset_ + 6], - (int8_t*)(ffn_weights->intermediate_weight.sp_kernel), - input_tensor, - (int8_t*)inter_int_buf_); - } - else { -#endif - cublas_wrapper->Gemm((int8_t*)inter_int_buf_, - 1, - m, - inter_size_, - hidden_units_, - 0, - 0, - 0, - scale_list->h_scale_list_[scale_list->p3_offset_ + 6], - input_tensor, - (int8_t*)(ffn_weights->intermediate_weight.kernel)); -#ifdef SPARSITY_ENABLED - } -#endif - } - POP_RANGE; - - PUSH_RANGE("add bias act"); - invokeAddBiasActivation(m, ffn_weights->intermediate_weight.bias, scale_list); - POP_RANGE; - sync_check_cuda_error(); - - PUSH_RANGE("FFN gemm 2"); - if (int8_mode_ == 1) { - cublas_wrapper->Gemm(output_tensor, - 1, - m, - hidden_units_, - inter_size_, - 0, - 0, - 0, - inter_buf_, - (int8_t*)(ffn_weights->output_weight.kernel)); - } - else if (int8_mode_ == 2 || int8_mode_ == 3) { -#ifdef SPARSITY_ENABLED - if (sparse_) { - cublas_wrapper->SpGemm(hidden_units_, - m_padded, - inter_size_, - scale_list->h_scale_list_[scale_list->p3_offset_ + 7], - (int8_t*)(ffn_weights->output_weight.sp_kernel), - inter_buf_, - (int8_t*)output_tensor); - } - else { -#endif - cublas_wrapper->Gemm((int8_t*)output_tensor, - 1, - m, - hidden_units_, - inter_size_, - 0, - 0, - 0, - scale_list->h_scale_list_[scale_list->p3_offset_ + 7], - inter_buf_, - (int8_t*)(ffn_weights->output_weight.kernel)); -#ifdef SPARSITY_ENABLED - } -#endif - } - POP_RANGE; - - sync_check_cuda_error(); - if (is_free_buffer_after_forward_ == true) { - freeBuffer(); - } - sync_check_cuda_error(); -} - -template -FfnLayerINT8::FfnLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int int8_mode, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): - BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), - max_token_num_(max_batch_size * max_seq_len), - head_num_(head_num), - size_per_head_(size_per_head), - hidden_units_(head_num * size_per_head), - inter_size_(inter_size), - int8_mode_(int8_mode), - sparse_(sparse) -{ -} - -template -FfnLayerINT8::FfnLayerINT8(FfnLayerINT8 const& ffn_layer): - BaseLayer( - ffn_layer.stream_, ffn_layer.cublas_wrapper_, ffn_layer.allocator_, ffn_layer.is_free_buffer_after_forward_), - max_token_num_(ffn_layer.max_token_num_), - head_num_(ffn_layer.head_num_), - size_per_head_(ffn_layer.size_per_head_), - hidden_units_(ffn_layer.hidden_units_), - inter_size_(ffn_layer.inter_size_), - int8_mode_(ffn_layer.int8_mode_), - sparse_(ffn_layer.sparse_) -{ -} - -template -FfnLayerINT8::~FfnLayerINT8() -{ - cublas_wrapper_ = nullptr; - freeBuffer(); -} - -template -void FfnLayerINT8::allocateBuffer() -{ - if (is_allocate_buffer_ == false) { - inter_int_buf_ = - (int32_t*)allocator_->reMalloc(inter_int_buf_, sizeof(int32_t) * max_token_num_ * inter_size_, false); - inter_buf_ = (int8_t*)allocator_->reMalloc(inter_buf_, sizeof(int8_t) * max_token_num_ * inter_size_, false); - is_allocate_buffer_ = true; - } -} - -template -void FfnLayerINT8::freeBuffer() -{ - if (is_allocate_buffer_ == true) { - allocator_->free((void**)(&inter_int_buf_)); - allocator_->free((void**)(&inter_buf_)); - is_allocate_buffer_ = false; - } -} - -template -bool FfnLayerINT8::isValidTokenNum(size_t token_num) -{ - if (max_token_num_ == 0) { - max_token_num_ = token_num; - return true; - } - else { - return token_num <= max_token_num_; - } -} - -template class FfnLayerINT8; -template class FfnLayerINT8; - -template -GeluFfnLayerINT8::GeluFfnLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int int8_mode, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse): - FfnLayerINT8(max_batch_size, - max_seq_len, - head_num, - size_per_head, - inter_size, - int8_mode, - stream, - cublas_wrapper, - allocator, - is_free_buffer_after_forward, - sparse) -{ -} - -template -GeluFfnLayerINT8::GeluFfnLayerINT8(GeluFfnLayerINT8 const& gelu_ffn_layer): FfnLayerINT8(gelu_ffn_layer) -{ -} - -template -void GeluFfnLayerINT8::invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) -{ - if (int8_mode_ == 1) { - invokeAddBiasGeluCol32(inter_buf_, - inter_int_buf_, - bias, - m, - inter_size_, - stream_, - &(scale_list->d_scale_list_[scale_list->p2_offset_ + 4 * hidden_units_]), - &(scale_list->d_scale_list_[44 + 2]), - &(scale_list->d_scale_list_[52 + 3])); - } - else if (int8_mode_ == 2 || int8_mode_ == 3) { -#ifdef SPARSITY_ENABLED - if (sparse_) { - invokeAddBiasGeluRow(inter_buf_, - (const int8_t*)inter_int_buf_, - bias, - m, - inter_size_, - stream_, - &(scale_list->d_scale_list_[48 + 1]), - &(scale_list->d_scale_list_[52 + 3])); - } - else { -#endif - invokeAddBiasGeluCol32(inter_buf_, - (const int8_t*)inter_int_buf_, - bias, - m, - inter_size_, - stream_, - &(scale_list->d_scale_list_[48 + 1]), - &(scale_list->d_scale_list_[52 + 3])); -#ifdef SPARSITY_ENABLED - } -#endif - } -} - -template class GeluFfnLayerINT8; -template class GeluFfnLayerINT8; - -template -ReluFfnLayerINT8::ReluFfnLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int int8_mode, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): - FfnLayerINT8(max_batch_size, - max_seq_len, - head_num, - size_per_head, - inter_size, - int8_mode, - stream, - cublas_wrapper, - allocator, - is_free_buffer_after_forward) -{ -} - -template -ReluFfnLayerINT8::ReluFfnLayerINT8(ReluFfnLayerINT8 const& relu_ffn_layer): FfnLayerINT8(relu_ffn_layer) -{ -} - -template -void ReluFfnLayerINT8::invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) -{ - // TODO -} - -template class ReluFfnLayerINT8; -template class ReluFfnLayerINT8; - -} // namespace turbomind diff --git a/src/turbomind/layers/FfnLayerINT8.h b/src/turbomind/layers/FfnLayerINT8.h deleted file mode 100644 index dc3e938208..0000000000 --- a/src/turbomind/layers/FfnLayerINT8.h +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "FfnINT8Weight.h" -#include "src/turbomind/kernels/activation_int8_kernels.h" -#include "src/turbomind/layers/BaseLayer.h" -#include "src/turbomind/utils/ScaleList.h" -#include "src/turbomind/utils/Tensor.h" -#include "src/turbomind/utils/allocator.h" -#include "src/turbomind/utils/cublasINT8MMWrapper.h" -#include "src/turbomind/utils/memory_utils.h" -#include - -namespace turbomind { - -template -class GeluFfnLayerINT8; - -template -class ReluFfnLayerINT8; - -template -class FfnLayerINT8: public BaseLayer { -private: - // buffer handling - size_t max_token_num_ = 0; - - // meta data - size_t head_num_; - size_t size_per_head_; - - // calculated data - size_t hidden_units_; - - void allocateBuffer() override; - void freeBuffer() override; - bool isValidTokenNum(size_t token_num); - -protected: - size_t inter_size_; - int int8_mode_; - bool sparse_; - - int* inter_int_buf_; - int8_t* inter_buf_; - virtual void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) = 0; - -public: - FfnLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int int8_mode, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); - - FfnLayerINT8(FfnLayerINT8 const& ffn_layer); - - ~FfnLayerINT8(); - - void forward(std::vector* output_tensors, - const std::vector* input_tensors, - const FfnWeight* ffn_weights); - - friend GeluFfnLayerINT8; - friend ReluFfnLayerINT8; -}; - -template -class GeluFfnLayerINT8: public FfnLayerINT8 { -public: - GeluFfnLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int int8_mode, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false); - - GeluFfnLayerINT8(GeluFfnLayerINT8 const& ffn_layer); - - ~GeluFfnLayerINT8() = default; - -private: - using FfnLayerINT8::inter_int_buf_; - using FfnLayerINT8::inter_buf_; - using FfnLayerINT8::inter_size_; - using FfnLayerINT8::stream_; - using FfnLayerINT8::int8_mode_; - using FfnLayerINT8::sparse_; - using FfnLayerINT8::hidden_units_; - void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) override; -}; - -template -class ReluFfnLayerINT8: public FfnLayerINT8 { -public: - ReluFfnLayerINT8(size_t max_batch_size, - size_t max_seq_len, - size_t head_num, - size_t size_per_head, - size_t inter_size, - int int8_mode, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); - - ReluFfnLayerINT8(ReluFfnLayerINT8 const& ffn_layer); - - ~ReluFfnLayerINT8() = default; - -private: - using FfnLayerINT8::inter_int_buf_; - using FfnLayerINT8::inter_buf_; - using FfnLayerINT8::inter_size_; - using FfnLayerINT8::stream_; - using FfnLayerINT8::int8_mode_; - using FfnLayerINT8::hidden_units_; - void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) override; -}; - -} // namespace turbomind diff --git a/src/turbomind/layers/attention_layers_fp8/AttentionFP8Weight.h b/src/turbomind/layers/attention_layers_fp8/AttentionFP8Weight.h deleted file mode 100644 index 68d807664d..0000000000 --- a/src/turbomind/layers/attention_layers_fp8/AttentionFP8Weight.h +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "src/turbomind/layers/attention_layers/AttentionWeight.h" -#include "src/turbomind/utils/ScaleList.h" - -namespace turbomind { - -template -struct AttentionFP8Weight: public AttentionWeight { - const float* qk_scale; - const float* qk_scale_inv; - float* qk_h_scale; - float* qk_h_scale_inv; - float* identity_scale; - float* identity_h_scale; -}; - -} // namespace turbomind diff --git a/src/turbomind/layers/attention_layers_fp8/BaseAttentionFP8Layer.h b/src/turbomind/layers/attention_layers_fp8/BaseAttentionFP8Layer.h deleted file mode 100644 index 956cd88066..0000000000 --- a/src/turbomind/layers/attention_layers_fp8/BaseAttentionFP8Layer.h +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include "src/turbomind/layers/BaseLayer.h" -#include "src/turbomind/layers/attention_layers/BaseAttentionLayer.h" -#include "src/turbomind/layers/attention_layers_fp8/AttentionFP8Weight.h" -#include "src/turbomind/utils/Tensor.h" -#include "src/turbomind/utils/allocator.h" -#include "src/turbomind/utils/cublasFP8MMWrapper.h" -#include "src/turbomind/utils/memory_utils.h" - -namespace turbomind { - -// template -// AttentionType getAttentionType(size_t size_per_head, const int sm, const bool remove_padding, const int max_seq_len, -// const bool is_fuse = true) -// { -// if (std::is_same::value && (sm == kSM_70 || sm == kSM_86 || sm == kSM_80 || sm == kSM_75 || sm == -// kSM_72) -// && size_per_head == 64 && max_seq_len <= 384 && is_fuse == true) { -// return remove_padding ? AttentionType::FUSED_MHA : AttentionType::FUSED_PADDED_MHA; -// } -// else { -// return remove_padding ? AttentionType::UNFUSED_MHA : AttentionType::UNFUSED_PADDED_MHA; -// } -// } - -template -class BaseAttentionFP8Layer: public BaseLayer { - -public: - virtual void forward(TensorMap* output_tensors, - TensorMap* input_tensors, - const AttentionFP8Weight* attention_weights) = 0; - - BaseAttentionFP8Layer(cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward, - bool sparse = false): - BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr, sparse) - { - } - virtual ~BaseAttentionFP8Layer() = default; -}; - -} // namespace turbomind diff --git a/src/turbomind/layers/attention_layers_fp8/CMakeLists.txt b/src/turbomind/layers/attention_layers_fp8/CMakeLists.txt deleted file mode 100644 index 4c3b8f8cec..0000000000 --- a/src/turbomind/layers/attention_layers_fp8/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -cmake_minimum_required(VERSION 3.8) diff --git a/src/turbomind/layers/attention_layers_int8/AttentionINT8Weight.h b/src/turbomind/layers/attention_layers_int8/AttentionINT8Weight.h deleted file mode 100644 index da6153db26..0000000000 --- a/src/turbomind/layers/attention_layers_int8/AttentionINT8Weight.h +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "src/turbomind/layers/attention_layers/AttentionWeight.h" -#include "src/turbomind/utils/ScaleList.h" - -namespace turbomind { - -template -struct AttentionINT8Weight: AttentionWeight { - ScaleList* scale_list_ptr; -}; - -} // namespace turbomind diff --git a/src/turbomind/layers/attention_layers_int8/CMakeLists.txt b/src/turbomind/layers/attention_layers_int8/CMakeLists.txt deleted file mode 100644 index 0d1a96fef3..0000000000 --- a/src/turbomind/layers/attention_layers_int8/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -cmake_minimum_required(VERSION 3.8) diff --git a/src/turbomind/layers/beam_search_layers/BaseBeamSearchLayer.cu b/src/turbomind/layers/beam_search_layers/BaseBeamSearchLayer.cu deleted file mode 100644 index 4b68622e95..0000000000 --- a/src/turbomind/layers/beam_search_layers/BaseBeamSearchLayer.cu +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/turbomind/kernels/beam_search_penalty_kernels.h" -#include "src/turbomind/layers/beam_search_layers/BaseBeamSearchLayer.h" -#include "src/turbomind/utils/cuda_utils.h" - -namespace turbomind { - -__global__ void update_indir_cache_kernel(int* tgt_indir_cache, - const int* src_indir_cache, - const int* beam_ids, - const bool* finished, - int start_step, - int batch_dim, - int local_batch_size, - int beam_width, - int max_seq_len, - int step) -{ - int time_step = threadIdx.x + blockIdx.x * blockDim.x; - int bb_id = threadIdx.y + blockIdx.y * blockDim.y; - const int batch_id = bb_id / beam_width; - const int beam_id = bb_id % beam_width; - - if (bb_id >= beam_width * local_batch_size || time_step >= min(step + 1, max_seq_len) || finished[bb_id]) { - return; - } - time_step += start_step; - const int time_step_circ = time_step % max_seq_len; - - const int src_beam = beam_ids[batch_id * beam_width + beam_id]; - - const uint tgt_offset = batch_id * beam_width * max_seq_len + beam_id * max_seq_len + time_step_circ; - const uint src_offset = batch_id * beam_width * max_seq_len + src_beam * max_seq_len + time_step_circ; - - tgt_indir_cache[tgt_offset] = (time_step == step) ? beam_id : src_indir_cache[src_offset]; -} - -void update_indir_cache_kernelLauncher(int* tgt_indir_cache, - const int* src_indir_cache, - const int* beam_ids, - const bool* finished, - int batch_dim, - int local_batch_size, - int beam_width, - int max_seq_len, - int step, - cudaStream_t stream) -{ - const dim3 block(32); - const int start_step = max(0, step + 1 - max_seq_len); - const int num_steps = min(step + 1, max_seq_len); - // Update indirections steps [start_step, step], included - const dim3 grid((num_steps + block.x - 1) / block.x, local_batch_size * beam_width); - update_indir_cache_kernel<<>>(tgt_indir_cache, - src_indir_cache, - beam_ids, - finished, - start_step, - batch_dim, - local_batch_size, - beam_width, - max_seq_len, - step); -} - -template -BaseBeamSearchLayer::BaseBeamSearchLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t beam_width, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): - DynamicDecodeBaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, nullptr), - vocab_size_(vocab_size), - vocab_size_padded_(vocab_size_padded) -{ -} - -template -BaseBeamSearchLayer::BaseBeamSearchLayer(BaseBeamSearchLayer const& beam_search_layer): - DynamicDecodeBaseLayer(beam_search_layer), - vocab_size_(beam_search_layer.vocab_size_), - vocab_size_padded_(beam_search_layer.vocab_size_padded_), - topk_softmax_workspace_size_(beam_search_layer.topk_softmax_workspace_size_) -{ -} - -template -BaseBeamSearchLayer::~BaseBeamSearchLayer() -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); - freeBuffer(); -} - -template -void BaseBeamSearchLayer::freeBuffer() -{ - if (is_allocate_buffer_) { - allocator_->free((void**)(&topk_softmax_workspace_)); - is_allocate_buffer_ = false; - } -} - -template -void BaseBeamSearchLayer::setup(const size_t batch_size, const size_t beam_width, TensorMap* runtime_args) -{ - // do nothing. -} - -template -void BaseBeamSearchLayer::forward(std::vector* output_tensors, const std::vector* input_tensors) -{ - // input_tensors: - // logits [local_batch_size, beam_width, vocab_size_padded] - // embedding_bias [vocab_size_padded] - // step [1] on cpu - // src_cache_indirection [local_batch_size, beam_width, max_seq_len] - // max_input_length [1] on cpu - // input_lengths [local_batch_size * beam_width] - // ite [1] on cpu - - // output_tensors: - // output_ids [max_seq_len, batch_size, beam_width] - // finished [local_batch_size * beam_width] - // cum_log_probs [local_batch_size * beam_width] - // parent_ids [max_seq_len, batch_size * beam_width] - // sequence_length [local_batch_size * beam_width] - // tgt_cache_indirection [local_batch_size, beam_width, max_seq_len] - - std::unordered_map input_tensors_map{{"logits", input_tensors->at(0)}, - {"embedding_bias", input_tensors->at(1)}, - {"step", input_tensors->at(2)}, - {"src_cache_indirection", input_tensors->at(4)}, - {"max_input_length", input_tensors->at(5)}, - {"input_lengths", input_tensors->at(6)}, - {"ite", input_tensors->at(7)}}; - - std::unordered_map output_tensors_map{{"output_ids", output_tensors->at(0)}, - {"finished", output_tensors->at(1)}, - {"cum_log_probs", output_tensors->at(2)}, - {"parent_ids", output_tensors->at(3)}, - {"sequence_length", output_tensors->at(4)}, - {"tgt_cache_indirection", output_tensors->at(5)}}; - forward(&output_tensors_map, &input_tensors_map); -} - -template -void BaseBeamSearchLayer::forward(std::unordered_map* output_tensors, - const std::unordered_map* input_tensors) -{ - TensorMap input_map(*input_tensors); - TensorMap output_map(*output_tensors); - forward(&output_map, &input_map); -} - -template -void BaseBeamSearchLayer::forward(TensorMap* output_tensors, TensorMap* input_tensors) -{ - // input_tensors: - // logits [local_batch_size, beam_width, vocab_size_padded] - // embedding_bias [vocab_size_padded] - // step [1] on cpu - // src_cache_indirection [local_batch_size, beam_width, max_seq_len] - // end_id [local_batch_size] - // max_input_length [1] on cpu - // input_lengths [local_batch_size * beam_width], optional - // ite [1] on cpu - // beam_search_diversity_rate [1] on cpu, optional - // temperature [1] on cpu, optional - // len_penalty [1] on cpu, optional - // repetition_penalty [1] on cpu, optional - // presence_penalty [1] on cpu, optional - // Only one of repetition and presence penalties is allowed. - // min_length [1] on cpu, int, optional - - // output_tensors: - // output_ids [max_seq_len, batch_size, beam_width] - // finished [local_batch_size * beam_width], optional - // cum_log_probs [local_batch_size * beam_width] - // parent_ids [max_seq_len, batch_size * beam_width] - // sequence_length [local_batch_size * beam_width], optional - // tgt_cache_indirection [local_batch_size, beam_width, max_seq_len] - // output_log_probs [max_seq_len, batch_size, beam_width], optional - // beam_hyps, optional - - FT_CHECK(input_tensors->size() >= 7); - FT_CHECK(output_tensors->size() >= 5); - const int batch_size = output_tensors->at("output_ids").shape[1]; - const int beam_width = output_tensors->at("output_ids").shape[2]; - allocateBuffer(batch_size, beam_width); - - const int step = input_tensors->at("step").getVal(); - const int ite = input_tensors->at("ite").getVal(); - const int local_batch_size = input_tensors->at("logits").shape[0]; - - const float temperature = input_tensors->getVal("temperature", 1.0f); - const T* embedding_bias = input_tensors->getPtr("embedding_bias", nullptr); - - RepetitionPenaltyType repetition_penalty_type = RepetitionPenaltyType::None; - float repetition_penalty = getDefaultPenaltyValue(repetition_penalty_type); - if (input_tensors->isExist("repetition_penalty") || input_tensors->isExist("presence_penalty")) { - FT_CHECK_WITH_INFO( - !(input_tensors->isExist("repetition_penalty") && input_tensors->isExist("presence_penalty")), - "Found ambiguous parameters repetition_penalty and presence_penalty which are mutually exclusive. " - "Please provide one of repetition_penalty or presence_penalty."); - repetition_penalty_type = input_tensors->isExist("repetition_penalty") ? RepetitionPenaltyType::Multiplicative : - RepetitionPenaltyType::Additive; - repetition_penalty = repetition_penalty_type == RepetitionPenaltyType::Multiplicative ? - input_tensors->getVal("repetition_penalty") : - input_tensors->getVal("presence_penalty"); - } - - invokeAddBiasApplyPenalties( - step, - input_tensors->at("logits").getPtr(), - output_tensors->at("output_ids") - .getPtrWithOffset((step - 1) * batch_size * beam_width + ite * local_batch_size * beam_width), - output_tensors->getPtr("output_ids"), - output_tensors->getPtr("parent_ids"), - input_tensors->getPtr("input_lengths", nullptr), - output_tensors->getPtr("sequence_length", nullptr), - embedding_bias, - ite, - input_tensors->getVal("max_input_length"), - local_batch_size, - batch_size, - beam_width, - vocab_size_, - vocab_size_padded_, - input_tensors->getPtr("end_id", nullptr), - temperature, - repetition_penalty, - repetition_penalty_type, - input_tensors->getVal("min_length", 0), - stream_); - sync_check_cuda_error(); - - invokeSoftMax(output_tensors, input_tensors); - - if (beam_width > 1) { - const int max_seq_len = output_tensors->at("output_ids").shape[0]; - - update_indir_cache_kernelLauncher( - output_tensors->at("tgt_cache_indirection").getPtr(), - input_tensors->at("src_cache_indirection").getPtr(), - output_tensors->at("parent_ids") - .getPtrWithOffset(+step * beam_width * batch_size + ite * local_batch_size * beam_width), - output_tensors->at("finished").getPtr(), - batch_size, - local_batch_size, - beam_width, - max_seq_len, - step, - stream_); - sync_check_cuda_error(); - } - sync_check_cuda_error(); - if (is_free_buffer_after_forward_) { - freeBuffer(); - } - sync_check_cuda_error(); -} - -template class BaseBeamSearchLayer; -template class BaseBeamSearchLayer; - -} // namespace turbomind diff --git a/src/turbomind/layers/beam_search_layers/BaseBeamSearchLayer.h b/src/turbomind/layers/beam_search_layers/BaseBeamSearchLayer.h deleted file mode 100644 index 42a4303463..0000000000 --- a/src/turbomind/layers/beam_search_layers/BaseBeamSearchLayer.h +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "src/turbomind/kernels/penalty_types.h" -#include "src/turbomind/layers/DynamicDecodeBaseLayer.h" - -namespace turbomind { - -template -class BaseBeamSearchLayer: public DynamicDecodeBaseLayer { -private: - void freeBuffer(); - -protected: - // meta data - size_t vocab_size_; - size_t vocab_size_padded_; - - size_t topk_softmax_workspace_size_; - void* topk_softmax_workspace_ = nullptr; - - virtual void allocateBuffer() = 0; - virtual void allocateBuffer(size_t batch_size, size_t beam_width) = 0; - virtual void invokeSoftMax(TensorMap* output_tensors, TensorMap* input_tensors) = 0; - -public: - BaseBeamSearchLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t beam_width, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); - - BaseBeamSearchLayer(BaseBeamSearchLayer const& beam_search_layer); - - ~BaseBeamSearchLayer(); - - void setup(const size_t batch_size, const size_t beam_width, TensorMap* runtime_args) override; - void forward(std::vector* output_tensors, - const std::vector* input_tensors) override; - void forward(std::unordered_map* output_tensors, - const std::unordered_map* input_tensors) override; - void forward(TensorMap* output_tensors, TensorMap* input_tensors) override; -}; - -void update_indir_cache_kernelLauncher(int* tgt_indir_cache, - const int* src_indir_cache, - const int* beam_ids, - const bool* finished, - int batch_dim, - int beam_width, - int max_seq_len, - int ite, - cudaStream_t stream); - -} // namespace turbomind diff --git a/src/turbomind/layers/beam_search_layers/BeamSearchLayer.cu b/src/turbomind/layers/beam_search_layers/BeamSearchLayer.cu deleted file mode 100644 index 80de5ea585..0000000000 --- a/src/turbomind/layers/beam_search_layers/BeamSearchLayer.cu +++ /dev/null @@ -1,354 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/turbomind/kernels/reduce_kernel_utils.cuh" -#include "src/turbomind/layers/beam_search_layers/BeamSearchLayer.h" - -namespace turbomind { - -template -__global__ void logProbAddCumLogProb(float* log_probs, - const T* logits, - const float* cum_log_probs, - const int* end_ids, - const bool* finished, - const int beam_width, - const int n) -{ - int bid = blockIdx.x; - bool finish = finished != nullptr ? finished[bid] : false; - int offset = bid * n; - - float max_val = -1 * FLT_MAX; - __shared__ float s_max_val; - __shared__ float s_sum_val; - - if (finish) { - for (int tid = threadIdx.x; tid < n; tid += blockDim.x) { - log_probs[offset + tid] = (tid == end_ids[bid / beam_width]) ? cum_log_probs[bid] : -FLT_MAX; - } - } - else { - for (int tid = threadIdx.x; tid < n; tid += blockDim.x) { - log_probs[offset + tid] = (float)(logits[offset + tid]); - max_val = max(max_val, log_probs[offset + tid]); - } - - max_val = blockReduceMax(max_val); - if (threadIdx.x == 0) { - s_max_val = max_val; - } - __syncthreads(); - - float sum_val = 0.0f; - for (int tid = threadIdx.x; tid < n; tid += blockDim.x) { - log_probs[offset + tid] = __expf(log_probs[offset + tid] - s_max_val); - sum_val += log_probs[offset + tid]; - } - - sum_val = blockReduceSum(sum_val); - if (threadIdx.x == 0) { - s_sum_val = sum_val + 1e-6f; - } - __syncthreads(); - - for (int tid = threadIdx.x; tid < n; tid += blockDim.x) { - log_probs[offset + tid] = logf(log_probs[offset + tid] / s_sum_val) + cum_log_probs[bid]; - } - } -} - -template -void invokeLogProbAddCumLogProb(float* log_probs, - const T* logits, - const float* cum_log_probs, - const int* end_ids, - const bool* finished, - const int m, - const int beam_width, - const int n, - cudaStream_t stream) -{ - dim3 grid(m); - dim3 block(min(n, 1024)); - /*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */ - logProbAddCumLogProb<<>>( - log_probs, logits, cum_log_probs, end_ids, finished, beam_width, n); -} - -template -__global__ void updateStatesKernel(T* log_probs, - T* cum_log_probs, - float* output_log_probs, - bool* finished, - int* parent_ids, - int* sequence_length, - int* word_ids, - int* output_ids, - BeamHypotheses beam_hyps, - const int local_batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids) -{ - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < local_batch_size * beam_width; - index += blockDim.x * gridDim.x) { - - int batch_id = index / beam_width; - sequence_length[index] = finished[index] ? sequence_length[index] : sequence_length[index] + 1; - - int beam_id = (word_ids[index] / vocab_size) % beam_width; - int word_id = word_ids[index] % vocab_size; - - if (output_log_probs != nullptr) { - // get the cum_log_probs of previous run - output_log_probs[index] = log_probs[batch_id * beam_width * vocab_size + beam_id * vocab_size + word_id] - - cum_log_probs[batch_id * beam_width + beam_id]; - } - cum_log_probs[index] = log_probs[batch_id * beam_width * vocab_size + beam_id * vocab_size + word_id]; - sequence_length[index] = sequence_length[batch_id * beam_width + beam_id]; - finished[index] = word_id == end_ids[batch_id] ? 1 : 0; - parent_ids[index] = beam_id; - word_ids[index] = word_id; - output_ids[index] = word_id; - - if (beam_hyps.num_beams != nullptr) { - if (beam_hyps.num_beams[beam_hyps.ite * beam_hyps.local_batch_size + batch_id] == beam_width) { - for (int i = 0; i < beam_width; i++) { - finished[batch_id * beam_width + i] = true; - } - } - } - } -} - -void invokeUpdateStates(float* log_probs, - float* cum_log_probs, - float* output_log_probs, - bool* finished, - int* parent_ids, - int* sequence_length, - int* word_ids, - int* output_ids, - BeamHypotheses* beam_hyps, - const int local_batch_size, - const int beam_width, - const int vocab_size, - const int* end_ids, - cudaStream_t stream) -{ - dim3 grid((int)ceil(local_batch_size * beam_width * 1.0 / 256)); - dim3 block(256); - - updateStatesKernel<<>>(log_probs, - cum_log_probs, - output_log_probs, - finished, - parent_ids, - sequence_length, - word_ids, - output_ids, - *beam_hyps, - local_batch_size, - beam_width, - vocab_size, - end_ids); -} - -template -void BeamSearchLayer::invokeSoftMax(TensorMap* output_tensors, TensorMap* input_tensors) -{ - // input_tensors: - // logits [local_batch_size, beam_width, vocab_size_padded] - // embedding_bias [vocab_size_padded] - // step [1] on cpu - // src_cache_indirection [local_batch_size, beam_width, max_seq_len] - // max_input_length [1] on cpu - // input_lengths [local_batch_size * beam_width] - // ite [1] on cpu - // beam_search_diversity_rate [1] on cpu, optional - // temperature [1] on cpu, optional - // len_penalty [1] on cpu, optional - // repetition_penalty [1] on cpu, optional - - // output_tensors: - // output_ids [max_seq_len, batch_size, beam_width] - // finished [local_batch_size * beam_width] - // cum_log_probs [local_batch_size * beam_width] - // parent_ids [max_seq_len, batch_size * beam_width] - // sequence_length [local_batch_size * beam_width] - // tgt_cache_indirection [local_batch_size, beam_width, max_seq_len] - // output_log_probs [max_seq_len, batch_size * beam_width], optional - // beam_hyps, optional - - FT_CHECK(input_tensors->size() >= 7); - FT_CHECK(output_tensors->size() >= 6); - - const int batch_size = output_tensors->at("output_ids").shape[1]; - const int beam_width = output_tensors->at("output_ids").shape[2]; - const int step = input_tensors->at("step").getVal(); - const int ite = input_tensors->at("ite").getVal(); - const int local_batch_size = input_tensors->at("logits").shape[0]; - const float diversity_rate = input_tensors->isExist("beam_search_diversity_rate") ? - input_tensors->at("beam_search_diversity_rate").getVal() : - 0.0f; - const float length_penalty = - input_tensors->isExist("len_penalty") ? input_tensors->at("len_penalty").getVal() : 0.0f; - - const int id_offset = step * batch_size * beam_width + ite * local_batch_size * beam_width; - invokeLogProbAddCumLogProb(float_log_prob_buf_, - input_tensors->at("logits").getPtr(), - output_tensors->at("cum_log_probs").getPtr(), - input_tensors->at("end_id").getPtr(), - output_tensors->at("finished").getPtr(), - local_batch_size * beam_width, - beam_width, - vocab_size_padded_, - stream_); - sync_check_cuda_error(); - - BeamHypotheses beam_hyps; - if (output_tensors->isExist("beam_hyps") && diversity_rate == 0.0f) { - beam_hyps = *((BeamHypotheses*)(output_tensors->at("beam_hyps").getPtr())); - beam_hyps.step = step; - beam_hyps.ite = ite; - beam_hyps.local_batch_size = local_batch_size; - beam_hyps.batch_size = output_tensors->at("output_ids").shape[1]; - beam_hyps.max_seq_len = output_tensors->at("output_ids").shape[0]; - beam_hyps.output_ids_src = output_tensors->at("output_ids").getPtr(); - beam_hyps.parent_ids_src = output_tensors->at("parent_ids").getPtr(); - beam_hyps.sequence_lengths_src = output_tensors->at("sequence_length").getPtr(); - beam_hyps.length_penalty = length_penalty; - } - - invokeTopkBeamSearch(topk_softmax_workspace_, - topk_softmax_workspace_size_, - float_log_prob_buf_, - output_tensors->at("output_ids").getPtrWithOffset(id_offset), - &beam_hyps, - output_tensors->at("finished").getPtr(), - output_tensors->isExist("sequence_length") ? - output_tensors->at("sequence_length").getPtr() : - (int*)nullptr, - local_batch_size, - beam_width, - vocab_size_padded_, - diversity_rate, - length_penalty, - input_tensors->at("end_id").getPtr(), - stream_); - sync_check_cuda_error(); - - invokeUpdateStates(float_log_prob_buf_, - output_tensors->at("cum_log_probs").getPtr(), - output_tensors->getPtrWithOffset("output_log_probs", id_offset, nullptr), - output_tensors->at("finished").getPtr(), - output_tensors->at("parent_ids").getPtrWithOffset(id_offset), - output_tensors->at("sequence_length").getPtr(), - output_tensors->at("output_ids").getPtrWithOffset(id_offset), - output_tensors->at("output_ids").getPtrWithOffset(id_offset), - &beam_hyps, - local_batch_size, - beam_width, - vocab_size_padded_, - input_tensors->at("end_id").getPtr(), - stream_); - sync_check_cuda_error(); -} - -template -void BeamSearchLayer::allocateBuffer() -{ - FT_CHECK(false); -} - -template -void BeamSearchLayer::allocateBuffer(size_t batch_size, size_t beam_width) -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); - - invokeTopkBeamSearch(nullptr, - topk_softmax_workspace_size_, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - batch_size, - beam_width, - vocab_size_padded_, - 0.0f, // diversity rate - 0.0f, // length penalty - nullptr, - stream_); - topk_softmax_workspace_ = reinterpret_cast(allocator_->reMalloc( - topk_softmax_workspace_, - topk_softmax_workspace_size_ + sizeof(float) * batch_size * beam_width * vocab_size_padded_, - false)); - float_log_prob_buf_ = (float*)((char*)topk_softmax_workspace_ + topk_softmax_workspace_size_); - is_allocate_buffer_ = true; -} - -template -BeamSearchLayer::BeamSearchLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t beam_width, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): - BaseBeamSearchLayer(max_batch_size, - head_num, - size_per_head, - beam_width, - vocab_size, - vocab_size_padded, - end_id, - diversity_rate, - temperature, - len_penalty, - repetition_penalty, - stream, - cublas_wrapper, - allocator, - is_free_buffer_after_forward) -{ -} - -template -BeamSearchLayer::BeamSearchLayer(BeamSearchLayer const& beam_search_layer): - BaseBeamSearchLayer(beam_search_layer) -{ -} - -template -BeamSearchLayer::~BeamSearchLayer() -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); -} - -template class BeamSearchLayer; -template class BeamSearchLayer; - -} // namespace turbomind diff --git a/src/turbomind/layers/beam_search_layers/BeamSearchLayer.h b/src/turbomind/layers/beam_search_layers/BeamSearchLayer.h deleted file mode 100644 index 64dacf1ca0..0000000000 --- a/src/turbomind/layers/beam_search_layers/BeamSearchLayer.h +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "src/turbomind/kernels/beam_search_topk_kernels.h" -#include "src/turbomind/layers/beam_search_layers/BaseBeamSearchLayer.h" -#include - -namespace turbomind { - -template -class BeamSearchLayer: public BaseBeamSearchLayer { -private: - // meta data - using BaseBeamSearchLayer::vocab_size_; - using BaseBeamSearchLayer::vocab_size_padded_; - - using BaseBeamSearchLayer::topk_softmax_workspace_size_; - using BaseBeamSearchLayer::topk_softmax_workspace_; - - void allocateBuffer() override; - void allocateBuffer(size_t batch_size, size_t beam_width) override; - void invokeSoftMax(TensorMap* output_tensors, TensorMap* input_tensors) override; - - using BaseBeamSearchLayer::stream_; - using BaseBeamSearchLayer::is_allocate_buffer_; - using BaseBeamSearchLayer::allocator_; - - float* float_log_prob_buf_ = nullptr; - -protected: -public: - BeamSearchLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t beam_width, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); - - BeamSearchLayer(BeamSearchLayer const& beam_search_layer); - - ~BeamSearchLayer(); -}; - -} // namespace turbomind diff --git a/src/turbomind/layers/beam_search_layers/CMakeLists.txt b/src/turbomind/layers/beam_search_layers/CMakeLists.txt deleted file mode 100644 index 2708722334..0000000000 --- a/src/turbomind/layers/beam_search_layers/CMakeLists.txt +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -cmake_minimum_required(VERSION 3.8) - -add_library(BaseBeamSearchLayer STATIC BaseBeamSearchLayer.cu) -set_property(TARGET BaseBeamSearchLayer PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET BaseBeamSearchLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(BaseBeamSearchLayer PUBLIC -lcudart beam_search_penalty_kernels cuda_utils) - -add_library(OnlineBeamSearchLayer STATIC OnlineBeamSearchLayer.cu) -set_property(TARGET OnlineBeamSearchLayer PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET OnlineBeamSearchLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(OnlineBeamSearchLayer PUBLIC -lcudart BaseBeamSearchLayer online_softmax_beamsearch_kernels) - -add_library(BeamSearchLayer STATIC BeamSearchLayer.cu) -set_property(TARGET BeamSearchLayer PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET BeamSearchLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(BeamSearchLayer PUBLIC -lcudart BaseBeamSearchLayer beam_search_topk_kernels) diff --git a/src/turbomind/layers/beam_search_layers/OnlineBeamSearchLayer.cu b/src/turbomind/layers/beam_search_layers/OnlineBeamSearchLayer.cu deleted file mode 100644 index 7dcfc99f4a..0000000000 --- a/src/turbomind/layers/beam_search_layers/OnlineBeamSearchLayer.cu +++ /dev/null @@ -1,249 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/turbomind/layers/beam_search_layers/OnlineBeamSearchLayer.h" - -namespace turbomind { - -static const int SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS = 128; -static const int MAX_K = 4; - -template -__global__ void update_kernel(bool* finished, - int* parent_ids, - int* sequence_length, - int* word_ids, - int* output_ids, - BeamHypotheses beam_hyps, - const int vocab_size, - const int* end_ids, - const int local_batch_size, - const int beam_width) -{ - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < local_batch_size * beam_width; - index += blockDim.x * gridDim.x) { - - int batch_id = index / beam_width; - sequence_length[index] = finished[index] ? sequence_length[index] : sequence_length[index] + 1; - - int beam_id = (word_ids[index] / vocab_size) % beam_width; - int word_id = word_ids[index] % vocab_size; - - sequence_length[index] = sequence_length[batch_id * beam_width + beam_id]; - finished[index] = word_id == end_ids[index / beam_width] ? 1 : 0; - parent_ids[index] = beam_id; - word_ids[index] = word_id; - output_ids[index] = word_id; - - if (beam_hyps.num_beams != nullptr) { - if (beam_hyps.num_beams[beam_hyps.ite * beam_hyps.local_batch_size + batch_id] == beam_width) { - for (int i = 0; i < beam_width; i++) { - finished[batch_id * beam_width + i] = true; - } - } - } - } -} - -void invokeUpdate(bool* finished, - int* parent_ids, - int* sequence_length, - int* word_ids, - int* output_ids, - BeamHypotheses* beam_hyps, - const int local_batch_size, - const int beam_width, - const int vocab_size_padded, - const int* end_ids, - cudaStream_t stream) -{ - dim3 grid((int)ceil(local_batch_size * beam_width * 1.0 / 256)); - dim3 block(256); - - update_kernel<<>>(finished, - parent_ids, - sequence_length, - word_ids, - output_ids, - *beam_hyps, - vocab_size_padded, - end_ids, - local_batch_size, - beam_width); -} - -template -void OnlineBeamSearchLayer::invokeSoftMax(TensorMap* output_tensors, TensorMap* input_tensors) -{ - // input_tensors: - // logits [local_batch_size, beam_width, vocab_size_padded] - // embedding_bias [vocab_size_padded] - // step [1] on cpu - // src_cache_indirection [local_batch_size, beam_width, max_seq_len] - // max_input_length [1] on cpu - // input_lengths [local_batch_size * beam_width] - // ite [1] on cpu - // beam_search_diversity_rate [1] on cpu, optional - // temperature [1] on cpu, optional - // len_penalty [1] on cpu, optional - // repetition_penalty [1] on cpu, optional - - // output_tensors: - // output_ids [max_seq_len, batch_size, beam_width] - // finished [local_batch_size * beam_width] - // cum_log_probs [local_batch_size * beam_width] - // parent_ids [max_seq_len, batch_size * beam_width] - // sequence_length [local_batch_size * beam_width] - // tgt_cache_indirection [local_batch_size, beam_width, max_seq_len] - // output_log_probs [max_seq_len, batch_size, beam_width] - - FT_CHECK(input_tensors->size() >= 7); - FT_CHECK(output_tensors->size() >= 6); - - const int batch_size = output_tensors->at("output_ids").shape[1]; - const int beam_width = output_tensors->at("output_ids").shape[2]; - const int step = input_tensors->at("step").getVal(); - const int ite = input_tensors->at("ite").getVal(); - const int local_batch_size = input_tensors->at("logits").shape[0]; - const float diversity_rate = input_tensors->isExist("beam_search_diversity_rate") ? - input_tensors->at("beam_search_diversity_rate").getVal() : - 0.0f; - const float length_penalty = - input_tensors->isExist("len_penalty") ? input_tensors->at("len_penalty").getVal() : 0.0f; - - const int id_offset = step * batch_size * beam_width + local_batch_size * ite * beam_width; - - BeamHypotheses beam_hyps; - if (output_tensors->isExist("beam_hyps")) { - beam_hyps = *((BeamHypotheses*)(output_tensors->at("beam_hyps").getPtr())); - beam_hyps.step = step; - beam_hyps.ite = ite; - beam_hyps.local_batch_size = local_batch_size; - beam_hyps.batch_size = output_tensors->at("output_ids").shape[1]; - beam_hyps.max_seq_len = output_tensors->at("output_ids").shape[0]; - beam_hyps.output_ids_src = output_tensors->at("output_ids").getPtr(); - beam_hyps.parent_ids_src = output_tensors->at("parent_ids").getPtr(); - beam_hyps.sequence_lengths_src = output_tensors->at("sequence_length").getPtr(); - beam_hyps.log_probs_src = output_tensors->getPtr("output_log_probs", nullptr); - beam_hyps.length_penalty = length_penalty; - beam_hyps.end_ids = input_tensors->at("end_id").getPtr(); - } - - invokeTopkSoftMax(input_tensors->at("logits").getPtr(), - (const T*)(nullptr), - output_tensors->at("finished").getPtr(), - output_tensors->at("sequence_length").getPtr(), - output_tensors->at("cum_log_probs").getPtr(), - output_tensors->getPtrWithOffset("output_log_probs", id_offset, nullptr), - output_tensors->at("output_ids").getPtrWithOffset(id_offset), - topk_softmax_workspace_, - topk_softmax_workspace_size_, - &beam_hyps, - local_batch_size, - beam_width, - vocab_size_padded_, - input_tensors->at("end_id").getPtr(), - diversity_rate, - length_penalty, - stream_); - sync_check_cuda_error(); - - invokeUpdate(output_tensors->at("finished").getPtr(), - output_tensors->at("parent_ids").getPtrWithOffset(id_offset), - output_tensors->at("sequence_length").getPtr(), - output_tensors->at("output_ids").getPtrWithOffset(id_offset), - output_tensors->at("output_ids").getPtrWithOffset(id_offset), - &beam_hyps, - local_batch_size, - beam_width, - vocab_size_padded_, - input_tensors->at("end_id").getPtr(), - stream_); - sync_check_cuda_error(); -} - -template -void OnlineBeamSearchLayer::allocateBuffer() -{ - FT_CHECK(false); -} - -template -void OnlineBeamSearchLayer::allocateBuffer(size_t batch_size, size_t beam_width) -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); - // we need to check 2 * beam_width candidates each time - // 64 is the max beam width we support now. - topk_softmax_workspace_size_ = - (size_t)(ceil(batch_size * 64 * (64 * 2) / 4.) * 4 * 2 - + ceil(batch_size * (64 * 2) * SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS * (2 * (MAX_K * 2) + 2) / 4.) * 4); - - topk_softmax_workspace_ = reinterpret_cast( - allocator_->reMalloc(topk_softmax_workspace_, sizeof(float) * topk_softmax_workspace_size_, true)); - is_allocate_buffer_ = true; -} - -template -OnlineBeamSearchLayer::OnlineBeamSearchLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t beam_width, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward): - BaseBeamSearchLayer(max_batch_size, - head_num, - size_per_head, - beam_width, - vocab_size, - vocab_size_padded, - end_id, - diversity_rate, - temperature, - len_penalty, - repetition_penalty, - stream, - cublas_wrapper, - allocator, - is_free_buffer_after_forward) -{ -} - -template -OnlineBeamSearchLayer::OnlineBeamSearchLayer(OnlineBeamSearchLayer const& beam_search_layer): - BaseBeamSearchLayer(beam_search_layer) -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); -} - -template -OnlineBeamSearchLayer::~OnlineBeamSearchLayer() -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); -} - -template class OnlineBeamSearchLayer; -template class OnlineBeamSearchLayer; - -} // namespace turbomind diff --git a/src/turbomind/layers/beam_search_layers/OnlineBeamSearchLayer.h b/src/turbomind/layers/beam_search_layers/OnlineBeamSearchLayer.h deleted file mode 100644 index 15727d1925..0000000000 --- a/src/turbomind/layers/beam_search_layers/OnlineBeamSearchLayer.h +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "src/turbomind/kernels/online_softmax_beamsearch_kernels.h" -#include "src/turbomind/layers/beam_search_layers/BaseBeamSearchLayer.h" - -namespace turbomind { - -template -class OnlineBeamSearchLayer: public BaseBeamSearchLayer { -private: - // meta data - using BaseBeamSearchLayer::vocab_size_; - using BaseBeamSearchLayer::vocab_size_padded_; - - using BaseBeamSearchLayer::topk_softmax_workspace_size_; - using BaseBeamSearchLayer::topk_softmax_workspace_; - - void allocateBuffer() override; - void allocateBuffer(size_t batch_size, size_t beam_width) override; - void invokeSoftMax(TensorMap* output_tensors, TensorMap* input_tensors) override; - - using BaseBeamSearchLayer::stream_; - using BaseBeamSearchLayer::is_allocate_buffer_; - using BaseBeamSearchLayer::allocator_; - -protected: -public: - OnlineBeamSearchLayer(size_t max_batch_size, - size_t head_num, - size_t size_per_head, - size_t beam_width, - size_t vocab_size, - size_t vocab_size_padded, - int end_id, - float diversity_rate, - float temperature, - float len_penalty, - float repetition_penalty, - cudaStream_t stream, - cublasMMWrapper* cublas_wrapper, - IAllocator* allocator, - bool is_free_buffer_after_forward); - - OnlineBeamSearchLayer(OnlineBeamSearchLayer const& beam_search_layer); - - ~OnlineBeamSearchLayer(); -}; - -} // namespace turbomind diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index dde6c18aa8..0fbc361911 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -23,7 +23,6 @@ set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(Llama PUBLIC -lcudart cublasMMWrapper DynamicDecodeLayer - BaseBeamSearchLayer activation_kernels decoder_masked_multihead_attention bert_preprocess_kernels diff --git a/src/turbomind/models/llama/prefix_cache.cu b/src/turbomind/models/llama/prefix_cache.cu deleted file mode 100644 index 253b90fc30..0000000000 --- a/src/turbomind/models/llama/prefix_cache.cu +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "src/turbomind/models/llama/prefix_cache.h" - -// -> -template -__global__ void insertKeyCache(T* key_cache, const T* src, int L, int H, int Dx, int s, int X, size_t S) -{ - for (int i = threadIdx.x; i < L * H * Dx * s * X; i += blockDim.x) { - int i0 = i / X; - int x = i % X; - - int i1 = i0 / s; - int t = i0 % s; - - size_t j = (i1 * S + t) * X + x; - key_cache[j] = src[i]; - } -} - -template -void invokeInsertKeyCache(T* key_cache, const T* src, int L, int H, int Dx, int s, int X, int S, cudaStream_t st) -{ - insertKeyCache<<<1, 512, 0, st>>>(key_cache, src, L, H, Dx, s, X, S); -} -template void -invokeInsertKeyCache(float* key_cache, const float* src, int L, int H, int Dx, int s, int X, int S, cudaStream_t st); -template void -invokeInsertKeyCache(half* key_cache, const half* src, int L, int H, int Dx, int s, int X, int S, cudaStream_t st); - -// -> -template -__global__ void insertValueCache(T* value_cache, const T* src, int L, int H, int s, int D, size_t S) -{ - for (int i = threadIdx.x; i < L * H * s * D; i += blockDim.x) { - int i0 = i / D; - int d = i % D; - - int i1 = i0 / s; - int t = i0 % s; - - size_t j = (i1 * S + t) * D + d; - value_cache[j] = src[i]; - } -} - -template -void invokeInsertValueCache(T* value_cache, const T* src, int L, int H, int s, int D, int S, cudaStream_t st) -{ - insertValueCache<<<1, 512, 0, st>>>(value_cache, src, L, H, s, D, S); -} -template void -invokeInsertValueCache(float* value_cache, const float* src, int L, int H, int s, int D, int S, cudaStream_t st); -template void -invokeInsertValueCache(half* value_cache, const half* src, int L, int H, int s, int D, int S, cudaStream_t st); diff --git a/src/turbomind/models/llama/prefix_cache.h b/src/turbomind/models/llama/prefix_cache.h deleted file mode 100644 index a00ef864a4..0000000000 --- a/src/turbomind/models/llama/prefix_cache.h +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include - -template -void invokeInsertKeyCache(T* key_cache, const T* src, int L, int H, int Dx, int s, int X, int S, cudaStream_t st); - -template -void invokeInsertValueCache(T* value_cache, const T* src, int L, int H, int s, int D, int S, cudaStream_t st);