From a14bb2fee9e50540e46e69a39551946a6dacef53 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Sat, 6 Jul 2024 00:07:48 +0800 Subject: [PATCH] Optimizing the performance of fused_layer_norm and top_p_sampling operators (#65711) * optim fused_layer_norm and top_p_sampling * update * update * update * support hip * fix comment * update --- .../kernels/fusion/gpu/blha_get_max_len.cu | 5 +- .../fusion/gpu/fused_layernorm_kernel.cu | 2 +- .../phi/kernels/gpu/top_p_sampling_kernel.cu | 591 ++++++------------ 3 files changed, 188 insertions(+), 410 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu b/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu index 78a46e16989e50..9073575cf4f90c 100644 --- a/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu +++ b/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu @@ -65,4 +65,7 @@ PD_REGISTER_KERNEL(blha_get_max_len, ALL_LAYOUT, phi::fusion::BlhaGetMaxLenKernel, int, - int64_t) {} + int64_t) { + kernel->OutputAt(0).SetBackend(phi::Backend::CPU); + kernel->OutputAt(1).SetBackend(phi::Backend::CPU); +} diff --git a/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu index 221019531a5486..7d0e69a38fbb13 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu @@ -537,7 +537,7 @@ inline GPU(Error_t) // Note(Zhengzekang): We choose a fixed blocksize to avoid layernorm diff, by // RichardWooSJTU. - constexpr int block_size_conf_1 = 128; + constexpr int block_size_conf_1 = 512; int dev = 0; { diff --git a/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu b/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu index 3b2c3c28d587f5..a6826c7e6f3717 100644 --- a/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu +++ b/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu @@ -35,6 +35,12 @@ namespace cub = hipcub; #include "paddle/phi/kernels/funcs/top_k_function_cuda.h" #include "paddle/phi/kernels/primitive/functor_primitives.h" +#ifdef PADDLE_WITH_HIP +#define GPU(str) hip##str +#else +#define GPU(str) cu##str +#endif + // #define DEBUG_TOPP namespace phi { @@ -327,22 +333,16 @@ __device__ inline T exponential_transform(T val, T lambda) { #endif } -template +template __global__ void KeMatrixTopPBeamTopK(const T* src, const T* threshold, + GPU(randState_t) * states, T* top_ps, int64_t* out_id, // topk id T* out_val, // topk val int64_t* topk_ids, T* topk_scores, int vocab_size, - const int64_t* seed, - const uint64_t seed_num, - const uint64_t seed_offset, int* count_iter, int* count_iter_begin, const int k, @@ -354,29 +354,6 @@ __global__ void KeMatrixTopPBeamTopK(const T* src, const float threshold_now = threshold ? static_cast(threshold[bid]) : 0.f; -#ifdef PADDLE_WITH_HIP - hiprandState_t state; - if constexpr (INFER_SEEDS) { - hiprand_init(static_cast(seed[bid]), tid, seed_offset, &state); - } else { - if (need_batch_random) { - hiprand_init(seed_num, bid * blockDim.x + tid, seed_offset, &state); - } else { - hiprand_init(seed_num, tid, seed_offset, &state); - } - } -#else - curandState_t state; - if constexpr (INFER_SEEDS) { - curand_init(static_cast(seed[bid]), tid, seed_offset, &state); - } else { - if (need_batch_random) { - curand_init(seed_num, bid * blockDim.x + tid, seed_offset, &state); - } else { - curand_init(seed_num, tid, seed_offset, &state); - } - } -#endif int top_num = TopPBeamTopK; float top_p_num = static_cast(top_ps[bid]); const int offset = bid * vocab_size; @@ -426,21 +403,15 @@ __global__ void KeMatrixTopPBeamTopK(const T* src, float max_val = 0.f; int max_id = -1; for (int i = 0; i < TopPBeamTopK; i++) { - if (k > 0 && i < k) { + if (i < k) { topk_ids_now[i] = static_cast(beam_max[i].id); topk_scores_now[i] = beam_max[i].v; } if (!flag) { float val = static_cast(beam_max[i].v); sum_prob += val; -#ifdef PADDLE_WITH_HIP - float random_ratio = - exponential_transform(hiprand_uniform(&state), 1.0f); -#else float random_ratio = - exponential_transform(curand_uniform(&state), 1.0f); -#endif - + exponential_transform(GPU(rand_uniform)(states + bid), 1.0f); float random_val = (val >= threshold_now ? val : 0.f) / random_ratio; if (max_val < random_val) { max_val = random_val; @@ -466,26 +437,20 @@ __global__ void KeMatrixTopPBeamTopK(const T* src, } } -template -__global__ void KeMatrixTopPBeamTopKTruncated(const T* src, - const T* threshold, - T* top_ps, - int64_t* out_id, // topk id - T* out_val, // topk val - int64_t* topk_ids, - T* topk_scores, - int vocab_size, - const int64_t* seed, - const uint64_t seed_num, - const uint64_t seed_offset, - int* count_iter, - int* count_iter_begin, - const int k, - const bool need_batch_random) { +template +__global__ void KeMatrixTopPBeamTopKFt(const T* src, + const T* threshold, + GPU(randState_t) * states, + T* top_ps, + int64_t* out_id, // topk id + T* out_val, // topk val + int64_t* topk_ids, + T* topk_scores, + int vocab_size, + int* count_iter, + int* count_iter_begin, + const int k, + const bool need_batch_random) { const int tid = threadIdx.x; const int wid = tid / 32; const int lane = tid % 32; @@ -493,29 +458,6 @@ __global__ void KeMatrixTopPBeamTopKTruncated(const T* src, const float threshold_now = threshold ? static_cast(threshold[bid]) : 0.f; -#ifdef PADDLE_WITH_HIP - hiprandState_t state; - if constexpr (INFER_SEEDS) { - hiprand_init(static_cast(seed[bid]), tid, seed_offset, &state); - } else { - if (need_batch_random) { - hiprand_init(seed_num, bid * blockDim.x + tid, seed_offset, &state); - } else { - hiprand_init(seed_num, tid, seed_offset, &state); - } - } -#else - curandState_t state; - if constexpr (INFER_SEEDS) { - curand_init(static_cast(seed[bid]), tid, seed_offset, &state); - } else { - if (need_batch_random) { - curand_init(seed_num, bid * blockDim.x + tid, seed_offset, &state); - } else { - curand_init(seed_num, tid, seed_offset, &state); - } - } -#endif int top_num = TopPBeamTopK; float top_p_num = static_cast(top_ps[bid]); int64_t* topk_ids_now = nullptr; @@ -558,16 +500,12 @@ __global__ void KeMatrixTopPBeamTopKTruncated(const T* src, } if (tid == 0) { count_iter_begin[bid] = count_iter[bid]; -#ifdef PADDLE_WITH_HIP - float rand_top_p = hiprand_uniform(&state) * top_p_num; -#else - float rand_top_p = curand_uniform(&state) * top_p_num; -#endif + float rand_top_p = GPU(rand_uniform)(states + bid) * top_p_num; top_ps[bid] = (T)rand_top_p; float sum_prob = 0.0f; bool flag = false; for (int i = 0; i < TopPBeamTopK; i++) { - if (k > 0 && i < k) { + if (i < k) { topk_ids_now[i] = static_cast(beam_max[i].id); topk_scores_now[i] = beam_max[i].v; } @@ -633,133 +571,61 @@ template void DispatchKeMatrixTopPBeamTopK(const Context& dev_ctx, const T* src, const T* threshold, + GPU(randState_t) * states, T* top_ps, int64_t* out_id, // topk id T* out_val, // topk val int64_t* topk_ids, T* topk_scores, int vocab_size, - const int64_t* seed, - const uint64_t seed_num, - const uint64_t seed_offset, int* count_iter, int* count_iter_begin, const int k, const int bs, - const std::string& mode, - const bool need_batch_random) { + const bool need_batch_random, + const std::string& mode) { int BlockSize = GetBlockSize(vocab_size); - if (mode == "truncated") { - if (seed) { - switch (BlockSize) { - FIXED_BLOCK_DIM( - KeMatrixTopPBeamTopKTruncated - <<>>(src, - threshold, - top_ps, - out_id, - out_val, - topk_ids, - topk_scores, - vocab_size, - seed, - 0, - 0, - count_iter, - count_iter_begin, - k, - need_batch_random)); - default: - PD_THROW( - "the input data shape has error in the topp_beam_topk kernel."); - } - } else { - switch (BlockSize) { - FIXED_BLOCK_DIM( - KeMatrixTopPBeamTopKTruncated - <<>>(src, - threshold, - top_ps, - out_id, - out_val, - topk_ids, - topk_scores, - vocab_size, - nullptr, - seed_num, - seed_offset, - count_iter, - count_iter_begin, - k, - need_batch_random)); - default: - PD_THROW( - "the input data shape has error in the topp_beam_topk kernel."); - } + if (mode == "truncate") { + switch (BlockSize) { + FIXED_BLOCK_DIM( + KeMatrixTopPBeamTopKFt + <<>>(src, + threshold, + states, + top_ps, + out_id, + out_val, + topk_ids, + topk_scores, + vocab_size, + count_iter, + count_iter_begin, + k, + need_batch_random)); + default: + PD_THROW( + "the input data shape has error in the topp_beam_topk kernel."); } } else { - if (seed) { - switch (BlockSize) { - FIXED_BLOCK_DIM( - KeMatrixTopPBeamTopK - <<>>(src, - threshold, - top_ps, - out_id, - out_val, - topk_ids, - topk_scores, - vocab_size, - seed, - 0, - 0, - count_iter, - count_iter_begin, - k, - need_batch_random)); - default: - PD_THROW( - "the input data shape has error in the topp_beam_topk kernel."); - } - } else { - switch (BlockSize) { - FIXED_BLOCK_DIM( - KeMatrixTopPBeamTopK - <<>>(src, - threshold, - top_ps, - out_id, - out_val, - topk_ids, - topk_scores, - vocab_size, - nullptr, - seed_num, - seed_offset, - count_iter, - count_iter_begin, - k, - need_batch_random)); - default: - PD_THROW( - "the input data shape has error in the topp_beam_topk kernel."); - } + switch (BlockSize) { + FIXED_BLOCK_DIM( + KeMatrixTopPBeamTopK + <<>>(src, + threshold, + states, + top_ps, + out_id, + out_val, + topk_ids, + topk_scores, + vocab_size, + count_iter, + count_iter_begin, + k, + need_batch_random)); + default: + PD_THROW( + "the input data shape has error in the topp_beam_topk kernel."); } } } @@ -792,16 +658,14 @@ struct MaxOp { } }; -template +template __global__ void topp_sampling(T* sorted_probs, int64_t* sorted_id, T* out_val, int64_t* out_id, const T* top_ps, const T* threshold, - const int64_t* seed, - const uint64_t seed_num, - const uint64_t seed_offset, + GPU(randState_t) * states, const int p_num, const int vocab_size, const bool need_batch_random, @@ -833,30 +697,6 @@ __global__ void topp_sampling(T* sorted_probs, BlockPrefixCallbackOp prefix_op(0); int offset = bid * vocab_size; - -#ifdef PADDLE_WITH_HIP - hiprandState_t state; - if constexpr (INFER_SEEDS) { - hiprand_init(static_cast(seed[bid]), tid, seed_offset, &state); - } else { - if (need_batch_random) { - hiprand_init(seed_num, bid * blockDim.x + tid, seed_offset, &state); - } else { - hiprand_init(seed_num, tid, seed_offset, &state); - } - } -#else - curandState_t state; - if constexpr (INFER_SEEDS) { - curand_init(static_cast(seed[bid]), tid, seed_offset, &state); - } else { - if (need_batch_random) { - curand_init(seed_num, bid * blockDim.x + tid, seed_offset, &state); - } else { - curand_init(seed_num, tid, seed_offset, &state); - } - } -#endif #ifdef DEBUG_TOPP if (tid == 0) { printf( @@ -880,11 +720,8 @@ __global__ void topp_sampling(T* sorted_probs, if (thread_offset < p_t || (thread_offset >= p_t && thread_offset - thread_count < p_t)) { -#ifdef PADDLE_WITH_HIP - float random_ratio = exponential_transform(hiprand_uniform(&state), 1.0f); -#else - float random_ratio = exponential_transform(curand_uniform(&state), 1.0f); -#endif + float random_ratio = + exponential_transform(GPU(rand_uniform)(states + bid), 1.0f); float tmp_val = (thread_count >= threshold_now ? thread_count : 0.f) / random_ratio; if (static_cast(max_thread_pair.v) < tmp_val) { @@ -929,6 +766,14 @@ __global__ void topp_sampling(T* sorted_probs, } } __syncthreads(); + if (stop_shared == 0) { + if (tid == 0) { + out_id[bid] = sorted_id[offset]; + out_val[bid] = sorted_probs[offset]; + } + return; + } + Pair max_pair = BlockReduce(temp_storage_reduce) .Reduce(max_thread_pair, MaxOp>()); if (tid == 0) { @@ -945,21 +790,19 @@ __global__ void topp_sampling(T* sorted_probs, } } -template -__global__ void topp_sampling_truncated(T* sorted_probs, - int64_t* sorted_id, - T* out_val, - int64_t* out_id, - const T* top_ps, - const T* threshold, - const int64_t* seed, - const uint64_t seed_num, - const uint64_t seed_offset, - const int p_num, - const int vocab_size, - const bool need_batch_random, - int* count_iter, - int* count_iter_begin) { +template +__global__ void topp_sampling_ft(T* sorted_probs, + int64_t* sorted_id, + T* out_val, + int64_t* out_id, + const T* top_ps, + const T* threshold, + GPU(randState_t) * states, + const int p_num, + const int vocab_size, + const bool need_batch_random, + int* count_iter, + int* count_iter_begin) { __shared__ int stop_shared; __shared__ float rand_p; const int tid = threadIdx.x; @@ -1095,28 +938,11 @@ __global__ void topp_sampling_truncated(T* sorted_probs, BlockReduce(temp_storage_reduce).Reduce(threshold_id, MaxOp()); #ifdef PADDLE_WITH_HIP hiprandStatePhilox4_32_10_t rng; - if constexpr (INFER_SEEDS) { - hiprand_init( - static_cast(seed[bid]), tid, seed_offset, &rng); - } else { - if (need_batch_random) { - hiprand_init(seed_num, bid * blockDim.x + tid, seed_offset, &rng); - } else { - hiprand_init(seed_num, tid, seed_offset, &rng); - } - } + hiprand_init(bid * blockDim.x + tid, tid, 0, &rng); int random_id = hiprand(&rng) % (max_id + 1); #else curandStatePhilox4_32_10_t rng; - if constexpr (INFER_SEEDS) { - curand_init(static_cast(seed[bid]), tid, seed_offset, &rng); - } else { - if (need_batch_random) { - curand_init(seed_num, bid * blockDim.x + tid, seed_offset, &rng); - } else { - curand_init(seed_num, tid, seed_offset, &rng); - } - } + curand_init(bid * blockDim.x + tid, tid, 0, &rng); int random_id = curand(&rng) % (max_id + 1); #endif out_id[bid] = sorted_id[offset + random_id]; @@ -1137,143 +963,79 @@ void DispatchTopPSampling(const Context& dev_ctx, int64_t* out_id, const T* top_ps, const T* threshold, - const int64_t* seed, - const uint64_t seed_num, - const uint64_t seed_offset, + GPU(randState_t) * states, const int p_num, const int vocab_size, const int bs, + const bool need_batch_random, int* count_iter, int* count_iter_begin, - const std::string& mode, - const bool need_batch_random) { + const std::string& mode) { int BlockSize = GetBlockSize(vocab_size); - if (mode == "truncated") { - if (seed) { - switch (BlockSize) { - FIXED_BLOCK_DIM( - topp_sampling_truncated - <<>>(sorted_probs, - sorted_id, - out_val, - out_id, - top_ps, - threshold, - seed, - 0, - 0, - p_num, - vocab_size, - need_batch_random, - count_iter, - count_iter_begin)); - default: - PD_THROW( - "the input data shape has error in the topp_sampling kernel."); - } - } else { - switch (BlockSize) { - FIXED_BLOCK_DIM( - topp_sampling_truncated - <<>>(sorted_probs, - sorted_id, - out_val, - out_id, - top_ps, - threshold, - nullptr, - seed_num, - seed_offset, - p_num, - vocab_size, - need_batch_random, - count_iter, - count_iter_begin)); - default: - PD_THROW( - "the input data shape has error in the topp_sampling kernel."); - } + if (mode == "truncate") { + switch (BlockSize) { + FIXED_BLOCK_DIM( + topp_sampling_ft + <<>>(sorted_probs, + sorted_id, + out_val, + out_id, + top_ps, + threshold, + states, + p_num, + vocab_size, + need_batch_random, + count_iter, + count_iter_begin)); + default: + PD_THROW("the input data shape has error in the topp_sampling kernel."); } } else { - if (seed) { - switch (BlockSize) { - FIXED_BLOCK_DIM( - topp_sampling - <<>>(sorted_probs, - sorted_id, - out_val, - out_id, - top_ps, - threshold, - seed, - 0, - 0, - p_num, - vocab_size, - need_batch_random, - count_iter, - count_iter_begin)); - default: - PD_THROW( - "the input data shape has error in the topp_sampling kernel."); - } - } else { - switch (BlockSize) { - FIXED_BLOCK_DIM( - topp_sampling - <<>>(sorted_probs, - sorted_id, - out_val, - out_id, - top_ps, - threshold, - nullptr, - seed_num, - seed_offset, - p_num, - vocab_size, - need_batch_random, - count_iter, - count_iter_begin)); - default: - PD_THROW( - "the input data shape has error in the topp_sampling kernel."); - } + switch (BlockSize) { + FIXED_BLOCK_DIM( + topp_sampling + <<>>(sorted_probs, + sorted_id, + out_val, + out_id, + top_ps, + threshold, + states, + p_num, + vocab_size, + need_batch_random, + count_iter, + count_iter_begin)); + default: + PD_THROW("the input data shape has error in the topp_sampling kernel."); } } } -__global__ void set_sorted_num(int* need_sorted_num, int bs) { - *need_sorted_num = bs; -} - -#ifdef PADDLE_WITH_HIP -template -__global__ void print_kernel(T* input, int size) { - for (int i = 0; i < size; i++) { - printf("["); - if (i != size - 1) { - printf("%f, ", static_cast(input[i])); - } else { - printf("%f]\n", static_cast(input[i])); - } +__global__ void setup_kernel(GPU(randState_t) * state, + int64_t* seed, + const int bs) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { + GPU(rand_init)(static_cast(seed[i]), 0, 0, &state[i]); } } -#else -template -__global__ void print_kernel(T* input, int size) { - for (int i = 0; i < size; i++) { - std::stringstream ss; - ss << "["; - if (i != size - 1) { - ss << static_cast(input[i]) << ", "; + +__global__ void setup_kernel(GPU(randState_t) * state, + const uint64_t seed, + const uint64_t offset, + const int bs, + const bool need_batch_random) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { + if (need_batch_random) { + GPU(rand_init)(seed, i, offset, &state[i]); } else { - ss << static_cast(input[i]) << "]\n"; + GPU(rand_init)(seed, 0, offset, &state[i]); } - VLOG(0) << ss.str(); } } -#endif template T* SafeGetTensorPtr(const DenseTensor& t) { @@ -1347,17 +1109,35 @@ void TopPSamplingKernel(const Context& dev_ctx, PD_THROW("the input data shape has error in the FillIndex kernel."); } int64_t* infer_seed = SafeGetTensorPtr(topp_seed); + + GPU(randState_t) * states{nullptr}; + phi::Allocator::AllocationPtr rand_states_buf{nullptr}; + rand_states_buf = phi::memory_utils::Alloc( + dev_ctx.GetPlace(), + bs * sizeof(GPU(randState_t)), + phi::Stream(reinterpret_cast(dev_ctx.stream()))); + states = reinterpret_cast(rand_states_buf->ptr()); + uint64_t seed_now = seed; uint64_t offset = 0; bool need_batch_random = false; - if (seed_now == -1) { - VLOG(1) << "use paddle seed gen"; - need_batch_random = true; - auto gen_cuda = dev_ctx.GetGenerator(); - uint64_t increment = x.numel() * 4; - auto seed_offset = gen_cuda->IncrementOffset(increment); - seed_now = seed_offset.first; - offset = seed_offset.second; + + if (infer_seed) { + setup_kernel<<<1, 256, 0, cu_stream>>>(states, infer_seed, bs); + } else { + if (seed == -1) { + need_batch_random = true; + auto gen_cuda = dev_ctx.GetGenerator(); + uint64_t increment = ps.numel() * 4; + auto seed_offset = gen_cuda->IncrementOffset(increment); + seed = seed_offset.first; + offset = seed_offset.second; + setup_kernel<<<1, 256, 0, cu_stream>>>( + states, seed, offset, bs, need_batch_random); + } else { + setup_kernel<<<1, 256, 0, cu_stream>>>( + states, seed, offset, bs, need_batch_random); + } } DenseTensor count_iter; @@ -1371,27 +1151,25 @@ void TopPSamplingKernel(const Context& dev_ctx, T* threshold_data = SafeGetTensorPtr(threshold); constexpr int TopKMaxLength = 2; - constexpr int TopPBeamTopK = 10; + constexpr int TopPBeamTopK = 5; DispatchKeMatrixTopPBeamTopK( dev_ctx, x.data(), threshold_data, + states, ps_now.data(), ids_ptr, out_ptr, topk_ids_data, topk_scores_data, vocab_size, - infer_seed, - seed_now, - offset, count_iter.data(), count_iter_begin.data(), k, bs, - mode, - need_batch_random); + need_batch_random, + mode); size_t temp_storage_bytes = 0; @@ -1446,17 +1224,14 @@ void TopPSamplingKernel(const Context& dev_ctx, ids_ptr, ps_now.data(), threshold_data, - infer_seed, - seed_now, - offset, + states, p_num, vocab_size, bs, + need_batch_random, count_iter.data(), count_iter_begin.data(), - mode, - need_batch_random); - return; + mode); } } // namespace phi