Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 16 additions & 84 deletions paddle/fluid/inference/tensorrt/convert/fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,91 +333,23 @@ class FcOpConverter : public OpConverter {
if (!engine_->with_dynamic_shape()) {
x_num_col_dims--;
}
// If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can
// not add Shuffle layer in ernie's multihead.
if (x_dim.nbDims == 4 && x_num_col_dims == 1) {
if (enable_int8 || support_int8) {
// add conv1x1 layer
nvinfer1::DimsHW nv_ksize(1, 1);
auto* fc_layer_int8 = TRT_ENGINE_ADD_LAYER(engine_,
Convolution,
*X,
n_output,
nv_ksize,
weight.get(),
bias.get());
if (activation_type == "relu") {
fc_layer_int8->setName(
("ernie_fc_op_int8: Convolution (Output: " + output_name + ")")
.c_str());
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("out_threshold"),
true,
platform::errors::InvalidArgument(
"must have out threshold in fc layers in int8 mode"));
float out_scale = 0;
if (enable_int8) {
out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
} else {
out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out"));
}
engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0),
out_scale);
nvinfer1::IActivationLayer* relu_layer_int8 =
TRT_ENGINE_ADD_LAYER(engine_,
Activation,
*(fc_layer_int8->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer_int8,
"relu_after_ernie_fc_int8",
{output_name},
test_mode);
} else {
RreplenishLayerAndOutput(fc_layer_int8,
"ernie_fc_op_int8: Convolution",
{output_name},
test_mode);
}
} else {
// add fc layer
auto* fc_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, FullyConnected, *X, n_output, weight.get(), bias.get());
if (activation_type == "relu") {
fc_layer_float->setName(
("ernie_fc_op_float: (Output: " + output_name + ")").c_str());
nvinfer1::IActivationLayer* relu_layer_float =
TRT_ENGINE_ADD_LAYER(engine_,
Activation,
*(fc_layer_float->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer_float,
"relu_after_ernie_fc_float",
{output_name},
test_mode);
} else {
RreplenishLayerAndOutput(
fc_layer_float, "ernie_fc_op_float", {output_name}, test_mode);
}
}
} else { // need reshape input before and after fc
PADDLE_ENFORCE_GT(
x_dim.nbDims,
x_num_col_dims,
platform::errors::InvalidArgument(
"Params and input dims mismatch. Paddle-TRT FC "
"converter expects x_dim.nbDims > x_num_col_dims, but "
"x_dim.nbDims : %d, x_num_col_dims : %d.",
x_dim.nbDims,
x_num_col_dims));
auto* reshape_before_fc_layer =
reshape_before_fc(X, x_dim, x_num_col_dims, output_name);
auto* reshape_itensor = reshape_before_fc_layer->getOutput(0);
if (enable_int8 || support_int8) {
engine_->SetTensorDynamicRange(reshape_itensor, in_scale);
}
regist_fc(reshape_itensor, n_output, weight, bias);
PADDLE_ENFORCE_GT(
x_dim.nbDims,
x_num_col_dims,
platform::errors::InvalidArgument(
"Params and input dims mismatch. Paddle-TRT FC "
"converter expects x_dim.nbDims > x_num_col_dims, but "
"x_dim.nbDims : %d, x_num_col_dims : %d.",
x_dim.nbDims,
x_num_col_dims));
// need reshape input before and after fc
auto* reshape_before_fc_layer =
reshape_before_fc(X, x_dim, x_num_col_dims, output_name);
auto* reshape_itensor = reshape_before_fc_layer->getOutput(0);
if (enable_int8 || support_int8) {
engine_->SetTensorDynamicRange(reshape_itensor, in_scale);
}
regist_fc(reshape_itensor, n_output, weight, bias);
}
};

Expand Down
Loading