Skip to content

Conversation

@mattjj
Copy link
Collaborator

@mattjj mattjj commented Oct 31, 2025

Functions that only internally use refs are pure. #31389 added support for those, though it didn't add tests. This PR adds a couple simple tests for remat-decorated functions that use only internal refs.

With external refs, passed to a remat-decorated function either as explicit arguments or by closure, the semantics are ambiguous: do we want the effect only to happen once, or every time the function is re-run? Given that ambiguity, we plan to let the user be explicit about what behavior they want.

But there's a sensible default behavior: only run the effects the first time. That is, running this program should print "2.0" and not "3.0":

@jax.remat
def f(y, x_ref):
  out = y * x_ref[...]
  x_ref[...] += 1
  return out

x_ref = jax.new_ref(1.)
jax.grad(f)(1., x_ref)
print(x_ref)  # 2.0

We capture the initial value of the ref as a residual, and then use it to initialize a fresh (internal-only) ref used in the rematerializing pass.

The implementation currently causes retracing the body of the remat. We could avoid that.

@mattjj mattjj requested a review from dougalm October 31, 2025 22:49
@mattjj mattjj self-assigned this Oct 31, 2025
@mattjj mattjj added the pull ready Ready for copybara import and testing label Oct 31, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @mattjj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly improves the interaction between JAX's remat (checkpointing) mechanism and mutable references (jax.new_ref). It provides a clear and consistent semantic for how effects from mutable references are handled during rematerialization, particularly when used within remat-decorated functions. This ensures that automatic differentiation works correctly and predictably even with stateful operations involving mutable arrays, addressing a previously ambiguous area and expanding the capabilities of JAX for more complex models.

Highlights

  • Enhanced remat with mutable references: The jax.remat (checkpointing) mechanism now correctly handles jax.new_ref (mutable arrays) within functions, ensuring predictable behavior during automatic differentiation.
  • Predictable effect handling: For remat-decorated functions with external refs, the effects are designed to run only the first time, capturing the initial ref value as a residual to maintain consistency across rematerialization.
  • Ref argument transformation: A new internal mechanism has been introduced to replace AbstractRef arguments with their corresponding value arguments during partial evaluation for remat, streamlining state management.
  • Expanded test coverage: Several new tests have been added to validate remat's interaction with mutable references, covering scenarios where refs are internal, passed as arguments, or closed over, including nested remat applications.
  • Removed effect restriction: An earlier NotImplementedError check that prevented certain effects from being used in remat_partial_eval has been removed, allowing for greater flexibility.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for jax.remat (or jax.checkpoint) with functions that have side-effects on mutable array references (jax.Ref). The implementation correctly establishes a sensible default behavior where the side-effects on external refs are executed only during the forward pass, and not during the re-computation on the backward pass.

The core logic is sound:

  1. In remat_partial_eval, the initial values of any Ref arguments are captured.
  2. A new helper, _replace_arg_refs_with_vals, transforms the jaxpr for the rematerialization pass. This new jaxpr accepts the initial values of the refs (instead of the refs themselves) and creates fresh, internal refs for the re-computation.

This approach correctly isolates the side-effects to the forward pass while allowing rematerialization to work with the correct initial state.

The accompanying tests are thorough, covering cases where refs are created internally, passed as arguments, or closed over, with and without JIT compilation. The removal of the previous effect-disallowing check in remat_partial_eval is also appropriate given this new capability.

Overall, this is a well-implemented and well-tested enhancement to JAX's rematerialization feature. I have no specific comments for code changes.

@mattjj mattjj changed the title [mutable-arrays] make one-level remat work with refs [mutable-arrays] make remat work with external refs Nov 1, 2025
Functions that only internally use refs are pure. jax-ml#31389 added support for
those, though it didn't add tests. This PR adds a couple simple tests for
remat-decorated functions that use only internal refs.

With external refs, passed to a remat-decorated function either as explicit
arguments or by closure, the semantics are ambiguous: do we want the effect
only to happen once, or every time the function is re-run? Given that
ambiguity, we plan to let the user be explicit about what behavior they want.

But there's a sensible default behavior: only run the effects the first time.
That is, running this program should print "2.0" and not "3.0":

```python
@jax.remat
def f(y, x_ref):
  out = y * x_ref[...]
  x_ref[...] += 1
  return out

x_ref = jax.new_ref(1.)
jax.grad(f)(1., x_ref)
print(x_ref)  # 2.0
```

We capture the initial value of the ref as a residual, and then use it to
initialize a fresh (internal-only) ref used in the rematerializing pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant