Skip to content

[Paddle Tensorrt] add tensorrt converter and marker #69208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ DEFINE_GENERAL_PATTERN(AssignValue_, paddle::dialect::AssignValue_Op)
DEFINE_GENERAL_PATTERN(Tile, paddle::dialect::TileOp)
DEFINE_GENERAL_PATTERN(Share_Data, paddle::dialect::ShareDataOp)
DEFINE_GENERAL_PATTERN(AssignOut, paddle::dialect::AssignOut_Op)
DEFINE_GENERAL_PATTERN(Swish, paddle::dialect::SwishOp)
DEFINE_GENERAL_PATTERN(Log, paddle::dialect::LogOp)
DEFINE_GENERAL_PATTERN(Floor, paddle::dialect::FloorOp)
DEFINE_GENERAL_PATTERN(Roll, paddle::dialect::RollOp)

#undef DEFINE_GENERAL_PATTERN
Expand Down Expand Up @@ -1500,6 +1503,87 @@ class WherePattern : public pir::OpRewritePattern<paddle::dialect::WhereOp> {
}
};

class EqualOpPattern : public pir::OpRewritePattern<paddle::dialect::EqualOp> {
public:
using pir::OpRewritePattern<paddle::dialect::EqualOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::EqualOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
#if IS_TRT_VERSION_LT(8600)
pir::Value x = op.operand_source(0);
auto x_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto x_shape = x_type.dims();
int dims = x_shape.size();
if (dims < 1) {
VLOG(3)
<< "pd_op.equal op does not support 0 dim input when TensorRT < 8.6.";
return false;
}
#endif

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class NotEqualOpPattern
: public pir::OpRewritePattern<paddle::dialect::NotEqualOp> {
public:
using pir::OpRewritePattern<paddle::dialect::NotEqualOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::NotEqualOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
#if IS_TRT_VERSION_LT(8600)
pir::Value x = op.operand_source(0);
auto x_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto x_shape = x_type.dims();
int dims = x_shape.size();
if (dims < 1) {
VLOG(3) << "pd_op.not_equal op does not support 0 dim input when "
"TensorRT < 8.6.";
return false;
}
#endif

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class FullLikeOpPattern
: public pir::OpRewritePattern<paddle::dialect::FullLikeOp> {
public:
using pir::OpRewritePattern<paddle::dialect::FullLikeOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::FullLikeOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value x = op.operand_source(0);
auto x_dtype = pir::GetDataTypeFromValue(x);
bool hasAttr = op->HasAttribute("dtype");
auto dtype =
op->attribute<paddle::dialect::DataTypeAttribute>("dtype").data();

if (dtype == phi::DataType::BOOL ||
(!hasAttr && x_dtype.isa<pir::BoolType>())) {
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
VLOG(3) << "the pd_op.full_like supports input of BOOL by trt8.4 above";
return true;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class FullWithTensorPattern
: public pir::OpRewritePattern<paddle::dialect::FullWithTensorOp> {
public:
Expand Down Expand Up @@ -1642,6 +1726,9 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ADD_PATTERN(AssignValue_)
ADD_PATTERN(Tile)
ADD_PATTERN(Share_Data)
ADD_PATTERN(Swish)
ADD_PATTERN(Log)
ADD_PATTERN(Floor)
ADD_PATTERN(Roll)
#if IS_TRT_VERSION_GE(8600)
ADD_PATTERN(Layer_norm)
Expand Down Expand Up @@ -1692,9 +1779,12 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<StackOpPattern>(context));
ps.Add(std::make_unique<TanhOpPattern>(context));
ps.Add(std::make_unique<WherePattern>(context));
ps.Add(std::make_unique<FullLikeOpPattern>(context));
ps.Add(std::make_unique<FullWithTensorPattern>(context));
ps.Add(std::make_unique<StridedSliceOpPattern>(context));
ps.Add(std::make_unique<TopkOpPattern>(context));
ps.Add(std::make_unique<EqualOpPattern>(context));
ps.Add(std::make_unique<NotEqualOpPattern>(context));
return ps;
}
};
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/tensorrt/impls/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@

from paddle.tensorrt.converter_utils import (
get_trt_plugin,
trt_mul,
)
from paddle.tensorrt.register import converter_registry

activation_type_map = {
"pd_op.tanh": trt.ActivationType.TANH,
"pd_op.relu": trt.ActivationType.RELU,
"pd_op.sigmoid": trt.ActivationType.SIGMOID,
"pd_op.silu": trt.ActivationType.SIGMOID,
"pd_op.swish": trt.ActivationType.SIGMOID,
}


Expand Down Expand Up @@ -99,3 +102,12 @@ def hardswish_converter(network, paddle_op, inputs):
x, hardsigmoid_layer.get_output(0), trt.ElementWiseOperation.PROD
)
return hardswish_layer.get_output(0)


@converter_registry.register("pd_op.swish", trt_version="8.x")
@converter_registry.register("pd_op.silu", trt_version="8.x")
def swish_silu_converter(network, paddle_op, inputs):
layer_output = network.add_activation(
inputs[0], activation_type_map[paddle_op.name()]
).get_output(0)
return trt_mul(network, inputs[0], layer_output)
64 changes: 64 additions & 0 deletions python/paddle/tensorrt/impls/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from paddle.tensorrt.converter_utils import (
add_1D_constant_layer,
cast_tensor,
trt_cast,
trt_floor_div,
trt_max,
trt_reduce_to_scalar,
Expand Down Expand Up @@ -98,6 +99,69 @@ def arange_converter(network, paddle_op, inputs):
return fill_layer.get_output(0)


@converter_registry.register("pd_op.full_like", trt_version="8.x")
def full_like_converter(network, paddle_op, inputs):
shape = tuple(paddle_op.operands()[0].source().shape)
ndims = len(shape)

out_dtype = int(paddle_op.attrs().get("dtype", None))
# Reference paddle/phi/common/data_type.h enum DataType
if out_dtype == 1: # paddle.bool
out_dtype = trt.int32
elif out_dtype == 7: # paddle.int32
out_dtype = trt.int32
elif out_dtype == 9: # paddle.int64
out_dtype = trt.int32
elif out_dtype == 10: # paddle.float32
out_dtype = trt.float32
elif out_dtype == 11: # paddle.float64
out_dtype = trt.float32
else:
raise RuntimeError(
f"cast converter currently doesn't support dtype: {out_dtype}"
)

value_op = paddle_op.operands()[1].source().get_defining_op()
if value_op.name() == "pd_op.full":
fill_value = value_op.attrs()["value"]
value = network.add_constant(
(1,),
np.array(
[
fill_value,
],
dtype=np.float32,
),
).get_output(0)
value = trt_cast(network, value, out_dtype)
else:
value = inputs[1]

shuffle_layer = network.add_shuffle(value)
shuffle_layer.reshape_dims = (1,) * ndims

start_vec = np.zeros((ndims,), dtype=np.int32)
start_tensor = network.add_constant((ndims,), start_vec).get_output(0)
shape_tensor = network.add_shape(inputs[0]).get_output(0)
stride_tensor = network.add_constant(
(ndims,), np.ones((ndims,), dtype=np.int32)
).get_output(0)

slice_layer = network.add_slice(
shuffle_layer.get_output(0),
start_vec,
[1] * ndims,
np.ones((ndims,), dtype=np.int32),
)
slice_layer.mode = trt.SliceMode.FILL
slice_layer.set_input(1, start_tensor)
slice_layer.set_input(2, shape_tensor)
slice_layer.set_input(3, stride_tensor)
value = trt_cast(network, value, out_dtype)
slice_layer.set_input(4, value)
return slice_layer.get_output(0)


@converter_registry.register("pd_op.full_with_tensor", trt_version="8.x")
def full_with_tensor_converter(network, paddle_op, inputs):
value_input = inputs[0]
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/tensorrt/impls/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,11 @@ def transpose_converter(network, paddle_op, inputs):
transposed_tensor = network.add_shuffle(inputs[0])
transposed_tensor.second_transpose = perm
return transposed_tensor.get_output(0)


@converter_registry.register("pd_op.bmm", trt_version="8.x")
def bmm_converter(network, paddle_op, inputs):
out = network.add_matrix_multiply(
inputs[0], trt.MatrixOperation.NONE, inputs[1], trt.MatrixOperation.NONE
)
return out.get_output(0)
30 changes: 20 additions & 10 deletions python/paddle/tensorrt/impls/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,28 @@
)
from paddle.tensorrt.register import converter_registry

logic_type_map = {
"pd_op.greater_than": trt.ElementWiseOperation.GREATER,
"pd_op.less_than": trt.ElementWiseOperation.LESS,
"pd_op.equal": trt.ElementWiseOperation.EQUAL,
}


@converter_registry.register("pd_op.greater_than", trt_version="8.x")
@converter_registry.register("pd_op.less_than", trt_version="8.x")
@converter_registry.register("pd_op.equal", trt_version="8.x")
def logic_converter(network, paddle_op, inputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把这里的逻辑重构一下吧,也改成映射的形式,参考activation相关converter的做法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已重构

if paddle_op.name() == "pd_op.greater_than":
layer_output = add_elementwise_layer(
network, paddle_op, inputs, trt.ElementWiseOperation.GREATER
)
elif paddle_op.name() == "pd_op.less_than":
layer_output = add_elementwise_layer(
network, paddle_op, inputs, trt.ElementWiseOperation.LESS
)
else:
raise ValueError(f"Unexpected paddle_op: {paddle_op.name()}")
layer_output = add_elementwise_layer(
network, paddle_op, inputs, logic_type_map[paddle_op.name()]
)
return trt_cast(network, layer_output, inputs[0].dtype)


@converter_registry.register("pd_op.not_equal", trt_version="8.x")
def not_equal_converter(network, paddle_op, inputs):
layer_output = add_elementwise_layer(
network, paddle_op, inputs, trt.ElementWiseOperation.EQUAL
)
not_layer = network.add_unary(layer_output, trt.UnaryOperation.NOT)
layer_output = not_layer.get_output(0)
return trt_cast(network, layer_output, inputs[0].dtype)
15 changes: 15 additions & 0 deletions python/paddle/tensorrt/impls/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
add_reduce_layer,
broadcast,
get_axes_for_reduce_op,
trt_cast,
trt_div,
trt_floor_div,
trt_mul,
Expand Down Expand Up @@ -170,3 +171,17 @@ def all_converter(network, paddle_op, inputs):
return add_cast_reduce_layer(
network, paddle_op, inputs, trt.ReduceOperation.MIN
)


@converter_registry.register("pd_op.floor_divide", trt_version="8.x")
def floor_divide_converter(network, paddle_op, inputs):
return add_elementwise_layer(
network, paddle_op, inputs, trt.ElementWiseOperation.FLOOR_DIV
)


@converter_registry.register("pd_op.log", trt_version="8.x")
def sqrt_converter(network, paddle_op, inputs):
input_tensor = trt_cast(network, inputs[0], trt.float32)
layer = network.add_unary(input_tensor, trt.UnaryOperation.LOG)
return layer.get_output(0)
12 changes: 9 additions & 3 deletions python/paddle/tensorrt/impls/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@

from paddle.tensorrt.register import converter_registry

ops_type_map = {
"pd_op.sqrt": trt.UnaryOperation.SQRT,
"pd_op.sqrt_": trt.UnaryOperation.SQRT,
"pd_op.floor": trt.UnaryOperation.FLOOR,
}


@converter_registry.register("pd_op.sqrt", trt_version="8.x")
@converter_registry.register("pd_op.sqrt_", trt_version="8.x")
@converter_registry.register("pd_op.floor", trt_version="8.x")
def sqrt_converter(network, paddle_op, inputs):
input_tensor = inputs[0]

sqrt_layer = network.add_unary(input_tensor, trt.UnaryOperation.SQRT)
return sqrt_layer.get_output(0)
layer = network.add_unary(input_tensor, ops_type_map[paddle_op.name()])
return layer.get_output(0)
1 change: 1 addition & 0 deletions test/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ if(NOT WIN32 AND TENSORRT_FOUND)
set_tests_properties(test_converter_creation PROPERTIES TIMEOUT "300")
set_tests_properties(test_converter_attribute PROPERTIES TIMEOUT "300")
set_tests_properties(test_converter_common PROPERTIES TIMEOUT "300")
set_tests_properties(test_converter_linalg PROPERTIES TIMEOUT "100")
set_tests_properties(test_converter_search PROPERTIES TIMEOUT "300")
set_tests_properties(test_converter_logic PROPERTIES TIMEOUT "300")

Expand Down
28 changes: 28 additions & 0 deletions test/tensorrt/test_converter_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,33 @@ def test_trt_result(self):
self.check_trt_result()


class TestSiluFloatTRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.silu
self.api_args = {
"x": np.random.randn(2, 3).astype("float32"),
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [1, 3]}
self.max_shape = {"x": [5, 3]}

def test_trt_result(self):
self.check_trt_result()


class TestSwishFloatTRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.swish
self.api_args = {
"x": np.random.randn(2, 3).astype("float32"),
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [1, 3]}
self.max_shape = {"x": [5, 3]}

def test_trt_result(self):
self.check_trt_result()


if __name__ == '__main__':
unittest.main()
Loading