-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformer decoding support fuse qkv #1455
Changes from 8 commits
bbfe3a9
4c888f0
0a2a5d9
71dd70f
3fe3437
a57fc66
8fb6739
c213ea4
46392f2
384e23a
262fe27
f301919
1b6b70a
9bef786
f807a6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,4 +55,4 @@ class CublasHandle { | |
cublasLtHandle_t cublaslt_handle_; | ||
|
||
~CublasHandle(); | ||
}; | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,8 @@ limitations under the License. */ | |
#include <string> | ||
#include <vector> | ||
|
||
#include "cublas_handle.h" | ||
|
||
#include "fastertransformer/decoding_beamsearch.h" | ||
#include "fastertransformer/decoding_sampling.h" | ||
#include "fastertransformer/open_decoder.h" | ||
|
@@ -67,15 +69,16 @@ std::vector<paddle::Tensor> DecodingCUDAForward( | |
paddle::Tensor& output_ids, | ||
paddle::Tensor& parent_ids, | ||
paddle::Tensor& sequence_length, | ||
std::string decoding_strategy, | ||
int beam_size, | ||
int topk, | ||
float topp, | ||
int n_head, | ||
int size_per_head, | ||
int num_layer, | ||
int bos_id, | ||
int eos_id, | ||
int64_t max_len, | ||
float beam_search_diversity_rate, | ||
float alpha); | ||
const std::string& decoding_strategy, | ||
const int& beam_size, | ||
const int& topk, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 理论上int值的参数,是不需要加const和引用的。因为没有任何加速意义。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Truth. 这里加 const 单纯是不需要修改的形参都习惯性加上 const,加上引用是想让形参列表看起来统一。我新的修改删除了引用,不过对于这样不需要修改的形参,还是建议保留 const。 |
||
const float& topp, | ||
const int& n_head, | ||
const int& size_per_head, | ||
const int& num_layer, | ||
const int& bos_id, | ||
const int& eos_id, | ||
const int64_t& max_len, | ||
const float& beam_search_diversity_rate, | ||
const float& alpha, | ||
const bool& fuse_qkv); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不另外加这个fuse_qkv的attr了吧,直接根据size来判断吧,也保证对之前模型的兼容性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks.