Skip to content

Commit 17d6d93

Browse files
authored
[XPU]fuse small ops of idg models (PaddlePaddle#54245)
1 parent a087b9c commit 17d6d93

File tree

8 files changed

+420
-6
lines changed

8 files changed

+420
-6
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ if(WITH_XPU)
252252
xpu DEPS ${XPU_PASS_DEPS})
253253
pass_library(add_activation_xpu_fuse_pass inference DIR xpu DEPS
254254
${XPU_PASS_DEPS})
255+
pass_library(fold_interp_outsize_fuse_pass inference DIR xpu DEPS
256+
${XPU_PASS_DEPS})
255257
endif()
256258

257259
cc_library(
@@ -536,4 +538,8 @@ if(WITH_XPU)
536538
test_multi_encoder_xpu_adaptive_seqlen_fuse_pass
537539
SRCS xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass_test.cc
538540
DEPS multi_encoder_xpu_adaptive_seqlen_fuse_pass)
541+
cc_test(
542+
test_fold_interp_outsize_fuse_pass
543+
SRCS xpu/fold_interp_outsize_fuse_pass_test.cc
544+
DEPS fold_interp_outsize_fuse_pass)
539545
endif()

paddle/fluid/framework/ir/pass_tester_helper.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,20 +361,31 @@ struct Layers {
361361
return outs;
362362
}
363363

