16
16
17
17
#include " paddle/fluid/framework/executor_gc_helper.h"
18
18
#include " paddle/fluid/framework/new_executor/data_transfer.h"
19
+ #include " paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
20
+ #include " paddle/fluid/operators/controlflow/recurrent_op_helper.h"
21
+ #include " paddle/fluid/operators/controlflow/while_op_helper.h"
19
22
20
23
namespace paddle {
21
24
namespace framework {
@@ -152,7 +155,7 @@ void build_variable_scope(const framework::BlockDesc& block,
152
155
}
153
156
154
157
void create_all_ops (const framework::BlockDesc& block,
155
- std::vector<std::shared_ptr <OperatorBase>>* ops) {
158
+ std::vector<std::unique_ptr <OperatorBase>>* ops) {
156
159
for (auto & op : block.AllOps ()) {
157
160
VLOG (3 ) << " CreateOp from : " << op->Type ();
158
161
@@ -167,7 +170,7 @@ void create_all_ops(const framework::BlockDesc& block,
167
170
}
168
171
auto op_base =
169
172
info.Creator ()(op->Type (), inputs_names, outputs_names, op_attr_map);
170
- ops->emplace_back (std::shared_ptr <OperatorBase>(op_base));
173
+ ops->emplace_back (std::unique_ptr <OperatorBase>(op_base));
171
174
}
172
175
}
173
176
@@ -263,10 +266,24 @@ void build_op_func_list(const platform::Place& place,
263
266
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope ()
264
267
: var_scope->GetMutableScope ();
265
268
auto & all_op_kernels = OperatorWithKernel::AllOpKernels ();
269
+ std::vector<std::unique_ptr<OperatorBase>>
270
+ ops_unique; // its elements will be moved to vec_func_list
271
+ // Step 1: create all ops for current block.
272
+ create_all_ops (block, &ops_unique);
273
+ // If gc is enabled and block size > 1
274
+ const ProgramDesc& main_program = *block.Program ();
275
+ operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp (
276
+ main_program, block.ID (), ops_unique);
277
+ operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp (
278
+ main_program, block.ID (), ops_unique);
279
+ operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp (
280
+ main_program, block.ID (), ops_unique);
281
+
266
282
std::vector<std::shared_ptr<OperatorBase>>
267
283
ops; // its elements will be moved to vec_func_list
268
- // Step 1: create all ops for current block.
269
- create_all_ops (block, &ops);
284
+ for (auto & op_unique : ops_unique) {
285
+ ops.emplace_back (std::move (op_unique));
286
+ }
270
287
auto unused_var_map = get_unused_vars (block, ops);
271
288
272
289
for (size_t i = 0 ; i < ops.size (); ++i) {
0 commit comments