Skip to content

Commit 4f21985

Browse files
authored
[Paddle TensorRT] add pd_op.split_with_num and pd_op.split converter (#68608)
* pd_op.split_with_num * split * 修改几个适配add_elementwise_layer * 多写了一个函数get_shape_with_dynamic_shape * fix * pd_op.split * 添加单测 * ci required显示不出来,重新提commit * 修改split_with_num * ci一直不动
1 parent eb514a6 commit 4f21985

File tree

7 files changed

+504
-59
lines changed

7 files changed

+504
-59
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 56 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ class Pool2dOpPattern
8989
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
9090
return false;
9191
}
92+
paddle::dialect::FullIntArrayOp full_int_array_op =
93+
pir::GetDefiningOpForInput(op, 1)
94+
->dyn_cast<paddle::dialect::FullIntArrayOp>();
95+
if (!full_int_array_op) {
96+
VLOG(3) << "Cannot find FullIntArrayOp";
97+
return false;
98+
}
9299
auto padding_attr = op->attribute<pir::ArrayAttribute>("paddings");
93100
std::vector<int32_t> paddings;
94101
for (const auto &attr : padding_attr.AsVector()) {
@@ -122,28 +129,19 @@ class Pool2dOpPattern
122129
if (!op->attribute<pir::BoolAttribute>("global_pooling").data()) {
123130
if (op->HasAttribute("exclusive")) {
124131
if (op->attribute<pir::BoolAttribute>("exclusive").data()) {
125-
paddle::dialect::FullIntArrayOp full_int_array_op =
126-
pir::GetDefiningOpForInput(op, 1)
127-
->dyn_cast<paddle::dialect::FullIntArrayOp>();
128-
if (!full_int_array_op) {
129-
VLOG(3) << "Cannot find FullIntArrayOp";
130-
return false;
131-
} else {
132-
auto attr_value =
133-
full_int_array_op->attribute<pir::ArrayAttribute>(
134-
"value");
135-
std::vector<int64_t> kernel_size;
136-
for (const auto &attr : attr_value.AsVector()) {
137-
kernel_size.push_back(
138-
attr.dyn_cast<pir::Int64Attribute>().data());
139-
}
140-
for (size_t i = 0; i < kernel_size.size(); ++i) {
141-
if (kernel_size[i] <= paddings[i]) {
142-
VLOG(3) << "the padding size should be less than the "
143-
"filter size "
144-
"for exclusive-counting pooling.";
145-
return false;
146-
}
132+
auto attr_value =
133+
full_int_array_op->attribute<pir::ArrayAttribute>("value");
134+
std::vector<int64_t> kernel_size;
135+
for (const auto &attr : attr_value.AsVector()) {
136+
kernel_size.push_back(
137+
attr.dyn_cast<pir::Int64Attribute>().data());
138+
}
139+
for (size_t i = 0; i < kernel_size.size(); ++i) {
140+
if (kernel_size[i] <= paddings[i]) {
141+
VLOG(3) << "the padding size should be less than the "
142+
"filter size "
143+
"for exclusive-counting pooling.";
144+
return false;
147145
}
148146
}
149147
}
@@ -796,42 +794,42 @@ class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
796794
return false;
797795
}
798796

799-
paddle::dialect::FullOp full_op =
800-
pir::GetDefiningOpForInput(op, 2)->dyn_cast<paddle::dialect::FullOp>();
801-
if (!full_op) {
802-
VLOG(3) << "Can not find full op";
797+
pir::Value axis_tensor = op.operand_source(2);
798+
if (!axis_tensor) {
799+
VLOG(3) << "pd_op.split can not find axis input";
803800
return false;
804-
} else {
801+
}
802+
auto out_vector_type = op.result(0).type().dyn_cast<pir::VectorType>();
803+
if (pir::GetDefiningOpForInput(op, 2)->isa<paddle::dialect::FullOp>()) {
804+
paddle::dialect::FullOp full_op =
805+
pir::GetDefiningOpForInput(op, 2)
806+
->dyn_cast<paddle::dialect::FullOp>();
805807
auto axis = full_op->attribute<paddle::dialect::ScalarAttribute>("value")
806808
.data()
807809
.to<int>();
808810
auto x_shape = op.operand_source(0)
809811
.type()
810812
.dyn_cast<paddle::dialect::DenseTensorType>()
811813
.dims();
812-
auto out_vector_type = op.result(0).type().dyn_cast<pir::VectorType>();
813814

814-
paddle::dialect::FullIntArrayOp full_sections_op =
815-
pir::GetDefiningOpForInput(op, 1)
816-
->dyn_cast<paddle::dialect::FullIntArrayOp>();
817-
if (!full_sections_op) {
818-
VLOG(3) << "Can not find FullIntArrayOp";
815+
axis += (axis < 0) ? x_shape.size() : 0;
816+
817+
if (x_shape[axis] == -1) {
818+
VLOG(3) << "The (" << axis << ") dim of input should not be -1";
819819
return false;
820820
}
821+
}
821822

823+
if (pir::GetDefiningOpForInput(op, 1)
824+
->isa<paddle::dialect::FullIntArrayOp>()) {
825+
paddle::dialect::FullIntArrayOp full_sections_op =
826+
pir::GetDefiningOpForInput(op, 1)
827+
->dyn_cast<paddle::dialect::FullIntArrayOp>();
822828
auto sections = full_sections_op->attribute<pir::ArrayAttribute>("value");
823-
824829
std::vector<int64_t> output_lengths;
825830
for (const auto &attr : sections.AsVector()) {
826831
output_lengths.push_back(attr.dyn_cast<pir::Int64Attribute>().data());
827832
}
828-
axis += (axis < 0) ? x_shape.size() : 0;
829-
830-
if (x_shape[axis] == -1) {
831-
VLOG(3) << "The (" << axis << ") dim of input should not be -1";
832-
return false;
833-
}
834-
835833
if (output_lengths.size() != out_vector_type.size()) {
836834
VLOG(3) << "The output_length should be equal to the output size.";
837835
return false;
@@ -853,33 +851,38 @@ class SplitWithNumOpPattern
853851
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
854852
return false;
855853
}
856-
paddle::dialect::FullOp full_op =
857-
pir::GetDefiningOpForInput(op, 1)->dyn_cast<paddle::dialect::FullOp>();
858-
if (!full_op) {
859-
VLOG(3) << "Can not find full op";
854+
855+
pir::Value axis_tensor = op.operand_source(1);
856+
if (!axis_tensor) {
857+
VLOG(3) << "pd_op.split_with_num can not find axis input";
860858
return false;
861-
} else {
862-
auto axis = full_op->attribute<paddle::dialect::ScalarAttribute>("value")
859+
}
860+
if (pir::GetDefiningOpForInput(op, 1)
861+
->isa<paddle::dialect::FullIntArrayOp>()) {
862+
paddle::dialect::FullIntArrayOp full_int_array_op =
863+
pir::GetDefiningOpForInput(op, 1)
864+
->dyn_cast<paddle::dialect::FullIntArrayOp>();
865+
auto axis = full_int_array_op
866+
->attribute<paddle::dialect::ScalarAttribute>("value")
863867
.data()
864868
.to<int>();
865869
auto x_shape = op.operand_source(0)
866870
.type()
867871
.dyn_cast<paddle::dialect::DenseTensorType>()
868872
.dims();
869-
auto out_vector_type = op.result(0).type().dyn_cast<pir::VectorType>();
870873

871874
axis += (axis < 0) ? x_shape.size() : 0;
872875
if (x_shape[axis] == -1) {
873876
VLOG(3) << "The (" << axis << ") dim of input should not be -1";
874877
return false;
875878
}
876-
877879
if (!op->HasAttribute("num")) {
878880
VLOG(3) << "split_with_num op must has num attributes";
879881
return false;
880882
}
881883
int num = op->attribute<pir::Int32Attribute>("num").data();
882884
std::vector<int64_t> output_lengths;
885+
883886
if (num > 0) {
884887
int64_t in_axis_dim = x_shape[axis];
885888
if (in_axis_dim % num != 0) {
@@ -893,14 +896,15 @@ class SplitWithNumOpPattern
893896
output_lengths.push_back(out_axis_dim);
894897
}
895898
}
896-
899+
auto out_vector_type = op.result(0).type().dyn_cast<pir::VectorType>();
897900
if (out_vector_type.size() != output_lengths.size()) {
898901
VLOG(3) << "The output_length should be equal to the output size.";
899902
return false;
900903
}
901-
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
902-
return true;
903904
}
905+
906+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
907+
return true;
904908
}
905909
};
906910
class GreaterEqualOpPattern

python/paddle/tensorrt/converter.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ def convert_subgraph_to_trt(self, program, group_op):
173173
value_to_trt_tensor[value.id] = input_tensor
174174

175175
for op in operations:
176+
# Adding marker labels to builtin ops facilitates convert processing, but they ultimately do not enter the TensorRT subgraph.
177+
if op.name() == "builtin.split":
178+
continue
176179
operands = []
177180
for operand in op.operands():
178181
source = operand.source()
@@ -205,7 +208,18 @@ def convert_subgraph_to_trt(self, program, group_op):
205208

206209
trt_outs = self.convert(network, op, operands)
207210

211+
results = []
212+
208213
for idx, result in enumerate(op.results()):
214+
if result.is_combine():
215+
used_ops = result.all_used_ops()
216+
for use_op in used_ops:
217+
if use_op.name() == "builtin.split":
218+
split_outputs = use_op.results()
219+
results.extend(split_outputs)
220+
else:
221+
results.append(result)
222+
for idx, result in enumerate(results):
209223
if idx < len(trt_outs):
210224
value_to_trt_tensor[result.id] = trt_outs[idx]
211225
else:
@@ -409,14 +423,10 @@ def convert(self, network, paddle_op, inputs):
409423
f"Converter for {op_name} not implemented."
410424
)
411425
outs = converter_func(network, paddle_op, inputs)
412-
if isinstance(outs, tuple):
413-
return outs
414-
elif isinstance(outs, trt.ITensor):
426+
if isinstance(outs, trt.ITensor):
415427
return (outs,)
416428
else:
417-
raise TypeError(
418-
f"Expected outputs to be a tuple or ITensor, but got {type(outs)}"
419-
)
429+
return outs
420430

421431
def convert_program_to_trt(self):
422432
for op in self.program.global_block().ops:

python/paddle/tensorrt/converter_utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,11 @@ def get_shape_tensor_element(network, x, index):
213213
return gather_layer.get_output(0)
214214

215215

216+
def trt_less(network, a, b):
217+
layer = network.add_elementwise(a, b, trt.ElementWiseOperation.LESS)
218+
return layer.get_output(0)
219+
220+
216221
def trt_sum(network, a, b):
217222
layer = network.add_elementwise(a, b, trt.ElementWiseOperation.SUM)
218223
return layer.get_output(0)
@@ -231,3 +236,89 @@ def trt_sub(network, a, b):
231236
def trt_min(network, a, b):
232237
layer = network.add_elementwise(a, b, trt.ElementWiseOperation.MIN)
233238
return layer.get_output(0)
239+
240+
241+
def trt_mul(network, a, b):
242+
layer = network.add_elementwise(a, b, trt.ElementWiseOperation.PROD)
243+
return layer.get_output(0)
244+
245+
246+
def trt_div(network, a, b):
247+
layer = network.add_elementwise(a, b, trt.ElementWiseOperation.DIV)
248+
return layer.get_output(0)
249+
250+
251+
def trt_floor_div(network, a, b):
252+
layer = network.add_elementwise(a, b, trt.ElementWiseOperation.FLOOR_DIV)
253+
return layer.get_output(0)
254+
255+
256+
def trt_equal(network, a, b):
257+
layer = network.add_elementwise(a, b, trt.ElementWiseOperation.EQUAL)
258+
return layer.get_output(0)
259+
260+
261+
def cast_tensor(network, input_tensor, dtype):
262+
layer = network.add_identity(input_tensor)
263+
layer.set_output_type(0, dtype)
264+
return layer.get_output(0)
265+
266+
267+
def build_start_tensor(network, rank, axis_tensor, offset):
268+
# Create indices_tensor [0, 1, ..., rank-1]
269+
indices = np.arange(rank, dtype=np.int32)
270+
indices_tensor = network.add_constant([rank], indices).get_output(0)
271+
272+
# Create mask: mask = (indices == axis_tensor)
273+
mask = network.add_elementwise(
274+
indices_tensor, axis_tensor, trt.ElementWiseOperation.EQUAL
275+
).get_output(0)
276+
mask_int = cast_tensor(network, mask, trt.int32)
277+
278+
# Calculate start_tensor = mask_int * offset
279+
start_tensor = network.add_elementwise(
280+
mask_int, offset, trt.ElementWiseOperation.PROD
281+
).get_output(0)
282+
283+
return start_tensor
284+
285+
286+
def build_size_tensor(
287+
network, rank, axis_tensor, size_value, input_shape_tensor
288+
):
289+
# Create indices_tensor [0, 1, ..., rank-1]
290+
indices = np.arange(rank, dtype=np.int32)
291+
indices_tensor = network.add_constant([rank], indices).get_output(0)
292+
293+
# Create mask: mask = (indices == axis_tensor)
294+
mask = network.add_elementwise(
295+
indices_tensor, axis_tensor, trt.ElementWiseOperation.EQUAL
296+
).get_output(0)
297+
mask_int = cast_tensor(network, mask, trt.int32)
298+
299+
# Create ones_tensor
300+
ones_tensor = network.add_constant(
301+
[rank], np.ones([rank], dtype=np.int32)
302+
).get_output(0)
303+
304+
# Calculate inverse_mask = ones_tensor - mask_int
305+
inverse_mask = network.add_elementwise(
306+
ones_tensor, mask_int, trt.ElementWiseOperation.SUB
307+
).get_output(0)
308+
309+
# Calculate size_tensor = mask_int * size_value + inverse_mask * input_shape_tensor
310+
size_value_broadcast = network.add_elementwise(
311+
mask_int, size_value, trt.ElementWiseOperation.PROD
312+
).get_output(0)
313+
314+
input_shape_broadcast = network.add_elementwise(
315+
inverse_mask, input_shape_tensor, trt.ElementWiseOperation.PROD
316+
).get_output(0)
317+
318+
size_tensor = network.add_elementwise(
319+
size_value_broadcast,
320+
input_shape_broadcast,
321+
trt.ElementWiseOperation.SUM,
322+
).get_output(0)
323+
324+
return size_tensor

0 commit comments

Comments
 (0)