-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch gradient_checkpoint_scope
#1559
Conversation
a55823e
to
67afe72
Compare
(Could be wrong when run in different thread...!)
7ea5536
to
5b43cfb
Compare
Note, I think the implementation is ready now. What's missing are some tests, as outlined in #1552. But anyway, I think you could already start reviewing. If some code is unclear, please say so: Either I have thought wrong, or made a mistake, or if not, I at least should better document/comment that part. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it turns out this is very difficult to automatically test/verify, WDYT about having a setup w/ and w/o gradient checkpointing and designing it so that the one without goes OOM while the one using this functionality does not?
But this sounds very easy to automatically test? Have you seen the WIP test code here? It tests basically exactly that. Also, that's what I wrote already here: #1552 (comment) |
For reference, this is the event trace I see now. Due to GC logic, it might not be 100% deterministic. Although probably changes in PyTorch internals will also induce changes here. No grad checkpointing:
With gradient checkpointing:
|
Ok I think this is ready now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better w/ the docs! I have some questions left, but nothing major.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
gradient_checkpoint_scope
Gradient checkpointing for PyTorch.
Fix #1552.
This implements a new gradient checkpointing API for the user,
gradient_checkpoint_scope
, as a better alternative totorch.utils.checkpoint
, usingtorch.autograd.graph.saved_tensors_hooks
andTorchDispatchMode
under the hood, and also handling the RNG and AMP state.gradient_checkpoint_scope
creates a gradient checkpoint scope. All tensors created within this scope will not be stored for backpropagation, but will be recomputed on the fly during backpropagation.Example:
In this example, the tensor
x
will not be stored for backpropagation, i.e. the computationx = a + b
will be recomputed during backpropagation.Internally, this uses the PyTorch
torch.autograd.graph.saved_tensors_hooks
mechanism to override what we store for backpropagation, and how to recompute it. And we use the PyTorchTorchDispatchMode
to intercept all operations within the scope. Note that the usage oftorch.autograd.graph.saved_tensors_hooks
is tricky here as we need it beyond the scope of thegradient_checkpoint_scope
, specifically for all future usages of the tensorx
in the example. See the code documentation for more details on this.Note,
torch.utils.checkpoint
is different: You cannot easily specify what not to store / what to recompute. You rather specify a start/end point what to store for backpropagation, and then PyTorch will recompute everything in between. For the example above, you define thaty
is the end point and will be stored. It looks like this:PyTorch will not recompute
... * c
here, but it will recomputea + b
.We find this API more cumbersome to use and less flexible, because in many case, you know what you want to recompute, i.e. what you don't want to store. The PyTorch API is more about what you want to store, and then recompute everything else between.
See also
returnn.tf.util.gradient_checkpoint
: same API and logic in TF, although it heavily makes use of the TF computation graph, i.e. graph mode, which makes this particular feature much easier to implement.Further references:
#1552
https://discuss.pytorch.org/t/gradient-checkpointing/205416
pytorch/pytorch#129867
https://gist.github.com/soulitzer/ec1049a947be046de7fbc2af61a4ee8c
You are not a RETURNN user yet but just want to try this?
And then: