Skip to content

Commit 989d454

Browse files
authored
[Paddle Tensorrt] add tensorrt converter and marker (#69208)
* add allop * mod file * fix file * fix conflitct
1 parent bf0e51c commit 989d454

File tree

14 files changed

+518
-13
lines changed

14 files changed

+518
-13
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ DEFINE_GENERAL_PATTERN(AssignValue_, paddle::dialect::AssignValue_Op)
8282
DEFINE_GENERAL_PATTERN(Tile, paddle::dialect::TileOp)
8383
DEFINE_GENERAL_PATTERN(Share_Data, paddle::dialect::ShareDataOp)
8484
DEFINE_GENERAL_PATTERN(AssignOut, paddle::dialect::AssignOut_Op)
85+
DEFINE_GENERAL_PATTERN(Swish, paddle::dialect::SwishOp)
86+
DEFINE_GENERAL_PATTERN(Log, paddle::dialect::LogOp)
87+
DEFINE_GENERAL_PATTERN(Floor, paddle::dialect::FloorOp)
8588
DEFINE_GENERAL_PATTERN(Roll, paddle::dialect::RollOp)
8689

8790
#undef DEFINE_GENERAL_PATTERN
@@ -1500,6 +1503,87 @@ class WherePattern : public pir::OpRewritePattern<paddle::dialect::WhereOp> {
15001503
}
15011504
};
15021505

1506+
class EqualOpPattern : public pir::OpRewritePattern<paddle::dialect::EqualOp> {
1507+
public:
1508+
using pir::OpRewritePattern<paddle::dialect::EqualOp>::OpRewritePattern;
1509+
bool MatchAndRewrite(paddle::dialect::EqualOp op,
1510+
pir::PatternRewriter &rewriter) const override {
1511+
if (op->HasAttribute(kCanRunTrtAttr) &&
1512+
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
1513+
return false;
1514+
}
1515+
#if IS_TRT_VERSION_LT(8600)
1516+
pir::Value x = op.operand_source(0);
1517+
auto x_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>();
1518+
auto x_shape = x_type.dims();
1519+
int dims = x_shape.size();
1520+
if (dims < 1) {
1521+
VLOG(3)
1522+
<< "pd_op.equal op does not support 0 dim input when TensorRT < 8.6.";
1523+
return false;
1524+
}
1525+
#endif
1526+
1527+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
1528+
return true;
1529+
}
1530+
};
1531+
1532+
class NotEqualOpPattern
1533+
: public pir::OpRewritePattern<paddle::dialect::NotEqualOp> {
1534+
public:
1535+
using pir::OpRewritePattern<paddle::dialect::NotEqualOp>::OpRewritePattern;
1536+
bool MatchAndRewrite(paddle::dialect::NotEqualOp op,
1537+
pir::PatternRewriter &rewriter) const override {
1538+
if (op->HasAttribute(kCanRunTrtAttr) &&
1539+
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
1540+
return false;
1541+
}
1542+
#if IS_TRT_VERSION_LT(8600)
1543+
pir::Value x = op.operand_source(0);
1544+
auto x_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>();
1545+
auto x_shape = x_type.dims();
1546+
int dims = x_shape.size();
1547+
if (dims < 1) {
1548+
VLOG(3) << "pd_op.not_equal op does not support 0 dim input when "
1549+
"TensorRT < 8.6.";
1550+
return false;
1551+
}
1552+
#endif
1553+
1554+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
1555+
return true;
1556+
}
1557+
};
1558+
1559+
class FullLikeOpPattern
1560+
: public pir::OpRewritePattern<paddle::dialect::FullLikeOp> {
1561+
public:
1562+
using pir::OpRewritePattern<paddle::dialect::FullLikeOp>::OpRewritePattern;
1563+
bool MatchAndRewrite(paddle::dialect::FullLikeOp op,
1564+
pir::PatternRewriter &rewriter) const override {
1565+
if (op->HasAttribute(kCanRunTrtAttr) &&
1566+
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
1567+
return false;
1568+
}
1569+
pir::Value x = op.operand_source(0);
1570+
auto x_dtype = pir::GetDataTypeFromValue(x);
1571+
bool hasAttr = op->HasAttribute("dtype");
1572+
auto dtype =
1573+
op->attribute<paddle::dialect::DataTypeAttribute>("dtype").data();
1574+
1575+
if (dtype == phi::DataType::BOOL ||
1576+
(!hasAttr && x_dtype.isa<pir::BoolType>())) {
1577+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
1578+
VLOG(3) << "the pd_op.full_like supports input of BOOL by trt8.4 above";
1579+
return true;
1580+
}
1581+
1582+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
1583+
return true;
1584+
}
1585+
};
1586+
15031587
class FullWithTensorPattern
15041588
: public pir::OpRewritePattern<paddle::dialect::FullWithTensorOp> {
15051589
public:
@@ -1642,6 +1726,9 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
16421726
ADD_PATTERN(AssignValue_)
16431727
ADD_PATTERN(Tile)
16441728
ADD_PATTERN(Share_Data)
1729+
ADD_PATTERN(Swish)
1730+
ADD_PATTERN(Log)
1731+
ADD_PATTERN(Floor)
16451732
ADD_PATTERN(Roll)
16461733
#if IS_TRT_VERSION_GE(8600)
16471734
ADD_PATTERN(Layer_norm)
@@ -1692,9 +1779,12 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
16921779
ps.Add(std::make_unique<StackOpPattern>(context));
16931780
ps.Add(std::make_unique<TanhOpPattern>(context));
16941781
ps.Add(std::make_unique<WherePattern>(context));
1782+
ps.Add(std::make_unique<FullLikeOpPattern>(context));
16951783
ps.Add(std::make_unique<FullWithTensorPattern>(context));
16961784
ps.Add(std::make_unique<StridedSliceOpPattern>(context));
16971785
ps.Add(std::make_unique<TopkOpPattern>(context));
1786+
ps.Add(std::make_unique<EqualOpPattern>(context));
1787+
ps.Add(std::make_unique<NotEqualOpPattern>(context));
16981788
return ps;
16991789
}
17001790
};

python/paddle/tensorrt/impls/activation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717

1818
from paddle.tensorrt.converter_utils import (
1919
get_trt_plugin,
20+
trt_mul,
2021
)
2122
from paddle.tensorrt.register import converter_registry
2223

2324
activation_type_map = {
2425
"pd_op.tanh": trt.ActivationType.TANH,
2526
"pd_op.relu": trt.ActivationType.RELU,
2627
"pd_op.sigmoid": trt.ActivationType.SIGMOID,
28+
"pd_op.silu": trt.ActivationType.SIGMOID,
29+
"pd_op.swish": trt.ActivationType.SIGMOID,
2730
}
2831

2932

@@ -99,3 +102,12 @@ def hardswish_converter(network, paddle_op, inputs):
99102
x, hardsigmoid_layer.get_output(0), trt.ElementWiseOperation.PROD
100103
)
101104
return hardswish_layer.get_output(0)
105+
106+
107+
@converter_registry.register("pd_op.swish", trt_version="8.x")
108+
@converter_registry.register("pd_op.silu", trt_version="8.x")
109+
def swish_silu_converter(network, paddle_op, inputs):
110+
layer_output = network.add_activation(
111+
inputs[0], activation_type_map[paddle_op.name()]
112+
).get_output(0)
113+
return trt_mul(network, inputs[0], layer_output)

python/paddle/tensorrt/impls/creation.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from paddle.tensorrt.converter_utils import (
2020
add_1D_constant_layer,
2121
cast_tensor,
22+
trt_cast,
2223
trt_floor_div,
2324
trt_max,
2425
trt_reduce_to_scalar,
@@ -98,6 +99,69 @@ def arange_converter(network, paddle_op, inputs):
9899
return fill_layer.get_output(0)
99100

100101

102+
@converter_registry.register("pd_op.full_like", trt_version="8.x")
103+
def full_like_converter(network, paddle_op, inputs):
104+
shape = tuple(paddle_op.operands()[0].source().shape)
105+
ndims = len(shape)
106+
107+
out_dtype = int(paddle_op.attrs().get("dtype", None))
108+
# Reference paddle/phi/common/data_type.h enum DataType
109+
if out_dtype == 1: # paddle.bool
110+
out_dtype = trt.int32
111+
elif out_dtype == 7: # paddle.int32
112+
out_dtype = trt.int32
113+
elif out_dtype == 9: # paddle.int64
114+
out_dtype = trt.int32
115+
elif out_dtype == 10: # paddle.float32
116+
out_dtype = trt.float32
117+
elif out_dtype == 11: # paddle.float64
118+
out_dtype = trt.float32
119+
else:
120+
raise RuntimeError(
121+
f"cast converter currently doesn't support dtype: {out_dtype}"
122+
)
123+
124+
value_op = paddle_op.operands()[1].source().get_defining_op()
125+
if value_op.name() == "pd_op.full":
126+
fill_value = value_op.attrs()["value"]
127+
value = network.add_constant(
128+
(1,),
129+
np.array(
130+
[
131+
fill_value,
132+
],
133+
dtype=np.float32,
134+
),
135+
).get_output(0)
136+
value = trt_cast(network, value, out_dtype)
137+
else:
138+
value = inputs[1]
139+
140+
shuffle_layer = network.add_shuffle(value)
141+
shuffle_layer.reshape_dims = (1,) * ndims
142+
143+
start_vec = np.zeros((ndims,), dtype=np.int32)
144+
start_tensor = network.add_constant((ndims,), start_vec).get_output(0)
145+
shape_tensor = network.add_shape(inputs[0]).get_output(0)
146+
stride_tensor = network.add_constant(
147+
(ndims,), np.ones((ndims,), dtype=np.int32)
148+
).get_output(0)
149+
150+
slice_layer = network.add_slice(
151+
shuffle_layer.get_output(0),
152+
start_vec,
153+
[1] * ndims,
154+
np.ones((ndims,), dtype=np.int32),
155+
)
156+
slice_layer.mode = trt.SliceMode.FILL
157+
slice_layer.set_input(1, start_tensor)
158+
slice_layer.set_input(2, shape_tensor)
159+
slice_layer.set_input(3, stride_tensor)
160+
value = trt_cast(network, value, out_dtype)
161+
slice_layer.set_input(4, value)
162+
return slice_layer.get_output(0)
163+
164+
101165
@converter_registry.register("pd_op.full_with_tensor", trt_version="8.x")
102166
def full_with_tensor_converter(network, paddle_op, inputs):
103167
value_input = inputs[0]

python/paddle/tensorrt/impls/linalg.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,11 @@ def transpose_converter(network, paddle_op, inputs):
6363
transposed_tensor = network.add_shuffle(inputs[0])
6464
transposed_tensor.second_transpose = perm
6565
return transposed_tensor.get_output(0)
66+
67+
68+
@converter_registry.register("pd_op.bmm", trt_version="8.x")
69+
def bmm_converter(network, paddle_op, inputs):
70+
out = network.add_matrix_multiply(
71+
inputs[0], trt.MatrixOperation.NONE, inputs[1], trt.MatrixOperation.NONE
72+
)
73+
return out.get_output(0)

python/paddle/tensorrt/impls/logic.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,28 @@
2020
)
2121
from paddle.tensorrt.register import converter_registry
2222

23+
logic_type_map = {
24+
"pd_op.greater_than": trt.ElementWiseOperation.GREATER,
25+
"pd_op.less_than": trt.ElementWiseOperation.LESS,
26+
"pd_op.equal": trt.ElementWiseOperation.EQUAL,
27+
}
28+
2329

2430
@converter_registry.register("pd_op.greater_than", trt_version="8.x")
2531
@converter_registry.register("pd_op.less_than", trt_version="8.x")
32+
@converter_registry.register("pd_op.equal", trt_version="8.x")
2633
def logic_converter(network, paddle_op, inputs):
27-
if paddle_op.name() == "pd_op.greater_than":
28-
layer_output = add_elementwise_layer(
29-
network, paddle_op, inputs, trt.ElementWiseOperation.GREATER
30-
)
31-
elif paddle_op.name() == "pd_op.less_than":
32-
layer_output = add_elementwise_layer(
33-
network, paddle_op, inputs, trt.ElementWiseOperation.LESS
34-
)
35-
else:
36-
raise ValueError(f"Unexpected paddle_op: {paddle_op.name()}")
34+
layer_output = add_elementwise_layer(
35+
network, paddle_op, inputs, logic_type_map[paddle_op.name()]
36+
)
37+
return trt_cast(network, layer_output, inputs[0].dtype)
38+
39+
40+
@converter_registry.register("pd_op.not_equal", trt_version="8.x")
41+
def not_equal_converter(network, paddle_op, inputs):
42+
layer_output = add_elementwise_layer(
43+
network, paddle_op, inputs, trt.ElementWiseOperation.EQUAL
44+
)
45+
not_layer = network.add_unary(layer_output, trt.UnaryOperation.NOT)
46+
layer_output = not_layer.get_output(0)
3747
return trt_cast(network, layer_output, inputs[0].dtype)

python/paddle/tensorrt/impls/math.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
add_reduce_layer,
2222
broadcast,
2323
get_axes_for_reduce_op,
24+
trt_cast,
2425
trt_div,
2526
trt_floor_div,
2627
trt_mul,
@@ -170,3 +171,17 @@ def all_converter(network, paddle_op, inputs):
170171
return add_cast_reduce_layer(
171172
network, paddle_op, inputs, trt.ReduceOperation.MIN
172173
)
174+
175+
176+
@converter_registry.register("pd_op.floor_divide", trt_version="8.x")
177+
def floor_divide_converter(network, paddle_op, inputs):
178+
return add_elementwise_layer(
179+
network, paddle_op, inputs, trt.ElementWiseOperation.FLOOR_DIV
180+
)
181+
182+
183+
@converter_registry.register("pd_op.log", trt_version="8.x")
184+
def sqrt_converter(network, paddle_op, inputs):
185+
input_tensor = trt_cast(network, inputs[0], trt.float32)
186+
layer = network.add_unary(input_tensor, trt.UnaryOperation.LOG)
187+
return layer.get_output(0)

python/paddle/tensorrt/impls/ops.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,17 @@
1515

1616
from paddle.tensorrt.register import converter_registry
1717

18+
ops_type_map = {
19+
"pd_op.sqrt": trt.UnaryOperation.SQRT,
20+
"pd_op.sqrt_": trt.UnaryOperation.SQRT,
21+
"pd_op.floor": trt.UnaryOperation.FLOOR,
22+
}
23+
1824

1925
@converter_registry.register("pd_op.sqrt", trt_version="8.x")
2026
@converter_registry.register("pd_op.sqrt_", trt_version="8.x")
27+
@converter_registry.register("pd_op.floor", trt_version="8.x")
2128
def sqrt_converter(network, paddle_op, inputs):
2229
input_tensor = inputs[0]
23-
24-
sqrt_layer = network.add_unary(input_tensor, trt.UnaryOperation.SQRT)
25-
return sqrt_layer.get_output(0)
30+
layer = network.add_unary(input_tensor, ops_type_map[paddle_op.name()])
31+
return layer.get_output(0)

test/tensorrt/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ if(NOT WIN32 AND TENSORRT_FOUND)
2323
set_tests_properties(test_converter_creation PROPERTIES TIMEOUT "300")
2424
set_tests_properties(test_converter_attribute PROPERTIES TIMEOUT "300")
2525
set_tests_properties(test_converter_common PROPERTIES TIMEOUT "300")
26+
set_tests_properties(test_converter_linalg PROPERTIES TIMEOUT "100")
2627
set_tests_properties(test_converter_search PROPERTIES TIMEOUT "300")
2728
set_tests_properties(test_converter_logic PROPERTIES TIMEOUT "300")
2829

test/tensorrt/test_converter_activation.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,33 @@ def test_trt_result(self):
8686
self.check_trt_result()
8787

8888

89+
class TestSiluFloatTRTPattern(TensorRTBaseTest):
90+
def setUp(self):
91+
self.python_api = paddle.nn.functional.silu
92+
self.api_args = {
93+
"x": np.random.randn(2, 3).astype("float32"),
94+
}
95+
self.program_config = {"feed_list": ["x"]}
96+
self.min_shape = {"x": [1, 3]}
97+
self.max_shape = {"x": [5, 3]}
98+
99+
def test_trt_result(self):
100+
self.check_trt_result()
101+
102+
103+
class TestSwishFloatTRTPattern(TensorRTBaseTest):
104+
def setUp(self):
105+
self.python_api = paddle.nn.functional.swish
106+
self.api_args = {
107+
"x": np.random.randn(2, 3).astype("float32"),
108+
}
109+
self.program_config = {"feed_list": ["x"]}
110+
self.min_shape = {"x": [1, 3]}
111+
self.max_shape = {"x": [5, 3]}
112+
113+
def test_trt_result(self):
114+
self.check_trt_result()
115+
116+
89117
if __name__ == '__main__':
90118
unittest.main()

0 commit comments

Comments
 (0)