Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
566b820
trt affine channel converter
zlsh80826 Mar 15, 2021
1502991
add trt affine channel base test
zlsh80826 Mar 17, 2021
c78a98e
add trt affine channel NHWC
zlsh80826 Mar 17, 2021
78c1ea4
remove asterisk for python2 compatibility
zlsh80826 Mar 18, 2021
10a2aac
trt affine channel converter
zlsh80826 Mar 15, 2021
e87eca2
add trt affine channel base test
zlsh80826 Mar 17, 2021
4874a17
add trt affine channel NHWC
zlsh80826 Mar 17, 2021
4909a83
remove asterisk for python2 compatibility
zlsh80826 Mar 18, 2021
c824589
fix rebase
zlsh80826 Mar 18, 2021
181f682
fix conflict
zlsh80826 Mar 18, 2021
b9a8c1c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Mar 18, 2021
10841dd
move LodTensor to Tensor
zlsh80826 Mar 19, 2021
2ba7788
add dbg info
zlsh80826 Mar 19, 2021
7074aa1
affine channel converter only support NCHW
zlsh80826 Mar 22, 2021
bb5b441
scale,bias are parameters, use create_parameters api
zlsh80826 Mar 22, 2021
dbba936
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Mar 22, 2021
b7eb3f2
reduce test input size to not exceed the timelimit of ci
zlsh80826 Mar 25, 2021
98a369b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Mar 25, 2021
2df6971
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Mar 25, 2021
fbd4ec0
refine affine channel unittest and add serialization/dynamic test
zlsh80826 Mar 25, 2021
872612c
change super to InferencePassTest for python2 compatibility
zlsh80826 Mar 25, 2021
307f560
change super to InferencePassTest for python2 compatibility
zlsh80826 Mar 25, 2021
1ae78de
fix affine channel fp16 serialize setting
zlsh80826 Mar 26, 2021
99c286a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Mar 26, 2021
a668849
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Mar 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,7 @@ USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
USE_TRT_CONVERTER(clip);
USE_TRT_CONVERTER(gather);
USE_TRT_CONVERTER(affine_channel);
USE_TRT_CONVERTER(multiclass_nms);
USE_TRT_CONVERTER(nearest_interp);
#endif
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ nv_library(tensorrt_converter
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
gather_op.cc
affine_channel_op.cc
multiclass_nms_op.cc
nearest_interp_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
Expand Down
94 changes: 94 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"

namespace paddle {
namespace framework {
class Scope;

namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle

namespace paddle {
namespace inference {
namespace tensorrt {

/*
* Affine Channel Op
*/
class AffineChannelOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid affine_channel op to tensorrt scale nd layer";

framework::OpDesc op_desc(op, nullptr);
std::string input_name = op_desc.Input("X").front();
std::string scale_name = op_desc.Input("Scale").front();
std::string bias_name = op_desc.Input("Bias").front();
std::string output_name = op_desc.Output("Out").front();

auto input_tensor = engine_->GetITensor(input_name);
auto idim = input_tensor->getDimensions();

auto* scale_v = scope.FindVar(scale_name);
auto* scale_t = scale_v->GetMutable<framework::LoDTensor>();
float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t, false);

auto* bias_v = scope.FindVar(bias_name);
auto* bias_t = bias_v->GetMutable<framework::LoDTensor>();
float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t, false);

auto data_layout = framework::StringToDataLayout(
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout")));

PADDLE_ENFORCE_EQ(
data_layout, framework::DataLayout::kNCHW,
platform::errors::InvalidArgument(
"TensorRT affine channel converter can only convert NCHW format. "
"Other format should be run in fluid mode. Report a bug on github "
"issue if you see this line."));

// tensorrt scalend layer only support spatial dims >= 2,
// so nhwc is not availabe (spatial dims == 0)
const int channel_axis = engine_->with_dynamic_shape();

TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(scale_ptr),
(size_t)idim.d[channel_axis]};
TensorRTEngine::Weight bias_weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_ptr),
(size_t)idim.d[channel_axis]};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};

auto layer = TRT_ENGINE_ADD_LAYER(engine_, ScaleNd, *input_tensor,
nvinfer1::ScaleMode::kCHANNEL,
bias_weights.get(), scale_weights.get(),
power_weights.get(), channel_axis);

RreplenishLayerAndOutput(layer, "affine_channel", {output_name}, test_mode);
}
};

} // namespace tensorrt
} // namespace inference
} // namespace paddle

REGISTER_TRT_OP_CONVERTER(affine_channel, AffineChannelOpConverter);
10 changes: 9 additions & 1 deletion paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"flatten2",
"flatten",
"gather",
"affine_channel",
"multiclass_nms",
"nearest_interp",
};
Expand Down Expand Up @@ -196,6 +197,13 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (!with_dynamic_shape || desc.Input("Axis").size() > 0) return false;
}

if (op_type == "affine_channel") {
if (!desc.HasAttr("data_layout")) return false;
auto data_layout = framework::StringToDataLayout(
BOOST_GET_CONST(std::string, desc.GetAttr("data_layout")));
if (data_layout != framework::DataLayout::kNCHW) return false;
}

if (op_type == "multiclass_nms") {
if (with_dynamic_shape) return false;
auto* block = desc.Block();
Expand Down Expand Up @@ -238,6 +246,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
}
}

if (op_type == "nearest_interp") {
std::vector<std::string> attrs{"data_layout", "interp_method",
"align_corners", "scale",
Expand All @@ -254,7 +263,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
BOOST_GET_CONST(std::string, desc.GetAttr("interp_method"));
if (interp_method != "nearest") return false;
}

if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import itertools
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig


class TRTAffineChannelTest(InferencePassTest):
def setUp(self):
self.bs = 2
self.channel = 8
self.height = 16
self.width = 16
self.data_layout = 'NCHW'
self.precision = AnalysisConfig.Precision.Float32
self.serialize = False
self.enable_trt = True

def build(self):
# set min_graph_size to 2,
# because affine channel doesn't support nhwc format
self.trt_parameters = InferencePassTest.TensorRTParam(
1 << 30, self.bs, 2, self.precision, self.serialize, False)

with fluid.program_guard(self.main_program, self.startup_program):
if self.data_layout == 'NCHW':
shape = [-1, self.channel, self.height, self.width]
else:
shape = [-1, self.height, self.width, self.channel]

data = fluid.data(name='in', shape=shape, dtype='float32')
# set scale, bias by constant
scale = fluid.layers.create_parameter(
shape=[self.channel],
dtype='float32',
default_initializer=fluid.initializer.Constant(2.))
bias = fluid.layers.create_parameter(
shape=[self.channel],
dtype='float32',
default_initializer=fluid.initializer.Constant(.5))
affine_channel_out = fluid.layers.affine_channel(
data, scale=scale, bias=bias, data_layout=self.data_layout)
out = fluid.layers.batch_norm(affine_channel_out, is_test=True)

shape[0] = self.bs
self.feeds = {'in': np.random.random(shape).astype('float32'), }
self.fetch_list = [out]

def check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
atol = 1e-5
if self.trt_parameters.precision == AnalysisConfig.Precision.Half:
atol = 1e-3
self.check_output_with_option(use_gpu, atol, flatten=True)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))

def run_test(self):
self.build()
self.check_output()

def run_test_all(self):
precision_opt = [
AnalysisConfig.Precision.Float32, AnalysisConfig.Precision.Half
]
serialize_opt = [False, True]

if self.data_layout == 'NCHW':
min_shape = [
self.bs, self.channel, self.height // 2, self.width // 2
]
max_shape = [self.bs, self.channel, self.height * 2, self.width * 2]
opt_shape = [self.bs, self.channel, self.height, self.width]

if self.data_layout == 'NHWC':
min_shape = [
self.bs, self.height // 2, self.width // 2, self.channel
]
max_shape = [self.bs, self.height * 2, self.width * 2, self.channel]
opt_shape = [self.bs, self.height, self.width, self.channel]

dynamic_shape_profile = InferencePassTest.DynamicShapeParam({
'in': min_shape
}, {'in': max_shape}, {'in': opt_shape}, False)
dynamic_shape_opt = [None, dynamic_shape_profile]

for precision, serialize, dynamic_shape in itertools.product(
precision_opt, serialize_opt, dynamic_shape_opt):
self.precision = precision
self.serialize = serialize
self.dynamic_shape_params = dynamic_shape
self.run_test()

def test_base(self):
self.run_test()

def test_fp16(self):
self.precision = AnalysisConfig.Precision.Half
self.run_test()

def test_serialize(self):
self.serialize = True
self.run_test()

def test_dynamic(self):
self.dynamic_shape_params = InferencePassTest.DynamicShapeParam({
'in': [self.bs, self.channel, self.height // 2, self.width // 2]
}, {'in': [self.bs, self.channel, self.height * 2, self.width * 2]
}, {'in': [self.bs, self.channel, self.height, self.width]}, False)
self.run_test()

def test_nchw_all(self):
self.run_test_all()

def test_nhwc(self):
self.data_layout = 'NHWC'
self.run_test_all()


if __name__ == "__main__":
unittest.main()