Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer decoding support fuse qkv #1455

Merged
merged 15 commits into from
Dec 17, 2021
Merged
2 changes: 1 addition & 1 deletion paddlenlp/ops/faster_transformer/sample/decoding_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="./sample/config/decoding.sample.yaml",
default="./faster_transformer/sample/config/decoding.sample.yaml",
type=str,
help="Path of the config file. ")
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/ops/faster_transformer/src/cublas_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ CublasHandle* CublasHandle::GetInstance() {
CublasHandle::~CublasHandle() {
cublasDestroy(cublas_handle_);
cublasLtDestroy(cublaslt_handle_);
}
}
2 changes: 1 addition & 1 deletion paddlenlp/ops/faster_transformer/src/cublas_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ class CublasHandle {
cublasLtHandle_t cublaslt_handle_;

~CublasHandle();
};
};
12 changes: 8 additions & 4 deletions paddlenlp/ops/faster_transformer/src/fusion_decoding_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ std::vector<paddle::Tensor> DecodingForward(
const int64_t& max_len,
const float& beam_search_diversity_rate,
const bool& rel_len,
const float& alpha) {
const float& alpha,
const bool& fuse_qkv) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不另外加这个fuse_qkv的attr了吧,直接根据size来判断吧,也保证对之前模型的兼容性

Copy link
Contributor Author

@FrostML FrostML Dec 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks.

int batch_size = input.shape()[0];
int max_out_len = rel_len ? max_len + input.shape()[1] : max_len;

Expand Down Expand Up @@ -158,7 +159,8 @@ std::vector<paddle::Tensor> DecodingForward(
eos_id,
max_out_len,
beam_search_diversity_rate,
alpha);
alpha,
fuse_qkv);
} else {
PD_THROW("Not implemented place. Only GPU is supported. ");
}
Expand Down Expand Up @@ -211,7 +213,8 @@ std::vector<std::vector<int64_t>> DecodingInferShape(
const int64_t& max_len,
const float& beam_search_diversity_rate,
const bool& rel_len,
const float& alpha) {
const float& alpha,
const bool& fuse_qkv) {
int batch_size = input_shape[0];

std::vector<int64_t> output_dims;
Expand Down Expand Up @@ -331,7 +334,8 @@ PD_BUILD_OP(fusion_decoding)
"max_len: int64_t",
"beam_search_diversity_rate: float",
"rel_len: bool",
"alpha: float"})
"alpha: float",
"fuse_qkv: bool"})
.SetKernelFn(PD_KERNEL(DecodingForward))
.SetInferShapeFn(PD_INFER_SHAPE(DecodingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(DecodingInferDtype));
15 changes: 10 additions & 5 deletions paddlenlp/ops/faster_transformer/src/fusion_decoding_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ limitations under the License. */
#include <sstream>
#include <vector>

#include "cublas_handle.h"
#include "fastertransformer/cuda/cub/cub.cuh"
#include "fusion_decoding_op.h"
#include "pd_traits.h"
Expand Down Expand Up @@ -77,6 +76,7 @@ std::vector<paddle::Tensor> decoding_kernel(
const int64_t& max_seq_len_,
const float& beam_search_diversity_rate_,
const float& alpha,
const bool& fuse_qkv,
cudaStream_t stream) {
int beam_width_ = (decoding_strategy == "beam_search" ||
decoding_strategy == "beam_search_v2")
Expand Down Expand Up @@ -261,7 +261,8 @@ std::vector<paddle::Tensor> decoding_kernel(
start_id_,
end_id_,
beam_search_diversity_rate_,
true); // is_fuse_topk_softMax
true, // is_fuse_topk_softMax
fuse_qkv);

decoding_beam_search_->forward(params, decoding_params);

Expand All @@ -283,7 +284,7 @@ std::vector<paddle::Tensor> decoding_kernel(
end_id_,
beam_search_diversity_rate_,
true, // is_fuse_topk_softMax
false, // is_fuse_qkv
fuse_qkv,
true, // keep_alive_beam
alpha);

Expand All @@ -307,7 +308,8 @@ std::vector<paddle::Tensor> decoding_kernel(
start_id_,
end_id_,
candidate_num_,
probability_threshold_);
probability_threshold_,
fuse_qkv);

decoding_sampling_->forward(params, decoding_params);

Expand Down Expand Up @@ -371,7 +373,8 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
const int& eos_id,
const int64_t& max_len,
const float& beam_search_diversity_rate,
const float& alpha) {
const float& alpha,
const bool& fuse_qkv) {
auto stream = input.stream();

cublasSetStream(CublasHandle::GetInstance()->cublas_handle_, stream);
Expand Down Expand Up @@ -430,6 +433,7 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
max_len,
beam_search_diversity_rate,
alpha,
fuse_qkv,
stream);
break;
}
Expand Down Expand Up @@ -484,6 +488,7 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
max_len,
beam_search_diversity_rate,
alpha,
fuse_qkv,
stream);
break;
}
Expand Down
5 changes: 4 additions & 1 deletion paddlenlp/ops/faster_transformer/src/fusion_decoding_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License. */
#include <string>
#include <vector>

#include "cublas_handle.h"

#include "fastertransformer/decoding_beamsearch.h"
#include "fastertransformer/decoding_sampling.h"
#include "fastertransformer/open_decoder.h"
Expand Down Expand Up @@ -77,4 +79,5 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
const int& eos_id,
const int64_t& max_len,
const float& beam_search_diversity_rate,
const float& alpha);
const float& alpha,
const bool& fuse_qkv);
12 changes: 8 additions & 4 deletions paddlenlp/ops/faster_transformer/src/fusion_force_decoding_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ std::vector<paddle::Tensor> DecodingForward(
const int64_t& max_len,
const float& beam_search_diversity_rate,
const bool& rel_len,
const float& alpha) {
const float& alpha,
const bool& fuse_qkv) {
int batch_size = input.shape()[0];
int max_out_len = rel_len ? max_len + input.shape()[1] : max_len;

Expand Down Expand Up @@ -159,7 +160,8 @@ std::vector<paddle::Tensor> DecodingForward(
eos_id,
max_out_len,
beam_search_diversity_rate,
alpha);
alpha,
fuse_qkv);
} else {
PD_THROW("Not implemented place. Only GPU is supported. ");
}
Expand Down Expand Up @@ -213,7 +215,8 @@ std::vector<std::vector<int64_t>> DecodingInferShape(
const int64_t& max_len,
const float& beam_search_diversity_rate,
const bool& rel_len,
const float& alpha) {
const float& alpha,
const bool& fuse_qkv) {
int batch_size = input_shape[0];

std::vector<int64_t> output_dims;
Expand Down Expand Up @@ -334,7 +337,8 @@ PD_BUILD_OP(fusion_force_decoding)
"max_len: int64_t",
"beam_search_diversity_rate: float",
"rel_len: bool",
"alpha: float"})
"alpha: float",
"fuse_qkv: bool"})
.SetKernelFn(PD_KERNEL(DecodingForward))
.SetInferShapeFn(PD_INFER_SHAPE(DecodingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(DecodingInferDtype));
83 changes: 39 additions & 44 deletions paddlenlp/ops/faster_transformer/src/fusion_force_decoding_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,19 @@ std::vector<paddle::Tensor> decoding_kernel(
paddle::Tensor& output_ids,
paddle::Tensor& parent_ids,
paddle::Tensor& sequence_length,
std::string decoding_strategy,
int beam_size,
int topk,
float topp,
int head_num_,
int size_per_head_,
int num_layer_,
int start_id_,
int end_id_,
int64_t max_seq_len_,
float beam_search_diversity_rate_,
float alpha,
cublasHandle_t cublas_handle_,
cublasLtHandle_t cublaslt_handle_,
const std::string& decoding_strategy,
const int& beam_size,
const int& topk,
const float& topp,
const int& head_num_,
const int& size_per_head_,
const int& num_layer_,
const int& start_id_,
const int& end_id_,
const int64_t& max_seq_len_,
const float& beam_search_diversity_rate_,
const float& alpha,
const bool& fuse_qkv,
cudaStream_t stream) {
int beam_width_ = (decoding_strategy == "beam_search" ||
decoding_strategy == "beam_search_v2")
Expand All @@ -119,8 +118,9 @@ std::vector<paddle::Tensor> decoding_kernel(
typedef typename traits_::data_t data_t_;

DecodingInitParam<DataType_> decoding_params;
decoding_params.cublas_handle = cublas_handle_;
decoding_params.cublaslt_handle = cublaslt_handle_;
decoding_params.cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
decoding_params.cublaslt_handle =
CublasHandle::GetInstance()->cublaslt_handle_;

decoding_params.output_ids = output_ids.mutable_data<int>(input.place());
decoding_params.parent_ids = parent_ids.mutable_data<int>(input.place());
Expand Down Expand Up @@ -158,8 +158,8 @@ std::vector<paddle::Tensor> decoding_kernel(

for (int i = 0; i < num_layer_; i++) {
params[i].stream = stream;
params[i].cublas_handle = cublas_handle_;
params[i].cublaslt_handle = cublaslt_handle_;
params[i].cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
params[i].cublaslt_handle = CublasHandle::GetInstance()->cublaslt_handle_;

if (decoding_strategy == "beam_search" ||
decoding_strategy == "beam_search_v2") {
Expand Down Expand Up @@ -292,7 +292,8 @@ std::vector<paddle::Tensor> decoding_kernel(
start_id_,
end_id_,
beam_search_diversity_rate_,
true); // is_fuse_topk_softMax
true, // is_fuse_topk_softMax
fuse_qkv); // is_fuse_qkv

decoding_beam_search_->forward(params, decoding_params);

Expand All @@ -314,7 +315,7 @@ std::vector<paddle::Tensor> decoding_kernel(
end_id_,
beam_search_diversity_rate_,
true, // is_fuse_topk_softMax
false, // is_fuse_qkv
fuse_qkv, // is_fuse_qkv
true, // keep_alive_beam
alpha);

Expand All @@ -338,7 +339,8 @@ std::vector<paddle::Tensor> decoding_kernel(
start_id_,
end_id_,
candidate_num_,
probability_threshold_);
probability_threshold_,
fuse_qkv);

decoding_sampling_->forward(params, decoding_params);

Expand Down Expand Up @@ -392,24 +394,21 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
paddle::Tensor& output_ids,
paddle::Tensor& parent_ids,
paddle::Tensor& sequence_length,
std::string decoding_strategy,
int beam_size,
int topk,
float topp,
int n_head,
int size_per_head,
int num_layer,
int bos_id,
int eos_id,
int64_t max_len,
float beam_search_diversity_rate,
float alpha) {
const std::string& decoding_strategy,
const int& beam_size,
const int& topk,
const float& topp,
const int& n_head,
const int& size_per_head,
const int& num_layer,
const int& bos_id,
const int& eos_id,
const int64_t& max_len,
const float& beam_search_diversity_rate,
const float& alpha,
const bool& fuse_qkv) {
auto stream = input.stream();
cublasHandle_t cublas_handle_;
cublasCreate(&cublas_handle_);
cublasLtHandle_t cublaslt_handle_;
cublasLtCreate(&cublaslt_handle_);
cublasSetStream(cublas_handle_, stream);
cublasSetStream(CublasHandle::GetInstance()->cublas_handle_, stream);

std::vector<paddle::Tensor> ret;

Expand Down Expand Up @@ -466,8 +465,7 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
max_len,
beam_search_diversity_rate,
alpha,
cublas_handle_,
cublaslt_handle_,
fuse_qkv,
stream);
break;
}
Expand Down Expand Up @@ -523,8 +521,7 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
max_len,
beam_search_diversity_rate,
alpha,
cublas_handle_,
cublaslt_handle_,
fuse_qkv,
stream);
break;
}
Expand All @@ -536,7 +533,5 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
}
}

