Skip to content

Commit 3e89748

Browse files
authored
[Dy2stat]fix no_grad context error in dy2stat (#35725)
* fix no_grad context error in dy2stat * remove useless comments * fix error by drop_kids in python * add test and fix review
1 parent b666fd3 commit 3e89748

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

paddle/fluid/pybind/pybind.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,17 @@ All parameter, weight, gradient are variables in Paddle.
12411241
return self.GetMutable<framework::ReaderHolder>();
12421242
},
12431243
py::return_value_policy::reference)
1244+
.def("get_scope",
1245+
[](Variable &self) -> Scope * {
1246+
auto scope_vec =
1247+
self.GetMutable<std::vector<framework::Scope *>>();
1248+
PADDLE_ENFORCE_GT(
1249+
scope_vec->size(), 0,
1250+
platform::errors::InvalidArgument(
1251+
"The size of scope_vec should be greater than 0"));
1252+
return scope_vec->front();
1253+
},
1254+
py::return_value_policy::reference)
12441255
.def("set_scope", [](Variable &self, Scope &scope) {
12451256
auto scope_vec = self.GetMutable<std::vector<framework::Scope *>>();
12461257
scope_vec->emplace_back(&scope);

python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,15 @@ def __call__(self, inputs):
290290
self._valid_vars(self._params),
291291
self._valid_vars(out_vars), self._tmp_scope_vec, self._double_grads,
292292
*attrs)
293-
293+
self.drop_scope_if_no_grad()
294294
restored_nest_out = self._restore_out(out_vars)
295295
return self._remove_no_value(restored_nest_out)
296296

297+
def drop_scope_if_no_grad(self):
298+
tracer = framework._dygraph_tracer()
299+
if self.training and not tracer._has_grad:
300+
self._tmp_scope_vec.value().get_scope().drop_kids()
301+
297302
@property
298303
def program(self):
299304
if self.training:

python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,21 @@ def test_switch_eval_and_train(self):
152152
partial_layer._train_program)
153153

154154

155+
class TestWithNoGrad(unittest.TestCase):
156+
def test_with_no_grad(self):
157+
with fluid.dygraph.guard():
158+
linear_net = Linear()
159+
x_data = np.random.random((5, 10)).astype('float32')
160+
x = fluid.dygraph.to_variable(x_data)
161+
162+
with paddle.no_grad():
163+
linear_net.train()
164+
linear_net(x)
165+
_, partial_layer = linear_net.forward.program_cache.last()[-1]
166+
self.assertEqual(partial_layer.program,
167+
partial_layer._train_program)
168+
169+
155170
class GPT2LMHeadModel(fluid.dygraph.Layer):
156171
def __init__(self):
157172
super(GPT2LMHeadModel, self).__init__()

0 commit comments

Comments
 (0)