Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions python/paddle/distributed/fleet/recompute/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import contextlib
import copy
import inspect
import weakref

import paddle
Expand Down Expand Up @@ -524,23 +523,16 @@ def recompute(function, *args, **kwargs):

return static_auto_recompute(function)(*args, **kwargs)

if kwargs and use_reentrant:
raise ValueError(
"Error, if you want to send kwargs(dict parameter) to function, please set use_reentrant=False."
)

if framework._dygraph_tracer()._has_grad:
check_args = list(args)
check_args.extend(list(kwargs.values()))
check_recompute_necessary(check_args)
check_recompute_necessary(args)

if use_reentrant:
input_args = args
# rearrange `position-args + keyword-args` into `position-args`
if isinstance(function, paddle.nn.Layer):
dyfunc_sig = inspect.signature(function.forward)
else:
dyfunc_sig = inspect.signature(function)

bound_args = dyfunc_sig.bind(*args, **kwargs)
bound_args.apply_defaults()
input_args = list(bound_args.arguments.values())
return RecomputeFunction.apply(function, preserve, *input_args)
return RecomputeFunction.apply(function, preserve, *args)
else:
return _recompute_without_reentrant(function, preserve, *args, **kwargs)

Expand Down
133 changes: 46 additions & 87 deletions test/collective/fleet/test_dygraph_recompute_for_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(
use_raw_recompute=False,
recompute_kwargs={},
raise_value_error=False,
recompute_use_kwargs_as_inputs=False,
):
super().__init__()
self.recompute_blocks = recompute_blocks
Expand Down Expand Up @@ -116,7 +115,6 @@ def __init__(
self.runfunc2, self.runfunc3, self.runfunc4
),
]
self.recompute_use_kwargs_as_inputs = recompute_use_kwargs_as_inputs

def forward(self, inputs):
if self.use_fleet_sq and not self.use_raw_recompute:
Expand All @@ -137,14 +135,9 @@ def forward(self, inputs):
)
for i in range(len(self.layers)):
if i in self.recompute_blocks:
if self.recompute_use_kwargs_as_inputs:
inputs = recompute(
self.layers[i], pos=pos, x=inputs, **recompute_kwargs
)
else:
inputs = recompute(
self.layers[i], inputs, pos, **recompute_kwargs
)
inputs = recompute(
self.layers[i], inputs, pos, **recompute_kwargs
)
else:
inputs = self.layers[i](inputs, pos)

Expand All @@ -160,7 +153,6 @@ def run_model(
segments=1,
enable_autocast=False,
pure_fp16=False,
recompute_use_kwargs_as_inputs=False,
):
gen = paddle.seed(10)
gen.manual_seed(10)
Expand All @@ -176,7 +168,6 @@ def run_model(
segments=segments,
recompute_kwargs=recompute_kwargs,
raise_value_error=raise_value_error,
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
)

if pure_fp16:
Expand Down Expand Up @@ -217,12 +208,7 @@ def run_model(


class TestRecompute(unittest.TestCase):
def test_base_case(
self,
enable_autocast=False,
pure_fp16=False,
recompute_use_kwargs_as_inputs=False,
):
def test_base_case(self, enable_autocast=False, pure_fp16=False):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
Expand All @@ -245,7 +231,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
enable_autocast=enable_autocast,
pure_fp16=pure_fp16,
recompute_kwargs={"use_reentrant": flag},
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

Expand All @@ -255,7 +240,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
enable_autocast=enable_autocast,
pure_fp16=pure_fp16,
recompute_kwargs={"use_reentrant": flag},
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

Expand All @@ -265,7 +249,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
enable_autocast=enable_autocast,
pure_fp16=pure_fp16,
recompute_kwargs={"use_reentrant": flag},
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

Expand All @@ -275,7 +258,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
enable_autocast=enable_autocast,
pure_fp16=pure_fp16,
recompute_kwargs={"use_reentrant": flag},
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

Expand All @@ -286,7 +268,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
enable_autocast=enable_autocast,
pure_fp16=pure_fp16,
recompute_kwargs={"use_reentrant": flag},
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

Expand All @@ -310,42 +291,31 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):

def test_fc_net_with_dropout(self):
self.test_base_case()
self.test_base_case(recompute_use_kwargs_as_inputs=True)

def test_fc_net_without_restore_rng(self):
for flag in [True, False]:
for recompute_use_kwargs_as_inputs in [True, False]:
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs={
"preserve_rng_state": False,
"use_reentrant": flag,
},
enable_autocast=True,
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
)
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs={
"preserve_rng_state": False,
"use_reentrant": flag,
},
enable_autocast=True,
)

def test_fc_net_with_amp(self):
self.test_base_case(enable_autocast=True)
self.test_base_case(
enable_autocast=True, recompute_use_kwargs_as_inputs=True
)

def test_fc_net_with_fp16(self):
self.test_base_case(enable_autocast=True, pure_fp16=True)
self.test_base_case(
enable_autocast=True,
pure_fp16=True,
recompute_use_kwargs_as_inputs=True,
)

def test_recompute_kwargs(self):
paddle.set_device("gpu")
pos = paddle.randn(shape=[10, 10], dtype="float32")
pos.stop_gradient = False

kwargs = {"pos": pos, "use_reentrant": True}
with self.assertRaises(TypeError):
with self.assertRaises(ValueError):
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs=kwargs,
Expand All @@ -358,57 +328,46 @@ def test_recompute_kwargs(self):
)

def test_recompute_inputs_with_param(self):
for flag in [True, False]:
for recompute_use_kwargs_as_inputs in [True, False]:
pos = paddle.randn(shape=[10, 10], dtype="float32")
new_pos = EagerParamBase(
shape=pos.shape, dtype=pos.dtype, name=pos.name
)
pos._share_buffer_to(new_pos)
new_pos.stop_gradient = False
pos = paddle.randn(shape=[10, 10], dtype="float32")
new_pos = EagerParamBase(
shape=pos.shape, dtype=pos.dtype, name=pos.name
)
pos._share_buffer_to(new_pos)
new_pos.stop_gradient = False

loss, param, grad = run_model(
recompute_block=[2, 4],
recompute_kwargs={"pos": new_pos, "use_reentrant": flag},
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
)
loss, param, grad = run_model(
recompute_block=[], recompute_kwargs={"pos": new_pos}
)

loss_ref, param_ref, grad_ref = run_model(
recompute_block=[1, 2, 3],
recompute_kwargs={"pos": new_pos, "use_reentrant": flag},
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
)
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[1, 2, 3], recompute_kwargs={"pos": new_pos}
)

self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)

def test_recompute_inputs_with_tuple(self):
for flag in [True, False]:
for recompute_use_kwargs_as_inputs in [True, False]:
pos = paddle.randn(shape=[10, 10], dtype="float32")
new_pos = EagerParamBase(
shape=pos.shape, dtype=pos.dtype, name=pos.name
)
pos._share_buffer_to(new_pos)
pos.stop_gradient = False
new_pos.stop_gradient = False

loss, param, grad = run_model(
recompute_block=[2, 4],
recompute_kwargs={"pos": (pos,), "use_reentrant": flag},
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
)
pos = paddle.randn(shape=[10, 10], dtype="float32")
new_pos = EagerParamBase(
shape=pos.shape, dtype=pos.dtype, name=pos.name
)
pos._share_buffer_to(new_pos)
pos.stop_gradient = False
new_pos.stop_gradient = False

loss_ref, param_ref, grad_ref = run_model(
recompute_block=[1, 2, 3],
recompute_kwargs={"pos": (new_pos,), "use_reentrant": flag},
recompute_use_kwargs_as_inputs=recompute_use_kwargs_as_inputs,
)
loss, param, grad = run_model(
recompute_block=[2, 4], recompute_kwargs={"pos": (pos,)}
)

loss_ref, param_ref, grad_ref = run_model(
recompute_block=[1, 2, 3],
recompute_kwargs={"pos": (new_pos,)},
)

self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)


if __name__ == '__main__':
Expand Down