cublasDestroy(cublas_handle_);
cublasLtDestroy(cublaslt_handle_);
return ret;
}
27 changes: 15 additions & 12 deletions paddlenlp/ops/faster_transformer/src/fusion_force_decoding_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License. */
#include <string>
#include <vector>

#include "cublas_handle.h"

#include "fastertransformer/decoding_beamsearch.h"
#include "fastertransformer/decoding_sampling.h"
#include "fastertransformer/open_decoder.h"
Expand Down Expand Up @@ -67,15 +69,16 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
paddle::Tensor& output_ids,
paddle::Tensor& parent_ids,
paddle::Tensor& sequence_length,
std::string decoding_strategy,
int beam_size,
int topk,
float topp,
int n_head,
int size_per_head,
int num_layer,
int bos_id,
int eos_id,
int64_t max_len,
float beam_search_diversity_rate,
float alpha);
const std::string& decoding_strategy,
const int& beam_size,
const int& topk,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上int值的参数,是不需要加const和引用的。因为没有任何加速意义。
string的话通过引用,是会将copy从string降低到指针(32bit)的拷贝。
但是int呢,你不管加不加引用,拷贝的指针地址和拷贝int值是一样的。
正因为如此,所以你加不加const都一样了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truth. 这里加 const 单纯是不需要修改的形参都习惯性加上 const,加上引用是想让形参列表看起来统一。我新的修改删除了引用,不过对于这样不需要修改的形参,还是建议保留 const。

const float& topp,
const int& n_head,
const int& size_per_head,
const int& num_layer,
const int& bos_id,
const int& eos_id,
const int64_t& max_len,
const float& beam_search_diversity_rate,
const float& alpha,
const bool& fuse_qkv);
Loading