Skip to content

Commit cbdee1c

Browse files
committed
1. fix control flow cases 2. fix calc_gradient
1 parent c4fe699 commit cbdee1c

File tree

2 files changed

+35
-16
lines changed

2 files changed

+35
-16
lines changed

paddle/fluid/framework/new_executor/interpretercore_util.cc

+21-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
#include "paddle/fluid/framework/executor_gc_helper.h"
1818
#include "paddle/fluid/framework/new_executor/data_transfer.h"
19+
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
20+
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
21+
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
1922

2023
namespace paddle {
2124
namespace framework {
@@ -152,7 +155,7 @@ void build_variable_scope(const framework::BlockDesc& block,
152155
}
153156

154157
void create_all_ops(const framework::BlockDesc& block,
155-
std::vector<std::shared_ptr<OperatorBase>>* ops) {
158+
std::vector<std::unique_ptr<OperatorBase>>* ops) {
156159
for (auto& op : block.AllOps()) {
157160
VLOG(3) << "CreateOp from : " << op->Type();
158161

@@ -167,7 +170,7 @@ void create_all_ops(const framework::BlockDesc& block,
167170
}
168171
auto op_base =
169172
info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
170-
ops->emplace_back(std::shared_ptr<OperatorBase>(op_base));
173+
ops->emplace_back(std::unique_ptr<OperatorBase>(op_base));
171174
}
172175
}
173176

@@ -263,10 +266,24 @@ void build_op_func_list(const platform::Place& place,
263266
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
264267
: var_scope->GetMutableScope();
265268
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
269+
std::vector<std::unique_ptr<OperatorBase>>
270+
ops_unique; // its elements will be moved to vec_func_list
271+
// Step 1: create all ops for current block.
272+
create_all_ops(block, &ops_unique);
273+
// If gc is enabled and block size > 1
274+
const ProgramDesc& main_program = *block.Program();
275+
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
276+
main_program, block.ID(), ops_unique);
277+
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
278+
main_program, block.ID(), ops_unique);
279+
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
280+
main_program, block.ID(), ops_unique);
281+
266282
std::vector<std::shared_ptr<OperatorBase>>
267283
ops; // its elements will be moved to vec_func_list
268-
// Step 1: create all ops for current block.
269-
create_all_ops(block, &ops);
284+
for (auto& op_unique : ops_unique) {
285+
ops.emplace_back(std::move(op_unique));
286+
}
270287
auto unused_var_map = get_unused_vars(block, ops);
271288

272289
for (size_t i = 0; i < ops.size(); ++i) {

python/paddle/fluid/tests/unittests/test_calc_gradient.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import print_function
1616

17+
import paddle
1718
import unittest
1819
import numpy as np
1920
import paddle.fluid as fluid
@@ -83,19 +84,20 @@ def test2(self):
8384

8485
class TestGradientWithPrune(unittest.TestCase):
8586
def test_prune(self):
86-
x = fluid.data(name='x', shape=[3], dtype='float32')
87-
x.stop_gradient = False
88-
x1, x2, x3 = fluid.layers.split(x, dim=0, num_or_sections=3)
89-
y = x1 * 2
90-
x1_grad = fluid.gradients(y, x)
87+
with paddle.fluid.scope_guard(paddle.static.Scope()):
88+
x = fluid.data(name='x', shape=[3], dtype='float32')
89+
x.stop_gradient = False
90+
x1, x2, x3 = fluid.layers.split(x, dim=0, num_or_sections=3)
91+
y = x1 * 2
92+
x1_grad = fluid.gradients(y, x)
9193

92-
exe = fluid.Executor(fluid.CPUPlace())
93-
main = fluid.default_main_program()
94-
exe.run(fluid.default_startup_program())
95-
out = exe.run(main,
96-
feed={'x': np.ones([3]).astype('float32')},
97-
fetch_list=[x1_grad])
98-
self.assertTrue(np.array_equal(out[0], [2., 0., 0.]))
94+
exe = fluid.Executor(fluid.CPUPlace())
95+
main = fluid.default_main_program()
96+
exe.run(fluid.default_startup_program())
97+
out = exe.run(main,
98+
feed={'x': np.ones([3]).astype('float32')},
99+
fetch_list=[x1_grad])
100+
self.assertTrue(np.array_equal(out[0], [2., 0., 0.]))
99101

100102

101103
if __name__ == "__main__":

0 commit comments

Comments
 (0)