Skip to content

Commit 341ac15

Browse files
authored
【SCU】【Paddle TensorRT No.43】Add pd_op.leaky_relupd_op.prelu converter (#70591)
* add * Update test_converter_activation.py * fix codestyle * update * Update test_converter_activation.py * Update test_converter_activation.py * update prelu * fix codestyle * Update activation.py * Update trt_op_marker_pass.cc * Update test_converter_activation.py * Update test_converter_activation.py * Update test_converter_activation.py
1 parent c794454 commit 341ac15

File tree

4 files changed

+238
-0
lines changed

4 files changed

+238
-0
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ DEFINE_GENERAL_PATTERN(Flip, paddle::dialect::FlipOp)
9494
DEFINE_GENERAL_PATTERN(Mish, paddle::dialect::MishOp)
9595
DEFINE_GENERAL_PATTERN(AssignValue, paddle::dialect::AssignValueOp)
9696
DEFINE_GENERAL_PATTERN(AssignValue_, paddle::dialect::AssignValue_Op)
97+
DEFINE_GENERAL_PATTERN(LeakyRelu, paddle::dialect::LeakyReluOp)
98+
DEFINE_GENERAL_PATTERN(LeakyRelu_, paddle::dialect::LeakyRelu_Op)
9799
DEFINE_GENERAL_PATTERN(Anchor_Generator, paddle::dialect::AnchorGeneratorOp)
98100
DEFINE_GENERAL_PATTERN(Exp, paddle::dialect::ExpOp)
99101
DEFINE_GENERAL_PATTERN(Abs, paddle::dialect::AbsOp)
@@ -2538,6 +2540,25 @@ class AffineChannelOpPattern
25382540
}
25392541
};
25402542

