Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch gradient_checkpoint_scope #1559

Merged
merged 46 commits into from
Jul 9, 2024
Merged

Conversation

albertz
Copy link
Member

@albertz albertz commented Jul 1, 2024

Gradient checkpointing for PyTorch.

Fix #1552.

This implements a new gradient checkpointing API for the user, gradient_checkpoint_scope, as a better alternative to torch.utils.checkpoint, using torch.autograd.graph.saved_tensors_hooks and TorchDispatchMode under the hood, and also handling the RNG and AMP state.

gradient_checkpoint_scope creates a gradient checkpoint scope. All tensors created within this scope will not be stored for backpropagation, but will be recomputed on the fly during backpropagation.

Example:

a = ...
b = ...
c = ...
with gradient_checkpoint_scope():
    x = a + b
y = x * c

In this example, the tensor x will not be stored for backpropagation, i.e. the computation x = a + b will be recomputed during backpropagation.

Internally, this uses the PyTorch torch.autograd.graph.saved_tensors_hooks mechanism to override what we store for backpropagation, and how to recompute it. And we use the PyTorch TorchDispatchMode to intercept all operations within the scope. Note that the usage of torch.autograd.graph.saved_tensors_hooks is tricky here as we need it beyond the scope of the gradient_checkpoint_scope, specifically for all future usages of the tensor x in the example. See the code documentation for more details on this.

Note, torch.utils.checkpoint is different: You cannot easily specify what not to store / what to recompute. You rather specify a start/end point what to store for backpropagation, and then PyTorch will recompute everything in between. For the example above, you define that y is the end point and will be stored. It looks like this:

a = ...
b = ...
c = ...
y = torch.utils.checkpoint.checkpoint(lambda: (a + b) * c)

PyTorch will not recompute ... * c here, but it will recompute a + b.

We find this API more cumbersome to use and less flexible, because in many case, you know what you want to recompute, i.e. what you don't want to store. The PyTorch API is more about what you want to store, and then recompute everything else between.

See also returnn.tf.util.gradient_checkpoint: same API and logic in TF, although it heavily makes use of the TF computation graph, i.e. graph mode, which makes this particular feature much easier to implement.

Further references:
#1552
https://discuss.pytorch.org/t/gradient-checkpointing/205416
pytorch/pytorch#129867
https://gist.github.com/soulitzer/ec1049a947be046de7fbc2af61a4ee8c


You are not a RETURNN user yet but just want to try this?

pip install returnn

And then:

from returnn.torch.util.gradient_checkpoint import gradient_checkpoint_scope

...

@albertz albertz force-pushed the albert-torch-grad-checkpoint branch from 7ea5536 to 5b43cfb Compare July 2, 2024 23:16
@albertz
Copy link
Member Author

albertz commented Jul 3, 2024

Note, I think the implementation is ready now. What's missing are some tests, as outlined in #1552. But anyway, I think you could already start reviewing.

If some code is unclear, please say so: Either I have thought wrong, or made a mistake, or if not, I at least should better document/comment that part.

Copy link
Collaborator

@NeoLegends NeoLegends left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it turns out this is very difficult to automatically test/verify, WDYT about having a setup w/ and w/o gradient checkpointing and designing it so that the one without goes OOM while the one using this functionality does not?

returnn/torch/util/gradient_checkpoint.py Show resolved Hide resolved
returnn/torch/util/gradient_checkpoint.py Outdated Show resolved Hide resolved
@albertz
Copy link
Member Author

albertz commented Jul 4, 2024

If it turns out this is very difficult to automatically test/verify, WDYT about having a setup w/ and w/o gradient checkpointing and designing it so that the one without goes OOM while the one using this functionality does not?

But this sounds very easy to automatically test? Have you seen the WIP test code here? It tests basically exactly that. Also, that's what I wrote already here: #1552 (comment)

@albertz albertz marked this pull request as ready for review July 8, 2024 15:36
@albertz albertz requested a review from a team as a code owner July 8, 2024 15:36
@albertz
Copy link
Member Author

albertz commented Jul 8, 2024

For reference, this is the event trace I see now. Due to GC logic, it might not be 100% deterministic. Although probably changes in PyTorch internals will also induce changes here.

