Skip to content

Commit 7c6d7d7

Browse files
authored
Revert "【Hackathon 6th No.35】support kwargs for recompute when use_reentrant == True" (#63637)
This reverts commit 64cad15.
1 parent 8d530db commit 7c6d7d7

File tree

2 files changed

+53
-102
lines changed

2 files changed

+53
-102
lines changed

python/paddle/distributed/fleet/recompute/recompute.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import contextlib
1616
import copy
17-
import inspect
1817
import weakref
1918

2019
import paddle
@@ -524,23 +523,16 @@ def recompute(function, *args, **kwargs):
524523

525524
return static_auto_recompute(function)(*args, **kwargs)
526525

526+
if kwargs and use_reentrant:
527+
raise ValueError(
528+
"Error, if you want to send kwargs(dict parameter) to function, please set use_reentrant=False."
529+
)
530+
527531
if framework._dygraph_tracer()._has_grad:
528-
check_args = list(args)
529-
check_args.extend(list(kwargs.values()))
530-
check_recompute_necessary(check_args)
532+
check_recompute_necessary(args)
531533

532534
if use_reentrant:
533-
input_args = args
534-
# rearrange `position-args + keyword-args` into `position-args`
535-
if isinstance(function, paddle.nn.Layer):
536-
dyfunc_sig = inspect.signature(function.forward)
537-
else:
538-
dyfunc_sig = inspect.signature(function)
539-
540-
bound_args = dyfunc_sig.bind(*args, **kwargs)
541-
bound_args.apply_defaults()
542-
input_args = list(bound_args.arguments.values())
543-
return RecomputeFunction.apply(function, preserve, *input_args)
535+
return RecomputeFunction.apply(function, preserve, *args)
544536
else:
545537
return _recompute_without_reentrant(function, preserve, *args, **kwargs)
546538

test/collective/fleet/test_dygraph_recompute_for_eager.py

Lines changed: 46 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __init__(
7575
use_raw_recompute=False,
7676
recompute_kwargs={},
7777
raise_value_error=False,
78-
recompute_use_kwargs_as_inputs=False,
7978
):
8079
super().__init__()
8180
self.recompute_blocks = recompute_blocks
@@ -116,7 +115,6 @@ def __init__(
116115
self.runfunc2, self.runfunc3, self.runfunc4
117116
),
118117
]
119-
self.recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs
120118

121119
def forward(self, inputs):
122120
if self.use_fleet_sq and not self.use_raw_recompute:
@@ -137,14 +135,9 @@ def forward(self, inputs):
137135
)
138136
for i in range(len(self.layers)):
139137
if i in self.recompute_blocks:
140-
if self.recompute_use_kwargs_as_inputs:
141-
inputs = recompute(
142-
self.layers[i], pos=pos, x=inputs, **recompute_kwargs
143-
)
144-
else:
145-
inputs = recompute(
146-
self.layers[i], inputs, pos, **recompute_kwargs
147-
)
138+
inputs = recompute(
139+
self.layers[i], inputs, pos, **recompute_kwargs
140+
)
148141
else:
149142
inputs = self.layers[i](inputs, pos)
150143

@@ -160,7 +153,6 @@ def run_model(
160153
segments=1,
161154
enable_autocast=False,
162155
pure_fp16=False,
163-
recompute_use_kwargs_as_inputs=False,
164156
):
165157
gen = paddle.seed(10)
166158
gen.manual_seed(10)
@@ -176,7 +168,6 @@ def run_model(
176168
segments=segments,
177169
recompute_kwargs=recompute_kwargs,
178170
raise_value_error=raise_value_error,
179-
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
180171
)
181172

182173
if pure_fp16:
@@ -217,12 +208,7 @@ def run_model(
217208

218209

219210
class TestRecompute(unittest.TestCase):
220-
def test_base_case(
221-
self,
222-
enable_autocast=False,
223-
pure_fp16=False,
224-
recompute_use_kwargs_as_inputs=False,
225-
):
211+
def test_base_case(self, enable_autocast=False, pure_fp16=False):
226212
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
227213
self.assertEqual(loss_ref, loss)
228214
self.assertEqual(param_ref, param)
@@ -245,7 +231,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
245231
enable_autocast=enable_autocast,
246232
pure_fp16=pure_fp16,
247233
recompute_kwargs={"use_reentrant": flag},
248-
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
249234
)
250235
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
251236

@@ -255,7 +240,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
255240
enable_autocast=enable_autocast,
256241
pure_fp16=pure_fp16,
257242
recompute_kwargs={"use_reentrant": flag},
258-
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
259243
)
260244
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
261245

@@ -265,7 +249,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
265249
enable_autocast=enable_autocast,
266250
pure_fp16=pure_fp16,
267251
recompute_kwargs={"use_reentrant": flag},
268-
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
269252
)
270253
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
271254

@@ -275,7 +258,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
275258
enable_autocast=enable_autocast,
276259
pure_fp16=pure_fp16,
277260
recompute_kwargs={"use_reentrant": flag},
278-
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
279261
)
280262
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
281263

@@ -286,7 +268,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
286268
enable_autocast=enable_autocast,
287269
pure_fp16=pure_fp16,
288270
recompute_kwargs={"use_reentrant": flag},
289-
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
290271
)
291272
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)
292273

@@ -310,42 +291,31 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
310291

311292
def test_fc_net_with_dropout(self):
312293
self.test_base_case()
313-
self.test_base_case(recompute_use_kwargs_as_inputs=True)
314294

315295
def test_fc_net_without_restore_rng(self):
316296
for flag in [True, False]:
317-
for recompute_use_kwargs_as_inputs in [True, False]:
318-
loss_ref, param_ref, grad_ref = run_model(
319-
recompute_block=[2],
320-
recompute_kwargs={
321-
"preserve_rng_state": False,
322-
"use_reentrant": flag,
323-
},
324-
enable_autocast=True,
325-
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
326-
)
297+
loss_ref, param_ref, grad_ref = run_model(
298+
recompute_block=[2],
299+
recompute_kwargs={
300+
"preserve_rng_state": False,
301+
"use_reentrant": flag,
302+
},
303+
enable_autocast=True,
304+
)
327305

328306
def test_fc_net_with_amp(self):
329307
self.test_base_case(enable_autocast=True)
330-
self.test_base_case(
331-
enable_autocast=True, recompute_use_kwargs_as_inputs=True
332-
)
333308

334309
def test_fc_net_with_fp16(self):
335310
self.test_base_case(enable_autocast=True, pure_fp16=True)
336-
self.test_base_case(
337-
enable_autocast=True,
338-
pure_fp16=True,
339-
recompute_use_kwargs_as_inputs=True,
340-
)
341311

342312
def test_recompute_kwargs(self):
343313
paddle.set_device("gpu")
344314
pos = paddle.randn(shape=[10, 10], dtype="float32")
345315
pos.stop_gradient = False
346316

347317
kwargs = {"pos": pos, "use_reentrant": True}
348-
with self.assertRaises(TypeError):
318+
with self.assertRaises(ValueError):
349319
loss_ref, param_ref, grad_ref = run_model(
350320
recompute_block=[2],
351321
recompute_kwargs=kwargs,
@@ -358,57 +328,46 @@ def test_recompute_kwargs(self):
358328
)
359329

360330
def test_recompute_inputs_with_param(self):
361-
for flag in [True, False]:
362-
for recompute_use_kwargs_as_inputs in [True, False]:
363-
pos = paddle.randn(shape=[10, 10], dtype="float32")
364-
new_pos = EagerParamBase(
365-
shape=pos.shape, dtype=pos.dtype, name=pos.name
366-
)
367-
pos._share_buffer_to(new_pos)
368-
new_pos.stop_gradient = False
331+
pos = paddle.randn(shape=[10, 10], dtype="float32")
332+
new_pos = EagerParamBase(
333+
shape=pos.shape, dtype=pos.dtype, name=pos.name
334+
)
335+
pos._share_buffer_to(new_pos)
336+
new_pos.stop_gradient = False
369337

370-
loss, param, grad = run_model(
371-
recompute_block=[2, 4],
372-
recompute_kwargs={"pos": new_pos, "use_reentrant": flag},
373-
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
374-
)
338+
loss, param, grad = run_model(
339+
recompute_block=[], recompute_kwargs={"pos": new_pos}
340+
)
375341

376-
loss_ref, param_ref, grad_ref = run_model(
377-
recompute_block=[1, 2, 3],
378-
recompute_kwargs={"pos": new_pos, "use_reentrant": flag},
379-
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
380-
)
342+
loss_ref, param_ref, grad_ref = run_model(
343+
recompute_block=[1, 2, 3], recompute_kwargs={"pos": new_pos}
344+
)
381345

382-
self.assertEqual(loss_ref, loss)
383-
self.assertEqual(param_ref, param)
384-
self.assertEqual(grad_ref, grad)
346+
self.assertEqual(loss_ref, loss)
347+
self.assertEqual(param_ref, param)
348+
self.assertEqual(grad_ref, grad)
385349

386350
def test_recompute_inputs_with_tuple(self):
387-
for flag in [True, False]:
388-
for recompute_use_kwargs_as_inputs in [True, False]:
389-
pos = paddle.randn(shape=[10, 10], dtype="float32")
390-
new_pos = EagerParamBase(
391-
shape=pos.shape, dtype=pos.dtype, name=pos.name
392-
)
393-
pos._share_buffer_to(new_pos)
394-
pos.stop_gradient = False
395-
new_pos.stop_gradient = False
396-
397-
loss, param, grad = run_model(
398-
recompute_block=[2, 4],
399-
recompute_kwargs={"pos": (pos,), "use_reentrant": flag},
400-
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
401-
)
351+
pos = paddle.randn(shape=[10, 10], dtype="float32")
352+
new_pos = EagerParamBase(
353+
shape=pos.shape, dtype=pos.dtype, name=pos.name
354+
)
355+
pos._share_buffer_to(new_pos)
356+
pos.stop_gradient = False
357+
new_pos.stop_gradient = False
402358

403-
loss_ref, param_ref, grad_ref = run_model(
404-
recompute_block=[1, 2, 3],
405-
recompute_kwargs={"pos": (new_pos,), "use_reentrant": flag},
406-
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
407-
)
359+
loss, param, grad = run_model(
360+
recompute_block=[2, 4], recompute_kwargs={"pos": (pos,)}
361+
)
362+
363+
loss_ref, param_ref, grad_ref = run_model(
364+
recompute_block=[1, 2, 3],
365+
recompute_kwargs={"pos": (new_pos,)},
366+
)
408367

409-
self.assertEqual(loss_ref, loss)
410-
self.assertEqual(param_ref, param)
411-
self.assertEqual(grad_ref, grad)
368+
self.assertEqual(loss_ref, loss)
369+
self.assertEqual(param_ref, param)
370+
self.assertEqual(grad_ref, grad)
412371

413372

414373
if __name__ == '__main__':

0 commit comments

Comments
 (0)