2543+
class PreluOpPattern : public pir::OpRewritePattern<paddle::dialect::PreluOp> {
2544+
public:
2545+
using pir::OpRewritePattern<paddle::dialect::PreluOp>::OpRewritePattern;
2546+
bool MatchAndRewrite(paddle::dialect::PreluOp op,
2547+
pir::PatternRewriter &rewriter) const override {
2548+
if (op->HasAttribute(kCanRunTrtAttr) &&
2549+
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
2550+
return false;
2551+
}
2552+
pir::Value alpha_var = op.operand_source(1);
2553+
if (!alpha_var) {
2554+
VLOG(3) << "Variable Alpha of prelu TRT converter not found.";
2555+
return false;
2556+
}
2557+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
2558+
return true;
2559+
}
2560+
};
2561+
25412562
class YoloBoxOpPattern
25422563
: public pir::OpRewritePattern<paddle::dialect::YoloBoxOp> {
25432564
public:
@@ -2762,6 +2783,8 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
27622783
ADD_PATTERN(Mish)
27632784
ADD_PATTERN(AssignValue)
27642785
ADD_PATTERN(AssignValue_)
2786+
ADD_PATTERN(LeakyRelu)
2787+
ADD_PATTERN(LeakyRelu_)
27652788
ADD_PATTERN(Anchor_Generator)
27662789
ADD_PATTERN(Exp)
27672790
ADD_PATTERN(Abs)
@@ -2880,6 +2903,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
28802903
ps.Add(std::make_unique<EinsumOpPattern>(context));
28812904
ps.Add(std::make_unique<PNormOpPattern>(context));
28822905
ps.Add(std::make_unique<AffineChannelOpPattern>(context));
2906+
ps.Add(std::make_unique<PreluOpPattern>(context));
28832907
ps.Add(
28842908
std::make_unique<FusedBiasDropoutResidualLayerNormOpPattern>(context));
28852909
ps.Add(std::make_unique<YoloBoxOpPattern>(context));

python/paddle/tensorrt/impls/activation.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import tensorrt as trt
1717

1818
from paddle.tensorrt.converter_utils import (
19+
add_1D_constant_layer,
1920
add_constant_layer,
21+
trt_concat,
2022
trt_div,
2123
trt_min,
2224
trt_pow,
@@ -276,6 +278,16 @@ def thresholded_relu_converter(network, paddle_op, inputs):
276278
return thresholded_relu_layer.get_output(0)
277279

278280

281+
@converter_registry.register("pd_op.leaky_relu", trt_version="8.x")
282+
@converter_registry.register("pd_op.leaky_relu_", trt_version="8.x")
283+
def leaky_relu_converter(network, paddle_op, inputs):
284+
x = inputs[0]
285+
negative_slope = paddle_op.attrs()["negative_slope"]
286+
leaky_relu_layer = network.add_activation(x, trt.ActivationType.LEAKY_RELU)
287+
leaky_relu_layer.alpha = negative_slope
288+
return leaky_relu_layer.get_output(0)
289+
290+
279291
@converter_registry.register("pd_op.selu", trt_version="8.x")
280292
def selu_converter(network, paddle_op, inputs):
281293
x = inputs[0]
@@ -285,3 +297,44 @@ def selu_converter(network, paddle_op, inputs):
285297
selu_layer.alpha = alpha
286298
selu_layer.beta = scale
287299
return selu_layer.get_output(0)
300+
301+
302+
@converter_registry.register("pd_op.prelu", trt_version="8.x")
303+
def prelu_converter(network, paddle_op, inputs):
304+
input, alpha_data = inputs
305+
input_dims = input.shape
306+
mode = paddle_op.attrs()["mode"]
307+
data_format = paddle_op.attrs().get("data_format", "NCHW")
308+
w_dims = trt.Dims(alpha_data.numpy().shape)
309+
trt_w_dims = w_dims
310+
alpha_tensor = network.add_constant(trt_w_dims, alpha_data).get_output(0)
311+
alpha_dims = alpha_tensor.shape
312+
real_alpha_tensor = alpha_tensor
313+
if len(alpha_dims) != len(input_dims):
314+
reshape_layer = network.add_shuffle(alpha_tensor)
315+
c = alpha_dims[0]
316+
n_tensor = add_1D_constant_layer(network, [1])
317+
c_tensor = add_1D_constant_layer(network, [c])
318+
hw_tensor = None
319+
if len(input_dims) - 2 > 0:
320+
hw_tensor = add_1D_constant_layer(
321+
network, [1] * (len(input_dims) - 2)
322+
)
323+
if data_format == "NCHW":
324+
if hw_tensor:
325+
shape_tensor = trt_concat(
326+
network, [n_tensor, c_tensor, hw_tensor]
327+
)
328+
else:
329+
shape_tensor = trt_concat(network, [n_tensor, c_tensor])
330+
else:
331+
if hw_tensor:
332+
shape_tensor = trt_concat(
333+
network, [n_tensor, hw_tensor, c_tensor]
334+
)
335+
else:
336+
shape_tensor = trt_concat(network, [n_tensor, c_tensor])
337+
reshape_layer.set_input(1, shape_tensor)
338+
real_alpha_tensor = reshape_layer.get_output(0)
339+
layer = network.add_parametric_relu(input, real_alpha_tensor)
340+
return layer.get_output(0)

python/paddle/tensorrt/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def weight_to_tensor(network, paddle_value, trt_tensor, use_op_name):
299299
"pd_op.depthwise_conv2d_transpose",
300300
"pd_op.fused_conv2d_add_act",
301301
"pd_op.affine_channel",
302+
"pd_op.prelu",
302303
"pd_op.fused_bias_dropout_residual_layer_norm",
303304
"pd_op.deformable_conv",
304305
]

test/tensorrt/test_converter_activation.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,9 +369,169 @@ def setUp(self):
369369
def test_trt_result(self):
370370
self.check_trt_result()
371371

372+
373+
class TestLeakyReluCas1TRTPattern(TensorRTBaseTest):
374+
def setUp(self):
375+
self.python_api = paddle.nn.functional.leaky_relu
376+
self.api_args = {
377+
"x": np.random.randn(2, 3).astype("float32"),
378+
"negative_slope": 0.5,
379+
}
380+
self.program_config = {"feed_list": ["x"]}
381+
self.min_shape = {"x": [1, 3]}
382+
self.opt_shape = {"x": [2, 3]}
383+
self.max_shape = {"x": [5, 3]}
384+
385+
def test_trt_result_fp16(self):
386+
self.check_trt_result(precision_mode="fp16")
387+
388+
def test_trt_result_fp32(self):
389+
self.check_trt_result()
390+
391+
392+
class TestLeakyReluCase2TRTPattern(TensorRTBaseTest):
393+
def setUp(self):
394+
self.python_api = paddle.nn.functional.leaky_relu
395+
self.api_args = {
396+
"x": np.random.randn(2, 3).astype("float32"),
397+
"negative_slope": -0.5,
398+
}
399+
self.program_config = {"feed_list": ["x"]}
400+
self.min_shape = {"x": [1, 3]}
401+
self.opt_shape = {"x": [2, 3]}
402+
self.max_shape = {"x": [5, 3]}
403+
404+
def test_trt_result_fp16(self):
405+
self.check_trt_result(precision_mode="fp16")
406+
407+
def test_trt_result_fp32(self):
408+
self.check_trt_result()
409+
410+
411+
class TestLeakyRelu_Cas1TRTPattern(TensorRTBaseTest):
412+
def setUp(self):
413+
self.python_api = paddle.nn.functional.leaky_relu_
414+
self.api_args = {
415+
"x": np.random.randn(2, 3).astype("float32"),
416+
"negative_slope": 0.5,
417+
}
418+
self.program_config = {"feed_list": ["x"]}
419+
self.min_shape = {"x": [1, 3]}
420+
self.opt_shape = {"x": [2, 3]}
421+
self.max_shape = {"x": [5, 3]}
422+
423+
def test_trt_result_fp16(self):
424+
self.check_trt_result(precision_mode="fp16")
425+
426+
def test_trt_result_fp32(self):
427+
self.check_trt_result()
428+
429+
430+
class TestLeakyRelu_Case2TRTPattern(TensorRTBaseTest):
431+
def setUp(self):
432+
self.python_api = paddle.nn.functional.leaky_relu_
433+
self.api_args = {
434+
"x": np.random.randn(2, 3).astype("float32"),
435+
"negative_slope": -0.5,
436+
}
437+
self.program_config = {"feed_list": ["x"]}
438+
self.min_shape = {"x": [1, 3]}
439+
self.opt_shape = {"x": [2, 3]}
440+
self.max_shape = {"x": [5, 3]}
441+
372442
def test_trt_result_fp16(self):
373443
self.check_trt_result(precision_mode="fp16")
374444

445+
def test_trt_result_fp32(self):
446+
self.check_trt_result()
447+
448+
449+
def prelu_wrapper(x, alpha_shape, data_format='NCHW'):
450+
alpha = paddle.create_parameter(
451+
shape=alpha_shape, dtype='float32', name="alpha"
452+
)
453+
return paddle.nn.functional.prelu(x, alpha, data_format)
454+
455+
456+
class TestPReluCase1TRTPattern(TensorRTBaseTest):
457+
def setUp(self):
458+
self.python_api = prelu_wrapper
459+
self.api_args = {
460+
"x": np.random.randn(2, 3).astype("float32"),
461+
"alpha_shape": [3],
462+
"data_format": "NCHW",
463+
}
464+
self.program_config = {"feed_list": ["x"]}
465+
self.min_shape = {"x": [1, 3]}
466+
self.opt_shape = {"x": [2, 3]}
467+
self.max_shape = {"x": [5, 3]}
468+
469+
def test_trt_result_fp16(self):
470+
self.check_trt_result(precision_mode="fp16")
471+
472+
def test_trt_result_fp32(self):
473+
self.check_trt_result()
474+
475+
476+
class TestPReluCase2TRTPattern(TensorRTBaseTest):
477+
def setUp(self):
478+
self.python_api = prelu_wrapper
479+
self.api_args = {
480+
"x": np.random.randn(2, 3).astype("float32"),
481+
"alpha_shape": [3],
482+
"data_format": "NHWC",
483+
}
484+
self.program_config = {"feed_list": ["x"]}
485+
self.min_shape = {"x": [1, 3]}
486+
self.opt_shape = {"x": [2, 3]}
487+
self.max_shape = {"x": [5, 3]}
488+
489+
def test_trt_result_fp16(self):
490+
self.check_trt_result(precision_mode="fp16")
491+
492+
def test_trt_result_fp32(self):
493+
self.check_trt_result()
494+
495+
496+
class TestPReluCase3TRTPattern(TensorRTBaseTest):
497+
def setUp(self):
498+
self.python_api = prelu_wrapper
499+
self.api_args = {
500+
"x": np.random.randn(2, 3, 3).astype("float32"),
501+
"alpha_shape": [3],
502+
"data_format": "NCHW",
503+
}
504+
self.program_config = {"feed_list": ["x"]}
505+
self.min_shape = {"x": [1, 3, 3]}
506+
self.opt_shape = {"x": [2, 3, 3]}
507+
self.max_shape = {"x": [5, 3, 3]}
508+
509+
def test_trt_result_fp16(self):
510+
self.check_trt_result(precision_mode="fp16")
511+
512+
def test_trt_result_fp32(self):
513+
self.check_trt_result()
514+
515+
516+
class TestPReluCase4TRTPattern(TensorRTBaseTest):
517+
def setUp(self):
518+
self.python_api = prelu_wrapper
519+
self.api_args = {
520+
"x": np.random.randn(2, 3, 3).astype("float32"),
521+
"alpha_shape": [3],
522+
"data_format": "NHWC",
523+
}
524+
self.program_config = {"feed_list": ["x"]}
525+
self.min_shape = {"x": [1, 3, 3]}
526+
self.opt_shape = {"x": [2, 3, 3]}
527+
self.max_shape = {"x": [5, 3, 3]}
528+
529+
def test_trt_result_fp16(self):
530+
self.check_trt_result(precision_mode="fp16")
531+
532+
def test_trt_result_fp32(self):
533+
self.check_trt_result()
534+
375535

376536
if __name__ == '__main__':
377537
unittest.main()

0 commit comments

Comments
 (0)