No grad checkpointing:

(pycall {'caller_loc': '_pytest/python.py:194', 'caller_name': 'pytest_pyfunc_call', 'callsite_name': 'test_gradient_checkpoint_scope'})
(pycall {'caller_loc': 'test_torch_util.py:57', 'caller_name': 'test_gradient_checkpoint_scope', 'callsite_name': '__enter__'})
(pycall {'caller_loc': 'test_torch_util.py:58', 'caller_name': 'test_gradient_checkpoint_scope', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'test_torch_util.py:58', 'caller_name': 'test_gradient_checkpoint_scope', 'callsite_name': '__enter__'})
(torchop {'name': 'train_step_no_grad_ckpt'})
pycall {'caller_loc': 'test_torch_util.py:59', 'caller_name': 'test_gradient_checkpoint_scope', 'callsite_name': 'demo_run'} ✓
(pycall {'caller_loc': 'test_torch_util.py:46', 'caller_name': 'demo_run', 'callsite_name': '__getattr__'})
(pycall {'caller_loc': 'test_torch_util.py:47', 'caller_name': 'demo_run', 'callsite_name': '_wrapped_call_impl'})
pycall {'caller_loc': 'torch/nn/modules/module.py:1527', 'caller_name': '_call_impl', 'callsite_name': 'forward'} ✓
(pycall {'caller_loc': 'test_torch_util.py:39', 'caller_name': 'forward', 'callsite_name': '__getattr__'})
pycall {'caller_loc': 'test_torch_util.py:39', 'caller_name': 'forward', 'callsite_name': 'get_var_noise'} ✓
(torchop {'name': 'aten::randn'})
(torchop {'name': 'aten::empty'})
alloc {'id': 3, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/get_var_noise/aten::randn', 'size': 41612, 'total_alloc': 41612} ✓
(torchop {'name': 'aten::normal_'})
(torchop {'name': 'aten::add'})
alloc {'id': 4, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/aten::add', 'size': 41612, 'total_alloc': 83224} ✓
dealloc {'id': 3, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/get_var_noise/aten::randn', 'size': -41612, 'total_alloc': 41612} ✓
(torchop {'name': 'aten::mul'})
alloc {'id': 5, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/aten::mul', 'size': 41612, 'total_alloc': 83224} ✓
(torchop {'name': 'aten::sum'})
(torchop {'name': 'aten::sum'})
(alloc {'id': 6, 'name': 'test_gradient_checkpoint_scope/demo_run/aten::sum', 'size': 4, 'total_alloc': 83228})
(torchop {'name': 'aten::as_strided'})
(torchop {'name': 'aten::fill_'})
dealloc {'id': 5, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/aten::mul', 'size': -41612, 'total_alloc': 41616} ✓
pycall {'caller_loc': 'test_torch_util.py:50', 'caller_name': 'demo_run', 'callsite_name': 'backward'} ✓
(torchop {'name': 'aten::ones_like'})
(torchop {'name': 'aten::empty_like'})
(torchop {'name': 'aten::empty_strided'})
(alloc {'id': 7, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/aten::ones_like', 'size': 4, 'total_alloc': 41620})
(torchop {'name': 'aten::fill_'})
(torchop {'name': 'autograd::engine::evaluate_function: SumBackward0'})
(torchop {'name': 'SumBackward0'})
(torchop {'name': 'aten::expand'})
(torchop {'name': 'aten::as_strided'})
(torchop {'name': 'autograd::engine::evaluate_function: MulBackward0'})
(torchop {'name': 'MulBackward0'})
(torchop {'name': 'aten::mul'})
alloc {'id': 8, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/autograd::engine::evaluate_function: MulBackward0', 'size': 41612, 'total_alloc': 83232} ✓
(torchop {'name': 'aten::mul'})
alloc {'id': 9, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/autograd::engine::evaluate_function: MulBackward0', 'size': 41612, 'total_alloc': 124844} ✓
dealloc {'id': 4, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/aten::add', 'size': -41612, 'total_alloc': 83232} ✓
(torchop {'name': 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad'})
(torchop {'name': 'torch::autograd::AccumulateGrad'})
(torchop {'name': 'aten::detach'})
(torchop {'name': 'detach'})
(torchop {'name': 'autograd::engine::evaluate_function: AddBackward0'})
(torchop {'name': 'AddBackward0'})
(torchop {'name': 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad'})
(torchop {'name': 'torch::autograd::AccumulateGrad'})
(torchop {'name': 'aten::detach'})
(torchop {'name': 'detach'})
(dealloc {'id': 7, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/aten::ones_like', 'size': -4, 'total_alloc': 83228})
(dealloc {'id': 6, 'name': 'test_gradient_checkpoint_scope/demo_run/aten::sum', 'size': -4, 'total_alloc': 83224})
(pycall {'caller_loc': 'test_torch_util.py:52', 'caller_name': 'demo_run', 'callsite_name': 'wrapper'})
torchop {'name': 'Optimizer.step#SGD.step'} ✓
(torchop {'name': 'aten::add_'})
(torchop {'name': 'aten::add_'})
(pycall {'caller_loc': 'test_torch_util.py:53', 'caller_name': 'demo_run', 'callsite_name': 'inner'})
torchop {'name': 'Optimizer.zero_grad#SGD.zero_grad'} ✓
dealloc {'id': 9, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/autograd::engine::evaluate_function: MulBackward0', 'size': -41612, 'total_alloc': 41612} ✓
dealloc {'id': 8, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/autograd::engine::evaluate_function: MulBackward0', 'size': -41612, 'total_alloc': 0} ✓
(pycall {'caller_loc': 'test_torch_util.py:58', 'caller_name': 'test_gradient_checkpoint_scope', 'callsite_name': '__exit__'})
(pycall {'caller_loc': 'test_torch_util.py:57', 'caller_name': 'test_gradient_checkpoint_scope', 'callsite_name': '__exit__'})

With gradient checkpointing:

(pycall {'caller_loc': '_pytest/python.py:194', 'caller_name': 'pytest_pyfunc_call', 'callsite_name': 'test_gradient_checkpoint_scope'})
(pycall {'caller_loc': 'test_torch_util.py:94', 'caller_name': 'test_gradient_checkpoint_scope', 'callsite_name': '__enter__'})
(pycall {'caller_loc': 'test_torch_util.py:95', 'caller_name': 'test_gradient_checkpoint_scope', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'test_torch_util.py:95', 'caller_name': 'test_gradient_checkpoint_scope', 'callsite_name': '__enter__'})
(torchop {'name': 'train_step_grad_ckpt'})
pycall {'caller_loc': 'test_torch_util.py:96', 'caller_name': 'test_gradient_checkpoint_scope', 'callsite_name': 'demo_run'} ✓
(pycall {'caller_loc': 'test_torch_util.py:46', 'caller_name': 'demo_run', 'callsite_name': '__getattr__'})
(pycall {'caller_loc': 'test_torch_util.py:47', 'caller_name': 'demo_run', 'callsite_name': '_wrapped_call_impl'})
pycall {'caller_loc': 'torch/nn/modules/module.py:1527', 'caller_name': '_call_impl', 'callsite_name': 'forward'} ✓
(pycall {'caller_loc': 'test_torch_util.py:41', 'caller_name': 'forward', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:137', 'caller_name': '__init__', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:234', 'caller_name': '__init__', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:235', 'caller_name': '__init__', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:140', 'caller_name': '__init__', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'test_torch_util.py:41', 'caller_name': 'forward', 'callsite_name': '__enter__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:147', 'caller_name': '__enter__', 'callsite_name': '__enter__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:148', 'caller_name': '__enter__', 'callsite_name': '__enter__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:150', 'caller_name': '__enter__', 'callsite_name': 'current_thread'})
(pycall {'caller_loc': 'test_torch_util.py:42', 'caller_name': 'forward', 'callsite_name': '__getattr__'})
pycall {'caller_loc': 'test_torch_util.py:42', 'caller_name': 'forward', 'callsite_name': 'get_var_noise'} ✓
(torchop {'name': 'aten::randn'})
pycall {'caller_loc': 'test_torch_util.py:35', 'caller_name': 'get_var_noise', 'callsite_name': '__torch_dispatch__'} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:240', 'caller_name': '__torch_dispatch__', 'callsite_name': 'maybe_store_rng_state'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:300', 'caller_name': 'maybe_store_rng_state', 'callsite_name': '_get_dev_rng_state'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:397', 'caller_name': '_get_dev_rng_state', 'callsite_name': 'get_rng_state'})
alloc {'id': 3, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/get_var_noise/maybe_store_rng_state/_get_dev_rng_state/get_rng_state', 'size': 5056, 'total_alloc': 5056} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:241', 'caller_name': '__torch_dispatch__', 'callsite_name': 'maybe_store_amp_state'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:311', 'caller_name': 'maybe_store_amp_state', 'callsite_name': '_get_dev_amp_state'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:242', 'caller_name': '__torch_dispatch__', 'callsite_name': 'tree_map'})
(pycall {'caller_loc': 'torch/utils/_pytree.py:283', 'caller_name': '<listcomp>', 'callsite_name': 'maybe_store_rng_state'})
(pycall {'caller_loc': 'torch/utils/_pytree.py:283', 'caller_name': '<listcomp>', 'callsite_name': 'maybe_store_rng_state'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:243', 'caller_name': '__torch_dispatch__', 'callsite_name': 'tree_map'})
(pycall {'caller_loc': 'torch/utils/_pytree.py:283', 'caller_name': '<listcomp>', 'callsite_name': 'maybe_store_rng_state'})
(pycall {'caller_loc': 'torch/utils/_pytree.py:283', 'caller_name': '<listcomp>', 'callsite_name': 'maybe_store_rng_state'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:244', 'caller_name': '__torch_dispatch__', 'callsite_name': '__call__'})
(torchop {'name': 'aten::randn'})
(torchop {'name': 'aten::empty'})
alloc {'id': 4, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/get_var_noise/aten::randn', 'size': 41612, 'total_alloc': 46668} ✓
(torchop {'name': 'aten::normal_'})
pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:245', 'caller_name': '__torch_dispatch__', 'callsite_name': 'record_op'} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:273', 'caller_name': 'record_op', 'callsite_name': 'tree_flatten'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:274', 'caller_name': 'record_op', 'callsite_name': 'tree_map_only'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:275', 'caller_name': 'record_op', 'callsite_name': 'tree_map_only'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:276', 'caller_name': 'record_op', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:286', 'caller_name': 'record_op', 'callsite_name': '__contains__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:288', 'caller_name': 'record_op', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:289', 'caller_name': 'record_op', 'callsite_name': '__setitem__'})
(torchop {'name': 'aten::add'})
(pycall {'caller_loc': 'test_torch_util.py:42', 'caller_name': 'forward', 'callsite_name': '__torch_dispatch__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:240', 'caller_name': '__torch_dispatch__', 'callsite_name': 'maybe_store_rng_state'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:241', 'caller_name': '__torch_dispatch__', 'callsite_name': 'maybe_store_amp_state'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:242', 'caller_name': '__torch_dispatch__', 'callsite_name': 'tree_map'})
(pycall {'caller_loc': 'torch/utils/_pytree.py:283', 'caller_name': '<listcomp>', 'callsite_name': 'maybe_store_rng_state'})
(pycall {'caller_loc': 'torch/utils/_pytree.py:283', 'caller_name': '<listcomp>', 'callsite_name': 'maybe_store_rng_state'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:243', 'caller_name': '__torch_dispatch__', 'callsite_name': 'tree_map'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:244', 'caller_name': '__torch_dispatch__', 'callsite_name': '__call__'})
(torchop {'name': 'aten::add'})
alloc {'id': 5, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/aten::add', 'size': 41612, 'total_alloc': 88280} ✓
pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:245', 'caller_name': '__torch_dispatch__', 'callsite_name': 'record_op'} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:273', 'caller_name': 'record_op', 'callsite_name': 'tree_flatten'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:274', 'caller_name': 'record_op', 'callsite_name': 'tree_map_only'})
(pycall {'caller_loc': 'torch/utils/_pytree.py:334', 'caller_name': 'inner', 'callsite_name': 'maybe_map_raw_tensor_to_graph_tensor'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:315', 'caller_name': 'maybe_map_raw_tensor_to_graph_tensor', 'callsite_name': 'get'})
(pycall {'caller_loc': 'torch/utils/_pytree.py:334', 'caller_name': 'inner', 'callsite_name': 'maybe_map_raw_tensor_to_graph_tensor'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:315', 'caller_name': 'maybe_map_raw_tensor_to_graph_tensor', 'callsite_name': 'get'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:275', 'caller_name': 'record_op', 'callsite_name': 'tree_map_only'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:276', 'caller_name': 'record_op', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:286', 'caller_name': 'record_op', 'callsite_name': '__contains__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:288', 'caller_name': 'record_op', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:289', 'caller_name': 'record_op', 'callsite_name': '__setitem__'})
(pycall {'caller_loc': 'test_torch_util.py:42', 'caller_name': 'forward', 'callsite_name': 'remove'})
dealloc {'id': 4, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/get_var_noise/aten::randn', 'size': -41612, 'total_alloc': 46668} ✓
(pycall {'caller_loc': 'test_torch_util.py:41', 'caller_name': 'forward', 'callsite_name': '__exit__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:154', 'caller_name': '__exit__', 'callsite_name': '__exit__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:155', 'caller_name': '__exit__', 'callsite_name': 'is_any_recorded_tensor_alive'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:269', 'caller_name': 'is_any_recorded_tensor_alive', 'callsite_name': '__len__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:161', 'caller_name': '__exit__', 'callsite_name': '_register_custom_saved_tensors_hooks'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:493', 'caller_name': '_register_custom_saved_tensors_hooks', 'callsite_name': 'current_thread'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:495', 'caller_name': '_register_custom_saved_tensors_hooks', 'callsite_name': '__contains__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:504', 'caller_name': '_register_custom_saved_tensors_hooks', 'callsite_name': 'add'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:505', 'caller_name': '_register_custom_saved_tensors_hooks', 'callsite_name': '__len__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:163', 'caller_name': '__exit__', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:162', 'caller_name': '__exit__', 'callsite_name': '_register_custom_saved_tensors_hooks_thread_local_callback'})
(torchop {'name': 'aten::mul'})
pycall {'caller_loc': 'test_torch_util.py:43', 'caller_name': 'forward', 'callsite_name': '_pack_hook'} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:199', 'caller_name': '_pack_hook', 'callsite_name': 'is_any_recorded_tensor_alive'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:269', 'caller_name': 'is_any_recorded_tensor_alive', 'callsite_name': '__len__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:204', 'caller_name': '_pack_hook', 'callsite_name': 'get'})
(pycall {'caller_loc': 'test_torch_util.py:43', 'caller_name': 'forward', 'callsite_name': '_pack_hook'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:199', 'caller_name': '_pack_hook', 'callsite_name': 'is_any_recorded_tensor_alive'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:269', 'caller_name': 'is_any_recorded_tensor_alive', 'callsite_name': '__len__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:204', 'caller_name': '_pack_hook', 'callsite_name': 'get'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:206', 'caller_name': '_pack_hook', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:206', 'caller_name': '_pack_hook', 'callsite_name': '__init__'})
alloc {'id': 6, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/aten::mul', 'size': 41612, 'total_alloc': 88280} ✓
(pycall {'caller_loc': 'torch/nn/modules/module.py:1527', 'caller_name': '_call_impl', 'callsite_name': '__del__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:458', 'caller_name': '__del__', 'callsite_name': '__call__'})
pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:473', 'caller_name': '__call__', 'callsite_name': '_tensor_del_hook'} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:220', 'caller_name': '_tensor_del_hook', 'callsite_name': '_maybe_exit_saved_tensors_hooks_scope'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:174', 'caller_name': '_maybe_exit_saved_tensors_hooks_scope', 'callsite_name': 'current_thread'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:175', 'caller_name': '_maybe_exit_saved_tensors_hooks_scope', 'callsite_name': 'is_any_recorded_tensor_alive'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:269', 'caller_name': 'is_any_recorded_tensor_alive', 'callsite_name': '__len__'})
pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:176', 'caller_name': '_maybe_exit_saved_tensors_hooks_scope', 'callsite_name': 'exit_saved_tensors_hooks_scope'} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:190', 'caller_name': 'exit_saved_tensors_hooks_scope', 'callsite_name': 'current_thread'})
pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:196', 'caller_name': 'exit_saved_tensors_hooks_scope', 'callsite_name': '_custom_saved_tensors_hooks_exit'} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:543', 'caller_name': '_custom_saved_tensors_hooks_exit', 'callsite_name': '_custom_saved_tensors_hooks_call_callbacks'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:573', 'caller_name': '_custom_saved_tensors_hooks_call_callbacks', 'callsite_name': '<listcomp>'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:574', 'caller_name': '<listcomp>', 'callsite_name': '__call__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:473', 'caller_name': '__call__', 'callsite_name': '_custom_saved_tensors_hooks_callback'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:223', 'caller_name': '_custom_saved_tensors_hooks_callback', 'callsite_name': 'current_thread'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:225', 'caller_name': '_custom_saved_tensors_hooks_callback', 'callsite_name': 'is_any_recorded_tensor_alive'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:269', 'caller_name': 'is_any_recorded_tensor_alive', 'callsite_name': '__len__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:228', 'caller_name': '_custom_saved_tensors_hooks_callback', 'callsite_name': 'exit_saved_tensors_hooks_scope'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:190', 'caller_name': 'exit_saved_tensors_hooks_scope', 'callsite_name': 'current_thread'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:552', 'caller_name': '_custom_saved_tensors_hooks_exit', 'callsite_name': '__exit__'})
pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:557', 'caller_name': '_custom_saved_tensors_hooks_exit', 'callsite_name': '_unregister_custom_saved_tensors_hooks'} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:511', 'caller_name': '_unregister_custom_saved_tensors_hooks', 'callsite_name': 'current_thread'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:513', 'caller_name': '_unregister_custom_saved_tensors_hooks', 'callsite_name': '__contains__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:520', 'caller_name': '_unregister_custom_saved_tensors_hooks', 'callsite_name': 'remove'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:521', 'caller_name': '_unregister_custom_saved_tensors_hooks', 'callsite_name': '__len__'})
dealloc {'id': 5, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/aten::add', 'size': -41612, 'total_alloc': 46668} ✓
(torchop {'name': 'aten::sum'})
(torchop {'name': 'aten::sum'})
(alloc {'id': 7, 'name': 'test_gradient_checkpoint_scope/demo_run/aten::sum', 'size': 4, 'total_alloc': 46672})
(torchop {'name': 'aten::as_strided'})
(torchop {'name': 'aten::fill_'})
dealloc {'id': 6, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/aten::mul', 'size': -41612, 'total_alloc': 5060} ✓
pycall {'caller_loc': 'test_torch_util.py:50', 'caller_name': 'demo_run', 'callsite_name': 'backward'} ✓
(torchop {'name': 'aten::ones_like'})
(torchop {'name': 'aten::empty_like'})
(torchop {'name': 'aten::empty_strided'})
(alloc {'id': 8, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/aten::ones_like', 'size': 4, 'total_alloc': 5064})
(torchop {'name': 'aten::fill_'})
(torchop {'name': 'autograd::engine::evaluate_function: SumBackward0'})
(torchop {'name': 'SumBackward0'})
(torchop {'name': 'aten::expand'})
(torchop {'name': 'aten::as_strided'})
(torchop {'name': 'autograd::engine::evaluate_function: MulBackward0'})
(torchop {'name': 'MulBackward0'})
pycall {'caller_loc': 'torch/autograd/__init__.py:251', 'caller_name': 'backward', 'callsite_name': '_unpack_hook'} ✓
(pycall {'caller_loc': 'torch/autograd/__init__.py:251', 'caller_name': 'backward', 'callsite_name': '_unpack_hook'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:212', 'caller_name': '_unpack_hook', 'callsite_name': '_maybe_exit_saved_tensors_hooks_scope'})
pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:213', 'caller_name': '_unpack_hook', 'callsite_name': 'maybe_recompute'} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:336', 'caller_name': 'maybe_recompute', 'callsite_name': 'helper'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:336', 'caller_name': 'maybe_recompute', 'callsite_name': '__enter__'})
(pycall {'caller_loc': 'contextlib.py:137', 'caller_name': '__enter__', 'callsite_name': '_reset_rng_states_scope'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:385', 'caller_name': '_reset_rng_states_scope', 'callsite_name': '<dictcomp>'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:385', 'caller_name': '<dictcomp>', 'callsite_name': '_get_dev_rng_state'})
pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:397', 'caller_name': '_get_dev_rng_state', 'callsite_name': 'get_rng_state'} ✓
alloc {'id': 9, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/_unpack_hook/maybe_recompute/_reset_rng_states_scope/<dictcomp>/_get_dev_rng_state/get_rng_state', 'size': 5056, 'total_alloc': 10120} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:388', 'caller_name': '_reset_rng_states_scope', 'callsite_name': '_set_dev_rng_state'})
pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:404', 'caller_name': '_set_dev_rng_state', 'callsite_name': 'set_rng_state'} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:336', 'caller_name': 'maybe_recompute', 'callsite_name': 'helper'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:336', 'caller_name': 'maybe_recompute', 'callsite_name': '__enter__'})
(pycall {'caller_loc': 'contextlib.py:137', 'caller_name': '__enter__', 'callsite_name': '_reset_amp_states_scope'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:412', 'caller_name': '_reset_amp_states_scope', 'callsite_name': '__init__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:412', 'caller_name': '_reset_amp_states_scope', 'callsite_name': '__enter__'})
pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:341', 'caller_name': 'maybe_recompute', 'callsite_name': 'recompute'} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:357', 'caller_name': 'recompute', 'callsite_name': 'tree_map_only'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:358', 'caller_name': 'recompute', 'callsite_name': 'tree_map_only'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:359', 'caller_name': 'recompute', 'callsite_name': '__call__'})
(torchop {'name': 'aten::randn'})
(torchop {'name': 'aten::empty'})
alloc {'id': 10, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/_unpack_hook/maybe_recompute/recompute/aten::randn', 'size': 41612, 'total_alloc': 51732} ✓
(torchop {'name': 'aten::normal_'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:360', 'caller_name': 'recompute', 'callsite_name': 'tree_flatten'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:341', 'caller_name': 'maybe_recompute', 'callsite_name': 'recompute'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:357', 'caller_name': 'recompute', 'callsite_name': 'tree_map_only'})
(pycall {'caller_loc': 'torch/utils/_pytree.py:334', 'caller_name': 'inner', 'callsite_name': 'get_recomputed'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:358', 'caller_name': 'recompute', 'callsite_name': 'tree_map_only'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:359', 'caller_name': 'recompute', 'callsite_name': '__call__'})
(torchop {'name': 'aten::add'})
alloc {'id': 11, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/_unpack_hook/maybe_recompute/recompute/aten::add', 'size': 41612, 'total_alloc': 93344} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:360', 'caller_name': 'recompute', 'callsite_name': 'tree_flatten'})
dealloc {'id': 10, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/_unpack_hook/maybe_recompute/recompute/aten::randn', 'size': -41612, 'total_alloc': 51732} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:336', 'caller_name': 'maybe_recompute', 'callsite_name': '__exit__'})
(pycall {'caller_loc': 'contextlib.py:144', 'caller_name': '__exit__', 'callsite_name': '_reset_amp_states_scope'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:412', 'caller_name': '_reset_amp_states_scope', 'callsite_name': '__exit__'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:336', 'caller_name': 'maybe_recompute', 'callsite_name': '__exit__'})
(pycall {'caller_loc': 'contextlib.py:144', 'caller_name': '__exit__', 'callsite_name': '_reset_rng_states_scope'})
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:392', 'caller_name': '_reset_rng_states_scope', 'callsite_name': '_set_dev_rng_state'})
pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:404', 'caller_name': '_set_dev_rng_state', 'callsite_name': 'set_rng_state'} ✓
dealloc {'id': 9, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/_unpack_hook/maybe_recompute/_reset_rng_states_scope/<dictcomp>/_get_dev_rng_state/get_rng_state', 'size': -5056, 'total_alloc': 46676} ✓
dealloc {'id': 3, 'name': 'test_gradient_checkpoint_scope/demo_run/forward/get_var_noise/maybe_store_rng_state/_get_dev_rng_state/get_rng_state', 'size': -5056, 'total_alloc': 41620} ✓
(pycall {'caller_loc': 'returnn/torch/util/gradient_checkpoint.py:214', 'caller_name': '_unpack_hook', 'callsite_name': 'get_recomputed'})
(torchop {'name': 'aten::mul'})
alloc {'id': 12, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/autograd::engine::evaluate_function: MulBackward0', 'size': 41612, 'total_alloc': 83232} ✓
(torchop {'name': 'aten::mul'})
alloc {'id': 13, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/autograd::engine::evaluate_function: MulBackward0', 'size': 41612, 'total_alloc': 124844} ✓
dealloc {'id': 11, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/_unpack_hook/maybe_recompute/recompute/aten::add', 'size': -41612, 'total_alloc': 83232} ✓
(torchop {'name': 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad'})
(torchop {'name': 'torch::autograd::AccumulateGrad'})
(torchop {'name': 'aten::detach'})
(torchop {'name': 'detach'})
(torchop {'name': 'autograd::engine::evaluate_function: AddBackward0'})
(torchop {'name': 'AddBackward0'})
(torchop {'name': 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad'})
(torchop {'name': 'torch::autograd::AccumulateGrad'})
(torchop {'name': 'aten::detach'})
(torchop {'name': 'detach'})
(dealloc {'id': 8, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/aten::ones_like', 'size': -4, 'total_alloc': 83228})
(dealloc {'id': 7, 'name': 'test_gradient_checkpoint_scope/demo_run/aten::sum', 'size': -4, 'total_alloc': 83224})
(pycall {'caller_loc': 'test_torch_util.py:52', 'caller_name': 'demo_run', 'callsite_name': 'wrapper'})
torchop {'name': 'Optimizer.step#SGD.step'} ✓
(torchop {'name': 'aten::add_'})
(torchop {'name': 'aten::add_'})
(pycall {'caller_loc': 'test_torch_util.py:53', 'caller_name': 'demo_run', 'callsite_name': 'inner'})
torchop {'name': 'Optimizer.zero_grad#SGD.zero_grad'} ✓
dealloc {'id': 13, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/autograd::engine::evaluate_function: MulBackward0', 'size': -41612, 'total_alloc': 41612} ✓
dealloc {'id': 12, 'name': 'test_gradient_checkpoint_scope/demo_run/backward/autograd::engine::evaluate_function: MulBackward0', 'size': -41612, 'total_alloc': 0} ✓
(pycall {'caller_loc': 'test_torch_util.py:95', 'caller_name': 'test_gradient_checkpoint_scope', 'callsite_name': '__exit__'})
(pycall {'caller_loc': 'test_torch_util.py:94', 'caller_name': 'test_gradient_checkpoint_scope', 'callsite_name': '__exit__'})

@albertz
Copy link
Member Author

albertz commented Jul 8, 2024

Ok I think this is ready now.
@NeoLegends please review.

Copy link
Collaborator

@NeoLegends NeoLegends left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better w/ the docs! I have some questions left, but nothing major.

returnn/torch/util/gradient_checkpoint.py Outdated Show resolved Hide resolved
returnn/torch/util/gradient_checkpoint.py Outdated Show resolved Hide resolved
returnn/torch/util/gradient_checkpoint.py Show resolved Hide resolved
returnn/torch/util/gradient_checkpoint.py Show resolved Hide resolved
returnn/torch/util/gradient_checkpoint.py Show resolved Hide resolved
returnn/torch/util/gradient_checkpoint.py Outdated Show resolved Hide resolved
returnn/torch/util/gradient_checkpoint.py Show resolved Hide resolved
returnn/torch/util/gradient_checkpoint.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@NeoLegends NeoLegends left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

@albertz albertz merged commit a0f8be5 into master Jul 9, 2024
56 checks passed
@albertz albertz deleted the albert-torch-grad-checkpoint branch July 9, 2024 10:22
@albertz albertz changed the title Torch gradient_checkpoint_scope Torch gradient_checkpoint_scope Jul 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Gradient checkpointing for weight noise etc in PyTorch
2 participants