Skip to content

Conversation

@ceci3
Copy link
Contributor

@ceci3 ceci3 commented Aug 16, 2021

PR types

Performance optimization

PR changes

Others

Describe

matmul can be set to int8 in multihead
envs: cuda10.2, cudnn 8.0 trt: 7.2.2.3

layer prec prec prec
mul int8 int8 FP16
skip ln FP16 FP16 FP16
qkv int8 FP16 FP16
acc 0.7699(3857.0/5010) (/5010) (/5010)
qps(T4) 2538seq/s 2310seq/s 1898seq/s
latency(T4, bs=40) 20.9ms 22.9ms 32.5ms

int8计算结果会因为trt底层kernel选择不同结果会有些浮动。

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@ceci3 ceci3 changed the title [paddle-TRT]update ernie int8 [WIP][paddle-TRT]update ernie int8 Aug 16, 2021
@ceci3 ceci3 changed the title [WIP][paddle-TRT]update ernie int8 [paddle-TRT]support matmul set to int8 in multihead Aug 23, 2021
Copy link
Contributor

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add unit test

@qingqing01 qingqing01 requested a review from Superjomn August 24, 2021 09:08
BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold"));
auto out_scale2 =
BOOST_GET_CONST(float, mul2_op_desc->GetAttr("out_threshold"));
BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out_scale0out_scale1out_scale2分别是q,k,v的scale,将其最大值设置为 multihead_op的"out_threshold", multihead_op中包含FC和QKVToContextPlugin,这个"out_threhold"是用作FC的output scale.
问题是:能給"out_threhold"换个名字么?原因如下:

  1. 当前命名不太直观;容易被误以为multihead_op的output scale;
  2. "out_threhold"用于fusion pass向TRT op convert传递信息,所以这个命名不用受限于量化模型的存储格式。另外,也不受OP Maker中对attr的约束。

multihead_op_desc.SetAttr("dp_probs", qkv_plugin_scale);
}
auto* matmul_qk_op_desc = matmul_qk->Op();
if (matmul_qk_op_desc->HasAttr("X_scale")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果这个判断为False,line897~902还有必要么?


float* weight_data = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8");
bool qkv_plugin_int8 =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名更直观些?比如包含关键字"qkv2context".

dp_probs = out_scale / 127.0;
dp_probs =
BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0;
dp_probs = dp_probs / 127.0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的计算有点绕,需要说明下。

Comment on lines 169 to 176
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
if (enable_int8) {
type = static_cast<int>(nvinfer1::DataType::kHALF);
if (qkv_plugin_int8) {
type = static_cast<int>(nvinfer1::DataType::kINT8);
} else {
type = static_cast<int>(nvinfer1::DataType::kHALF);
}
Copy link
Contributor

@wanghaoshuang wanghaoshuang Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确定这些条件(line169-line176)判断没有冗余么?比如line175可以删除么?

cryoco
cryoco previously approved these changes Aug 27, 2021
Copy link

@cryoco cryoco left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ceci3 ceci3 merged commit 0043fa8 into PaddlePaddle:develop Aug 30, 2021
@ceci3 ceci3 deleted the update_ernie_int8 branch August 30, 2021 02:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants