Skip to content

Conversation

@mattjj
Copy link
Collaborator

@mattjj mattjj commented Aug 27, 2025

A few ingredients were needed:

  1. in core.py, mark InternalMutableArrayEffect as allowed under remat (just whitelisting it)
  2. remove a stale NotImplementedError in the scan discharge rule, solved when we used the aptly-named state_discharge2
  3. adapt stateful ops remat partial eval rules, basically to do full remat regardless of remat policy
  4. add such a rule for pallas core_map

@dougalm observed a downside to item 3: it means remat with an everything-saveable policy is not the same as AD classic (without the remat decorator). That's because we always remat these internally-stateful operations, regardless of policy. An alternative is to always save such computations (regardless of policy), but then we can't recover full remat. We don't want to risk splitting up stateful operations across the two sides, and it's sufficient to always put them one way or the other (regardless of policy). One way around this would be to statically switch on everything-saveable and/or nothing-saveable policies, represented as something other than an opaque callback.

We still can't handle this code until we redo remat again on top of direct-linearize (remat3):

import jax
import jax.numpy as jnp
jax.config.update('jax_enable_checks', True)

@jax.remat
def f(x):
  x_ref = jax.array_ref(0.)
  x_ref[...] = x
  return x_ref[...]

jax.grad(f)(3.)

Co-authored-by: Sharad Vikram <sharadmv@google.com>
@mattjj mattjj requested review from dougalm and sharadmv August 27, 2025 22:25
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Aug 27, 2025
@copybara-service copybara-service bot merged commit ad55eb8 into jax-ml:main Aug 27, 2025
27 of 28 checks passed
@mattjj mattjj deleted the remat-internal-refs branch August 27, 2025 23:17
mattjj added a commit to mattjj/jax that referenced this pull request Oct 31, 2025
Functions that only internally use refs are pure. jax-ml#31389 added support for
those, though it didn't add tests. This PR adds a couple simple tests for
remat-decorated functions that use only internal refs.

With external refs, passed to a remat-decorated function either as explicit
arguments or by closure, the semantics are ambiguous: do we want the effect
only to happen once, or every time the function is re-run? Given that
ambiguity, we plan to let the user be explicit about what behavior they want.

But there's a sensible default behavior: only run the effects the first time.
That is, running this program should print "2.0" and not "3.0":

```python
@jax.remat
def f(y, x_ref):
  out = y * x_ref[...]
  x_ref[...] += 1
  return out

x_ref = jax.new_ref(1.)
jax.grad(f)(1., x_ref)
print(x_ref)  # 2.0
```

We capture the initial value of the ref as a residual, and then use it to
initialize a fresh (internal-only) ref used in the rematerializing pass.
mattjj added a commit to mattjj/jax that referenced this pull request Oct 31, 2025
Functions that only internally use refs are pure. jax-ml#31389 added support for
those, though it didn't add tests. This PR adds a couple simple tests for
remat-decorated functions that use only internal refs.

With external refs, passed to a remat-decorated function either as explicit
arguments or by closure, the semantics are ambiguous: do we want the effect
only to happen once, or every time the function is re-run? Given that
ambiguity, we plan to let the user be explicit about what behavior they want.

But there's a sensible default behavior: only run the effects the first time.
That is, running this program should print "2.0" and not "3.0":

```python
@jax.remat
def f(y, x_ref):
  out = y * x_ref[...]
  x_ref[...] += 1
  return out

x_ref = jax.new_ref(1.)
jax.grad(f)(1., x_ref)
print(x_ref)  # 2.0
```

We capture the initial value of the ref as a residual, and then use it to
initialize a fresh (internal-only) ref used in the rematerializing pass.
mattjj added a commit to mattjj/jax that referenced this pull request Nov 2, 2025
Functions that only internally use refs are pure. jax-ml#31389 added support for
those, though it didn't add tests. This PR adds a couple simple tests for
remat-decorated functions that use only internal refs.

With external refs, passed to a remat-decorated function either as explicit
arguments or by closure, the semantics are ambiguous: do we want the effect
only to happen once, or every time the function is re-run? Given that
ambiguity, we plan to let the user be explicit about what behavior they want.

But there's a sensible default behavior: only run the effects the first time.
That is, running this program should print "2.0" and not "3.0":

```python
@jax.remat
def f(y, x_ref):
  out = y * x_ref[...]
  x_ref[...] += 1
  return out

x_ref = jax.new_ref(1.)
jax.grad(f)(1., x_ref)
print(x_ref)  # 2.0
```

We capture the initial value of the ref as a residual, and then use it to
initialize a fresh (internal-only) ref used in the rematerializing pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kokoro:force-run pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants