Skip to content

Commit a83062b

Browse files
committed
add relu6
1 parent c84addc commit a83062b

File tree

3 files changed

+25
-0
lines changed

3 files changed

+25
-0
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ DEFINE_GENERAL_PATTERN(Bmm, paddle::dialect::BmmOp)
5959
DEFINE_GENERAL_PATTERN(Concat, paddle::dialect::ConcatOp)
6060
DEFINE_GENERAL_PATTERN(Nonzero, paddle::dialect::NonzeroOp)
6161
DEFINE_GENERAL_PATTERN(Gelu, paddle::dialect::GeluOp)
62+
DEFINE_GENERAL_PATTERN(Relu6, paddle::dialect::Relu6Op)
6263
DEFINE_GENERAL_PATTERN(Fused_gemm_epilogue,
6364
paddle::dialect::FusedGemmEpilogueOp)
6465
DEFINE_GENERAL_PATTERN(Layer_norm, paddle::dialect::LayerNormOp)
@@ -2134,6 +2135,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
21342135
ADD_PATTERN(DepthwiseConv2d)
21352136
ADD_PATTERN(Nonzero)
21362137
ADD_PATTERN(Gelu)
2138+
ADD_PATTERN(Relu6)
21372139
ADD_PATTERN(Shape)
21382140
ADD_PATTERN(Shape64)
21392141
ADD_PATTERN(Expand)

python/paddle/tensorrt/impls/activation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ def activation_converter(network, paddle_op, inputs):
4545
return layer.get_output(0)
4646

4747

48+
@converter_registry.register("pd_op.relu6", trt_version="trt_version_ge=8.0")
49+
def relu6_converter(network, paddle_op, inputs):
50+
layer = network.add_activation(inputs[0], trt.ActivationType.CLIP)
51+
layer.alpha = 0.0
52+
layer.beta = 6.0
53+
return layer.get_output(0)
54+
55+
4856
@converter_registry.register("pd_op.softmax", trt_version="trt_version_ge=8.0")
4957
def softmax_converter(network, paddle_op, inputs):
5058
axis = paddle_op.attrs().get("axis", 0)

test/tensorrt/test_converter_activation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ def test_trt_result(self):
9090
self.check_trt_result()
9191

9292

93+
class TestRelu6TRTPattern(TensorRTBaseTest):
94+
def setUp(self):
95+
self.python_api = paddle.nn.functional.relu6
96+
self.api_args = {"x": np.random.randn(3).astype("float32")}
97+
self.program_config = {"feed_list": ["x"]}
98+
self.min_shape = {"x": [1]}
99+
self.max_shape = {"x": [5]}
100+
101+
def test_trt_result(self):
102+
self.check_trt_result()
103+
104+
def test_trt_result_fp16(self):
105+
self.check_trt_result(precision_mode="fp16")
106+
107+
93108
class TestTanhTRTPattern(TensorRTBaseTest):
94109
def setUp(self):
95110
self.python_api = paddle.tanh

0 commit comments

Comments
 (0)