364-
std::vector<VarDesc*> split(VarDesc* x, int num_or_section, int axis = 0) {
365-
std::vector<VarDesc*> outs(num_or_section);
366-
for (int i = 0; i < num_or_section; i++) {
364+
std::vector<VarDesc*> split(VarDesc* x,
365+
int num_or_section = 0,
366+
int axis = 0,
367+
std::vector<int> sections = {-1}) {
368+
int out_num = num_or_section;
369+
if (num_or_section == 0) {
370+
out_num = sections.size();
371+
}
372+
std::vector<VarDesc*> outs(out_num);
373+
for (int i = 0; i < out_num; i++) {
367374
outs[i] = lod_tensor(unique_name());
368375
}
369-
std::vector<std::string> out_names(num_or_section);
370-
for (int i = 0; i < num_or_section; i++) {
376+
std::vector<std::string> out_names(out_num);
377+
for (int i = 0; i < out_num; i++) {
371378
out_names[i] = outs[i]->Name();
372379
}
373380
OpDesc* op = program_.MutableBlock(0)->AppendOp();
374381
op->SetType("split");
375382
op->SetInput("X", {x->Name()});
376383
op->SetOutput("Out", out_names);
377-
op->SetAttr("num_or_section", num_or_section);
384+
if (num_or_section == 0) {
385+
op->SetAttr("sections", sections);
386+
} else {
387+
op->SetAttr("num_or_section", num_or_section);
388+
}
378389
op->SetAttr("axis", axis);
379390
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
380391
static_cast<int>(OpRole::kForward));
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h"
16+
#include <string>
17+
18+
#include "glog/logging.h"
19+
20+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
21+
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
22+
#include "paddle/fluid/framework/op_version_registry.h"
23+
#include "paddle/fluid/platform/enforce.h"
24+
25+
namespace phi {
26+
class DenseTensor;
27+
} // namespace phi
28+
29+
namespace paddle {
30+
namespace framework {
31+
class Scope;
32+
} // namespace framework
33+
} // namespace paddle
34+
35+
namespace paddle {
36+
namespace framework {
37+
namespace ir {
38+
39+
namespace patterns {
40+
struct DetectorFusePattern : public PatternBase {
41+
DetectorFusePattern(PDPattern* pattern, const std::string& name_scope);
42+
43+
// declare operator node's name
44+
PATTERN_DECL_NODE(shape);
45+
PATTERN_DECL_NODE(cast1);
46+
PATTERN_DECL_NODE(slice);
47+
PATTERN_DECL_NODE(concat);
48+
PATTERN_DECL_NODE(split);
49+
PATTERN_DECL_NODE(cast2);
50+
PATTERN_DECL_NODE(bilinear_interp);
51+
// declare variable node's name
52+
PATTERN_DECL_NODE(x);
53+
PATTERN_DECL_NODE(shape_out);
54+
PATTERN_DECL_NODE(cast1_out);
55+
PATTERN_DECL_NODE(slice_out);
56+
PATTERN_DECL_NODE(concat_y);
57+
PATTERN_DECL_NODE(concat_out);
58+
PATTERN_DECL_NODE(split_out_0);
59+
PATTERN_DECL_NODE(split_out_1);
60+
PATTERN_DECL_NODE(cast2_out);
61+
};
62+
63+
DetectorFusePattern::DetectorFusePattern(PDPattern* pattern,
64+
const std::string& name_scope)
65+
: PatternBase(pattern, name_scope, name_scope) {
66+
auto* x = pattern->NewNode(x_repr())
67+
->assert_is_op_input("shape", "Input")
68+
->assert_is_op_input("bilinear_interp_v2", "X");
69+
auto* shape = pattern->NewNode(shape_repr())->assert_is_op("shape");
70+
auto* shape_out = pattern->NewNode(shape_out_repr())
71+
->assert_is_op_output("shape", "Out")
72+
->assert_is_op_input("cast", "X");
73+
shape->LinksFrom({x}).LinksTo({shape_out});
74+
auto* cast1 = pattern->NewNode(cast1_repr())
75+
->assert_is_op("cast")
76+
->assert_more([&](Node* node) {
77+
auto* op_desc = node->Op();
78+
return op_desc->GetAttrIfExists<int>("in_dtype") == 2 &&
79+
op_desc->GetAttrIfExists<int>("out_dtype") == 3;
80+
});
81+
auto* cast1_out = pattern->NewNode(cast1_out_repr())
82+
->assert_is_op_output("cast", "Out")
83+
->assert_is_op_input("slice", "Input");
84+
cast1->LinksFrom({shape_out}).LinksTo({cast1_out});
85+
auto* slice =
86+
pattern->NewNode(slice_repr())
87+
->assert_is_op("slice")
88+
->assert_more([&](Node* node) {
89+
auto* op_desc = node->Op();
90+
return op_desc->GetAttrIfExists<std::vector<int>>("axes") ==
91+
std::vector<int>{0} &&
92+
op_desc->GetAttrIfExists<std::vector<int>>("starts") ==
93+
std::vector<int>{0} &&
94+
op_desc->GetAttrIfExists<std::vector<int>>("ends") ==
95+
std::vector<int>{2};
96+
});
97+
auto* slice_out = pattern->NewNode(slice_out_repr())
98+
->assert_is_op_output("slice", "Out")
99+
->assert_is_op_nth_input("concat", "X", 0);
100+
slice->LinksFrom({cast1_out}).LinksTo({slice_out});
101+
auto* concat = pattern->NewNode(concat_repr())
102+
->assert_is_op("concat")
103+
->assert_more([&](Node* node) {
104+
auto* op_desc = node->Op();
105+
return op_desc->GetAttrIfExists<int>("axis") == 0;
106+
});
107+
auto* concat_y = pattern->NewNode(concat_y_repr())
108+
->assert_is_op_nth_input("concat", "X", 1)
109+
->assert_is_persistable_var();
110+
auto* concat_out = pattern->NewNode(concat_out_repr())
111+
->assert_is_op_output("concat", "Out")
112+
->assert_is_op_input("split", "X");
113+
concat->LinksFrom({slice_out, concat_y}).LinksTo({concat_out});
114+
auto* split = pattern->NewNode(split_repr())
115+
->assert_is_op("split")
116+
->assert_more([&](Node* node) {
117+
auto* op_desc = node->Op();
118+
return op_desc->GetAttrIfExists<int>("axis") == 0 &&
119+
(op_desc->GetAttrIfExists<std::vector<int>>(
120+
"sections") == std::vector<int>{2, 2} ||
121+
op_desc->GetAttrIfExists<int>("num") == 2);
122+
});
123+
auto* split_out_0 = pattern->NewNode(split_out_0_repr())
124+
->assert_is_op_nth_output("split", "Out", 0);
125+
auto* split_out_1 = pattern->NewNode(split_out_1_repr())
126+
->assert_is_op_nth_output("split", "Out", 1)
127+
->assert_is_op_input("cast", "X");
128+
split->LinksFrom({concat_out}).LinksTo({split_out_0, split_out_1});
129+
auto* cast2 = pattern->NewNode(cast2_repr())
130+
->assert_is_op("cast")
131+
->assert_more([&](Node* node) {
132+
auto* op_desc = node->Op();
133+
return op_desc->GetAttrIfExists<int>("in_dtype") == 3 &&
134+
op_desc->GetAttrIfExists<int>("out_dtype") == 2;
135+
});
136+
auto* cast2_out = pattern->NewNode(cast2_out_repr())
137+
->assert_is_op_output("cast", "Out")
138+
->assert_is_op_input("bilinear_interp_v2", "OutSize");
139+
cast2->LinksFrom({split_out_1}).LinksTo({cast2_out});
140+
auto* bilinear_interp = pattern->NewNode(bilinear_interp_repr())
141+
->assert_is_op("bilinear_interp_v2");
142+
bilinear_interp->LinksFrom({x, cast2_out});
143+
}
144+
145+
} // namespace patterns
146+
147+
void FoldInterpOutsizeFusePass::DetectorFuse(ir::Graph* graph) const {
148+
GraphPatternDetector gpd;
149+
patterns::DetectorFusePattern pattern(gpd.mutable_pattern(), name_scope_);
150+
int found_subgraph_count = 0;
151+
152+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
153+
Graph* graph) {
154+
VLOG(4) << "handle DetectorFuse";
155+
/* declare operator node's name */
156+
GET_IR_NODE(shape);
157+
GET_IR_NODE(cast1);
158+
GET_IR_NODE(slice);
159+
GET_IR_NODE(concat);
160+
GET_IR_NODE(split);
161+
GET_IR_NODE(cast2);
162+
GET_IR_NODE(bilinear_interp);
163+
/* declare variable node's name*/
164+
GET_IR_NODE(x);
165+
GET_IR_NODE(shape_out);
166+
GET_IR_NODE(cast1_out);
167+
GET_IR_NODE(slice_out);
168+
GET_IR_NODE(concat_y);
169+
GET_IR_NODE(concat_out);
170+
GET_IR_NODE(split_out_0);
171+
GET_IR_NODE(split_out_1);
172+
GET_IR_NODE(cast2_out);
173+
174+
auto* scope = param_scope();
175+
PADDLE_ENFORCE_NOT_NULL(
176+
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
177+
178+
auto* concat_y_t =
179+
scope->GetVar(concat_y->Name())->GetMutable<phi::DenseTensor>();
180+
// concat_y int64 --> int32
181+
auto tensor_type = concat_y_t->dtype();
182+
if (tensor_type == phi::DataType::INT64) {
183+
CastToInt32(concat_y_t, nullptr);
184+
}
185+
bilinear_interp->Op()->RenameInput(cast2_out->Name(), concat_y->Name());
186+
IR_NODE_UNLINK(x, shape);
187+
IR_NODE_UNLINK(cast2_out, bilinear_interp);
188+
IR_NODE_LINK_TO(concat_y, bilinear_interp);
189+
// delete useless node
190+
std::unordered_set<const Node*> delete_nodes = {shape,
191+
cast1,
192+
slice,
193+
concat,
194+
split,
195+
cast2,
196+
shape_out,
197+
cast1_out,
198+
slice_out,
199+
concat_out,
200+
split_out_0,
201+
split_out_1,
202+
cast2_out};
203+
GraphSafeRemoveNodes(graph, delete_nodes);
204+
found_subgraph_count++;
205+
};
206+
207+
gpd(graph, handler);
208+
AddStatis(found_subgraph_count);
209+
}
210+
211+
void FoldInterpOutsizeFusePass::ApplyImpl(ir::Graph* graph) const {
212+
PADDLE_ENFORCE_NOT_NULL(
213+
graph, platform::errors::PreconditionNotMet("graph should not be null."));
214+
Init(name_scope_, graph);
215+
216+
DetectorFuse(graph);
217+
}
218+
219+
} // namespace ir
220+
} // namespace framework
221+
} // namespace paddle
222+
223+
REGISTER_PASS(fold_interp_outsize_fuse_pass,
224+
paddle::framework::ir::FoldInterpOutsizeFusePass);
225+
226+
REGISTER_PASS_CAPABILITY(fold_interp_outsize_fuse_pass)
227+
.AddCombination(
228+
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
229+
"shape", 0));
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <string>
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
#include "paddle/fluid/framework/ir/pass.h"
19+
20+
namespace phi {
21+
class DenseTensor;
22+
} // namespace phi
23+
24+
namespace paddle {
25+
namespace framework {
26+
class Scope;
27+
} // namespace framework
28+
} // namespace paddle
29+
30+
namespace paddle {
31+
namespace framework {
32+
namespace ir {
33+
34+
class FoldInterpOutsizeFusePass : public FusePassBase {
35+
protected:
36+
void ApplyImpl(ir::Graph* graph) const override;
37+
38+
private:
39+
/*
40+
Origin subgraph:
41+
x
42+
/ \
43+
| shape
44+
| |
45+
| cast
46+
| |
47+
| slice
48+
| |
49+
| concat
50+
| |
51+
| split
52+
| | \
53+
| | \
54+
| outvar_1 outvar_0
55+
| |
56+
| cast
57+
| /
58+
\ /
59+
bilinear_interp_v2
60+
61+
Fused subgraph:
62+
x
63+
| concat_y
64+
| /
65+
bilinear_interp_v2
66+
*/
67+
void DetectorFuse(ir::Graph* graph) const;
68+
69+
const std::string name_scope_{"fold_interp_outsize_fuse_pass"};
70+
};
71+
72+
} // namespace ir
73+
} // namespace framework
74+
} // namespace paddle

0 commit comments

Comments
 (0)