Skip to content

Commit 6208375

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into backward_TensorCore
2 parents 30f443e + cf12ea5 commit 6208375

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+2163
-340
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,13 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
134134
modify_op_lock_and_record_event_pass
135135
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
136136
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
137-
sync_batch_norm_pass runtime_context_cache_pass)
137+
sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass)
138138
if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM))
139139
set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass)
140140
endif()
141141
cc_library(build_strategy SRCS build_strategy.cc DEPS pass_builder ${IR_PASS_DEPS})
142+
cc_test(build_strategy_test SRCS build_strategy_test.cc
143+
DEPS build_strategy op_registry op_proto_maker graph)
142144

143145
if (WITH_MKLDNN)
144146
target_link_libraries(build_strategy mkldnn_placement_pass)

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
2121

2222
DECLARE_bool(use_mkldnn);
23+
DECLARE_bool(convert_all_blocks);
2324

2425
namespace paddle {
2526
namespace framework {
@@ -312,6 +313,11 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
312313
DeviceType use_device) const {
313314
#endif
314315
VLOG(1) << "apply all passes";
316+
if (FLAGS_convert_all_blocks) {
317+
PADDLE_ENFORCE_EQ(
318+
graph->IsMainGraph(), true,
319+
platform::errors::InvalidArgument("This graph is not main_graph"));
320+
}
315321
// Create a default one if not finalized by user.
316322
CreatePassesFromStrategy(false);
317323

@@ -432,7 +438,14 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
432438
}
433439
}
434440
VLOG(1) << "Start Apply Pass " << pass->Type();
435-
graph = pass->Apply(graph);
441+
if (FLAGS_convert_all_blocks) {
442+
for (size_t i = 0; i < graph->SubGraphsSize(); ++i) {
443+
VLOG(3) << "Apply Pass " << pass->Type() << "to SubGraph " << i;
444+
pass->Apply(graph->GetSubGraph(i));
445+
}
446+
} else {
447+
graph = pass->Apply(graph);
448+
}
436449
VLOG(1) << "Finish Apply Pass " << pass->Type();
437450
}
438451
VLOG(1) << "All Passes Applied";
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <algorithm>
16+
#include <memory>
17+
#include <string>
18+
#include <unordered_map>
19+
#include <unordered_set>
20+
#include <vector>
21+
22+
#include "gtest/gtest-message.h"
23+
#include "gtest/gtest-test-part.h"
24+
#include "gtest/gtest.h"
25+
#include "gtest/gtest_pred_impl.h"
26+
27+
#include "paddle/fluid/framework/details/build_strategy.h"
28+
#include "paddle/fluid/framework/op_proto_maker.h"
29+
#include "paddle/fluid/framework/op_registry.h"
30+
#include "paddle/fluid/framework/var_type_inference.h"
31+
#include "paddle/fluid/platform/place.h"
32+
33+
DECLARE_bool(convert_all_blocks);
34+
35+
namespace paddle {
36+
namespace framework {
37+
38+
class SumOpMaker : public OpProtoAndCheckerMaker {
39+
public:
40+
void Make() {
41+
AddInput("X", "").AsDuplicable();
42+
AddOutput("Out", "").AsDuplicable();
43+
AddComment("");
44+
}
45+
};
46+
47+
class SumOpWithKernel : public OperatorWithKernel {
48+
public:
49+
using OperatorWithKernel::OperatorWithKernel;
50+
51+
protected:
52+
void InferShape(framework::InferShapeContext *ctx) const override {}
53+
OpKernelType GetExpectedKernelType(
54+
const ExecutionContext &ctx) const override {
55+
return OpKernelType(proto::VarType::FP32, ctx.Input<Tensor>("X")->place());
56+
}
57+
};
58+
59+
} // namespace framework
60+
} // namespace paddle
61+
62+
REGISTER_OP_WITHOUT_GRADIENT(sum, paddle::framework::SumOpWithKernel,
63+
paddle::framework::SumOpMaker);
64+
65+
namespace paddle {
66+
namespace framework {
67+
namespace details {
68+
69+
static std::vector<platform::Place> CreatePlaces(size_t num, bool use_cuda) {
70+
std::vector<platform::Place> result;
71+
result.reserve(num);
72+
for (size_t i = 0; i < num; ++i) {
73+
if (use_cuda) {
74+
result.emplace_back(platform::CUDAPlace(i));
75+
} else {
76+
result.emplace_back(platform::CPUPlace());
77+
}
78+
}
79+
return result;
80+
}
81+
82+
void BuildStrategyApply(BuildStrategy *build_strategy, ir::Graph *graph) {
83+
std::string loss_name = "";
84+
Scope scope;
85+
std::vector<Scope *> scopes = {&scope};
86+
87+
auto places = CreatePlaces(1, false);
88+
auto device = platform::Place2DeviceType(places[0]);
89+
90+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
91+
platform::NCCLCommunicator ctxs;
92+
#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL)
93+
platform::BKCLCommunicator ctxs;
94+
#endif
95+
96+
build_strategy->Apply(graph, places, loss_name, scopes, 1,
97+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
98+
device, &ctxs);
99+
#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL)
100+
device, &ctxs);
101+
#else
102+
device);
103+
#endif
104+
}
105+
106+
std::unique_ptr<ir::Graph> CreateGraph() {
107+
ProgramDesc prog;
108+
auto *op = prog.MutableBlock(0)->AppendOp();
109+
op->SetType("sum");
110+
op->SetInput("X", {"a1"});
111+
op->SetOutput("Out", {"b1"});
112+
op->SetAttr("op_role", 1);
113+
114+
prog.MutableBlock(0)->Var("a1")->SetType(proto::VarType::LOD_TENSOR);
115+
prog.MutableBlock(0)->Var("b1")->SetType(proto::VarType::LOD_TENSOR);
116+
117+
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
118+
return g;
119+
}
120+
121+
std::unique_ptr<ir::Graph> CreateMultiGraph() {
122+
ProgramDesc prog;
123+
prog.AppendBlock(prog.Block(0));
124+
prog.AppendBlock(prog.Block(0));
125+
126+
// Set contents in block_0.
127+
auto *op = prog.MutableBlock(0)->AppendOp();
128+
op->SetType("sum");
129+
op->SetInput("X", {"test_a", "test_b", "test_c"});
130+
op->SetOutput("Out", {"test_out"});
131+
op->SetAttr("op_role", 1);
132+
133+
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS);
134+
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS);
135+
prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS);
136+
prog.MutableBlock(0)->Var("test_out");
137+
op->InferVarType(prog.MutableBlock(0));
138+
139+
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR);
140+
op->InferVarType(prog.MutableBlock(0));
141+
142+
// Set contents in block_1.
143+
op = prog.MutableBlock(1)->AppendOp();
144+
op->SetType("sum");
145+
op->SetInput("X", {"a1"});
146+
op->SetOutput("Out", {"b1"});
147+
op->SetAttr("op_role", 1);
148+
149+
prog.MutableBlock(1)->Var("a1")->SetType(proto::VarType::LOD_TENSOR);
150+
prog.MutableBlock(1)->Var("b1")->SetType(proto::VarType::LOD_TENSOR);
151+
152+
// Set contents in block_2.
153+
op = prog.MutableBlock(2)->AppendOp();
154+
op->SetType("sum");
155+
op->SetInput("X", {"a2"});
156+
op->SetOutput("Out", {"b2"});
157+
op->SetAttr("op_role", 1);
158+
159+
prog.MutableBlock(2)->Var("a2")->SetType(proto::VarType::LOD_TENSOR);
160+
prog.MutableBlock(2)->Var("b2")->SetType(proto::VarType::LOD_TENSOR);
161+
162+
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
163+
return g;
164+
}
165+
166+
inline bool CheckSubGraphSame(ir::Graph *g1, ir::Graph *g2) {
167+
const auto &g1_nodes_set = g1->Nodes();
168+
const auto &g2_nodes_set = g2->Nodes();
169+
170+
if (g1_nodes_set.size() != g2_nodes_set.size()) return false;
171+
172+
std::vector<ir::Node *> g1_nodes(g1_nodes_set.begin(), g1_nodes_set.end());
173+
std::vector<ir::Node *> g2_nodes(g2_nodes_set.begin(), g2_nodes_set.end());
174+
175+
auto comp = [](ir::Node *n1, ir::Node *n2) {
176+
return n1->Name() > n2->Name();
177+
};
178+
std::stable_sort(g1_nodes.begin(), g1_nodes.end(), comp);
179+
std::stable_sort(g2_nodes.begin(), g2_nodes.end(), comp);
180+
181+
for (size_t i = 0; i < g1_nodes.size(); ++i) {
182+
const auto &n1 = g1_nodes[i];
183+
const auto &n2 = g2_nodes[i];
184+
185+
if (n1->NodeType() != n2->NodeType()) return false;
186+
if (n1->Name() != n2->Name()) return false;
187+
188+
auto n1_inputs = n1->inputs;
189+
auto n2_inputs = n2->inputs;
190+
if (n1_inputs.size() != n2_inputs.size()) return false;
191+
192+
std::stable_sort(n1_inputs.begin(), n1_inputs.end(), comp);
193+
std::stable_sort(n2_inputs.begin(), n2_inputs.end(), comp);
194+
for (size_t i = 0; i < n1_inputs.size(); ++i) {
195+
if (n1_inputs[i]->Name() != n2_inputs[i]->Name()) return false;
196+
}
197+
198+
auto n1_outputs = n1->outputs;
199+
auto n2_outputs = n2->outputs;
200+
if (n1_outputs.size() != n2_outputs.size()) return false;
201+
202+
std::stable_sort(n1_outputs.begin(), n1_outputs.end(), comp);
203+
std::stable_sort(n2_outputs.begin(), n2_outputs.end(), comp);
204+
for (size_t i = 0; i < n1_outputs.size(); ++i) {
205+
if (n1_outputs[i]->Name() != n2_outputs[i]->Name()) return false;
206+
}
207+
208+
if (n1->IsVar()) {
209+
const auto &var1 = n1->Var();
210+
const auto &var2 = n2->Var();
211+
if ((var1 == nullptr) != (var2 == nullptr)) return false;
212+
}
213+
214+
if (n1->IsOp()) {
215+
const auto &op1 = n1->Op();
216+
const auto &op2 = n2->Op();
217+
if ((op1 == nullptr) != (op2 == nullptr)) return false;
218+
219+
const auto &op1_input = op1->InputNames();
220+
const auto &op2_input = op2->InputNames();
221+
if (op1_input.size() != op2_input.size()) return false;
222+
if (op1_input != op2_input) return false;
223+
224+
for (size_t i = 0; i < op1_input.size(); ++i) {
225+
if (op1->Input(op1_input[i]) != op2->Input(op2_input[i])) return false;
226+
}
227+
228+
const auto &op1_output = op1->OutputNames();
229+
const auto &op2_output = op2->OutputNames();
230+
if (op1_output.size() != op2_output.size()) return false;
231+
if (op1_output != op2_output) return false;
232+
233+
for (size_t i = 0; i < op1_output.size(); ++i) {
234+
if (op1->Output(op1_output[i]) != op2->Output(op2_output[i]))
235+
return false;
236+
}
237+
}
238+
}
239+
return true;
240+
}
241+
242+
inline bool CheckGraphSame(ir::Graph *g1, ir::Graph *g2) {
243+
if (g1 == nullptr || g2 == nullptr) return true;
244+
245+
if (FLAGS_convert_all_blocks) {
246+
if (g1->SubGraphsSize() != g2->SubGraphsSize()) return false;
247+
248+
for (size_t i = 0; i < g1->SubGraphsSize(); ++i) {
249+
if (!CheckSubGraphSame(g1->GetSubGraph(i), g2->GetSubGraph(i)))
250+
return false;
251+
}
252+
} else {
253+
if (!CheckSubGraphSame(g1, g2)) return false;
254+
}
255+
return true;
256+
}
257+
258+
TEST(BuildStrategy, Basic) {
259+
BuildStrategy build_strategy;
260+
261+
ProgramDesc prog;
262+
ir::Graph old_graph(prog), graph(prog);
263+
264+
BuildStrategyApply(&build_strategy, &graph);
265+
266+
ASSERT_TRUE(CheckGraphSame(&old_graph, &graph));
267+
}
268+
269+
TEST(BuildStrategy, TestSingleGraph) {
270+
BuildStrategy build_strategy;
271+
auto graph = CreateGraph();
272+
ir::Graph old_graph(graph->OriginProgram());
273+
274+
BuildStrategyApply(&build_strategy, graph.get());
275+
276+
// graph should not change for no pass here
277+
ASSERT_TRUE(CheckGraphSame(&old_graph, graph.get()));
278+
}
279+
280+
TEST(BuildStrategy, TestMultiGraph) {
281+
// Set FLAGS_convert_all_blocks to true to make sure this test works.
282+
bool flag_temp = FLAGS_convert_all_blocks;
283+
FLAGS_convert_all_blocks = true;
284+
285+
BuildStrategy build_strategy;
286+
auto graph = CreateMultiGraph();
287+
ir::Graph old_graph(graph->OriginProgram());
288+
289+
BuildStrategyApply(&build_strategy, graph.get());
290+
291+
// graph should not change for no pass here
292+
ASSERT_TRUE(CheckGraphSame(&old_graph, graph.get()));
293+
294+
// Recover FLAGS_convert_all_blocks.
295+
FLAGS_convert_all_blocks = flag_temp;
296+
}
297+
298+
} // namespace details
299+
} // namespace framework
300+
} // namespace paddle

paddle/fluid/framework/framework.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ message VarDesc {
173173
// True if the variable is an input data and
174174
// have to check the feed data shape and dtype
175175
optional bool need_check_feed = 4 [ default = false ];
176+
optional bool is_parameter = 5 [ default = false ];
177+
optional bool stop_gradient = 6 [ default = false ];
176178
}
177179

178180
message BlockDesc {

0 commit comments

Comments
 (0)