-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[mutable-arrays] make remat work with external refs #33050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @mattjj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly improves the interaction between JAX's Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for jax.remat (or jax.checkpoint) with functions that have side-effects on mutable array references (jax.Ref). The implementation correctly establishes a sensible default behavior where the side-effects on external refs are executed only during the forward pass, and not during the re-computation on the backward pass.
The core logic is sound:
- In
remat_partial_eval, the initial values of anyRefarguments are captured. - A new helper,
_replace_arg_refs_with_vals, transforms the jaxpr for the rematerialization pass. This new jaxpr accepts the initial values of the refs (instead of the refs themselves) and creates fresh, internal refs for the re-computation.
This approach correctly isolates the side-effects to the forward pass while allowing rematerialization to work with the correct initial state.
The accompanying tests are thorough, covering cases where refs are created internally, passed as arguments, or closed over, with and without JIT compilation. The removal of the previous effect-disallowing check in remat_partial_eval is also appropriate given this new capability.
Overall, this is a well-implemented and well-tested enhancement to JAX's rematerialization feature. I have no specific comments for code changes.
Functions that only internally use refs are pure. jax-ml#31389 added support for those, though it didn't add tests. This PR adds a couple simple tests for remat-decorated functions that use only internal refs. With external refs, passed to a remat-decorated function either as explicit arguments or by closure, the semantics are ambiguous: do we want the effect only to happen once, or every time the function is re-run? Given that ambiguity, we plan to let the user be explicit about what behavior they want. But there's a sensible default behavior: only run the effects the first time. That is, running this program should print "2.0" and not "3.0": ```python @jax.remat def f(y, x_ref): out = y * x_ref[...] x_ref[...] += 1 return out x_ref = jax.new_ref(1.) jax.grad(f)(1., x_ref) print(x_ref) # 2.0 ``` We capture the initial value of the ref as a residual, and then use it to initialize a fresh (internal-only) ref used in the rematerializing pass.
Functions that only internally use refs are pure. #31389 added support for those, though it didn't add tests. This PR adds a couple simple tests for remat-decorated functions that use only internal refs.
With external refs, passed to a remat-decorated function either as explicit arguments or by closure, the semantics are ambiguous: do we want the effect only to happen once, or every time the function is re-run? Given that ambiguity, we plan to let the user be explicit about what behavior they want.
But there's a sensible default behavior: only run the effects the first time. That is, running this program should print "2.0" and not "3.0":
We capture the initial value of the ref as a residual, and then use it to initialize a fresh (internal-only) ref used in the rematerializing pass.
The implementation currently causes retracing the body of the remat. We could avoid that.