Skip to content

Commit

Permalink
improve readthedocs behavior for jax.remat / jax.checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Dec 18, 2024
1 parent fcfb0b7 commit 1d03e17
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 64 deletions.
1 change: 1 addition & 0 deletions docs/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ Automatic differentiation
custom_gradient
closure_convert
checkpoint
remat

``custom_jvp``
~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@

from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
from jax._src.ad_checkpoint import checkpoint as checkpoint
from jax._src.ad_checkpoint import remat as remat
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as _deprecated_clear_backends
from jax._src.api import clear_caches as clear_caches
Expand Down Expand Up @@ -122,7 +123,6 @@
from jax._src.xla_bridge import process_index as process_index
from jax._src.xla_bridge import process_indices as process_indices
from jax._src.callback import pure_callback as pure_callback
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
Expand Down
74 changes: 12 additions & 62 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
import functools
from functools import partial
import logging
from typing import Any
Expand Down Expand Up @@ -168,6 +167,7 @@ def policy(prim, *args, **params):
def checkpoint(fun: Callable, *, prevent_cse: bool = True,
policy: Callable[..., bool] | None = None,
static_argnums: int | tuple[int, ...] = (),
concrete: bool = False,
) -> Callable:
"""Make ``fun`` recompute internal linearization points when differentiated.
Expand Down Expand Up @@ -222,6 +222,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
returns a boolean indicating whether the corresponding output value(s) can
be saved as residuals (or instead must be recomputed in the (co)tangent
computation if needed).
concrete: Ignored vestigial argument. It does nothing.
Returns:
A function (callable) with the same input/output behavior as ``fun`` but
Expand Down Expand Up @@ -309,6 +310,8 @@ def foo(x, y):
``jax.ensure_compile_time_eval``), it may be easier to compute some values
outside the :func:`jax.checkpoint`-decorated function and then close over them.
"""
del concrete # Ignored.

@wraps(fun)
@api_boundary
def fun_remat(*args, **kwargs):
Expand All @@ -322,7 +325,14 @@ def fun_remat(*args, **kwargs):
return tree_unflatten(out_tree, out_flat)
return fun_remat

remat = checkpoint # alias
def remat(fun: Callable, *, prevent_cse: bool = True,
policy: Callable[..., bool] | None = None,
static_argnums: int | tuple[int, ...] = (),
) -> Callable:
"""Alias of :func:`~jax.checkpoint`."""
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)


# This function is similar to api_util.argnums_partial, except the error
# messages are specific to jax.remat (and thus more actionable), the
Expand Down Expand Up @@ -855,65 +865,5 @@ def name_batcher(args, dims, *, name):
return name_p.bind(x, name=name), d
batching.primitive_batchers[name_p] = name_batcher


@functools.wraps(checkpoint)
def checkpoint_wrapper(
fun: Callable,
*,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: int | tuple[int, ...] = (),
policy: Callable[..., bool] | None = None,
) -> Callable:
if concrete:
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
"in its place, you can use its `static_argnums` option, and if "
"necessary the `jax.ensure_compile_time_eval()` context manager.\n"
"\n"
"For example, if using `concrete=True` for an `is_training` flag:\n"
"\n"
" from functools import partial\n"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, is_training):\n"
" if is_training:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"replace it with a use of `static_argnums`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, is_training):\n"
" ...\n"
"\n"
"If jax.numpy operations need to be performed on static arguments, "
"we can use the `jax.ensure_compile_time_eval()` context manager. "
"For example, we can replace this use of `concrete=True`\n:"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, y):\n"
" if y > 0:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"with this combination of `static_argnums` and "
"`jax.ensure_compile_time_eval()`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, y):\n"
" with jax.ensure_compile_time_eval():\n"
" y_pos = y > 0\n"
" if y_pos:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n")
raise NotImplementedError(msg)
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)

# TODO(phawkins): update users to refer to the public name.
_optimization_barrier = lax_internal.optimization_barrier

0 comments on commit 1d03e17

Please sign in to comment.