Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve readthedocs behavior for jax.remat / jax.checkpoint #24064

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ Automatic differentiation
custom_gradient
closure_convert
checkpoint
remat

``custom_jvp``
~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@

from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
from jax._src.ad_checkpoint import checkpoint as checkpoint
from jax._src.ad_checkpoint import remat as remat
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as _deprecated_clear_backends
from jax._src.api import clear_caches as clear_caches
Expand Down Expand Up @@ -122,7 +123,6 @@
from jax._src.xla_bridge import process_index as process_index
from jax._src.xla_bridge import process_indices as process_indices
from jax._src.callback import pure_callback as pure_callback
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
Expand Down
74 changes: 12 additions & 62 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
import functools
from functools import partial
import logging
from typing import Any
Expand Down Expand Up @@ -168,6 +167,7 @@ def policy(prim, *args, **params):
def checkpoint(fun: Callable, *, prevent_cse: bool = True,
policy: Callable[..., bool] | None = None,
static_argnums: int | tuple[int, ...] = (),
concrete: bool = False,
) -> Callable:
"""Make ``fun`` recompute internal linearization points when differentiated.

Expand Down Expand Up @@ -222,6 +222,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
returns a boolean indicating whether the corresponding output value(s) can
be saved as residuals (or instead must be recomputed in the (co)tangent
computation if needed).
concrete: Ignored vestigial argument. It does nothing.

Returns:
A function (callable) with the same input/output behavior as ``fun`` but
Expand Down Expand Up @@ -309,6 +310,8 @@ def foo(x, y):
``jax.ensure_compile_time_eval``), it may be easier to compute some values
outside the :func:`jax.checkpoint`-decorated function and then close over them.
"""
del concrete # Ignored.

@wraps(fun)
@api_boundary
def fun_remat(*args, **kwargs):
Expand All @@ -322,7 +325,14 @@ def fun_remat(*args, **kwargs):
return tree_unflatten(out_tree, out_flat)
return fun_remat

remat = checkpoint # alias
def remat(fun: Callable, *, prevent_cse: bool = True,
policy: Callable[..., bool] | None = None,
static_argnums: int | tuple[int, ...] = (),
) -> Callable:
"""Alias of :func:`~jax.checkpoint`."""
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)


# This function is similar to api_util.argnums_partial, except the error
# messages are specific to jax.remat (and thus more actionable), the
Expand Down Expand Up @@ -855,65 +865,5 @@ def name_batcher(args, dims, *, name):
return name_p.bind(x, name=name), d
batching.primitive_batchers[name_p] = name_batcher


@functools.wraps(checkpoint)
def checkpoint_wrapper(
fun: Callable,
*,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: int | tuple[int, ...] = (),
policy: Callable[..., bool] | None = None,
) -> Callable:
if concrete:
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
"in its place, you can use its `static_argnums` option, and if "
"necessary the `jax.ensure_compile_time_eval()` context manager.\n"
"\n"
"For example, if using `concrete=True` for an `is_training` flag:\n"
"\n"
" from functools import partial\n"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, is_training):\n"
" if is_training:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"replace it with a use of `static_argnums`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, is_training):\n"
" ...\n"
"\n"
"If jax.numpy operations need to be performed on static arguments, "
"we can use the `jax.ensure_compile_time_eval()` context manager. "
"For example, we can replace this use of `concrete=True`\n:"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, y):\n"
" if y > 0:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"with this combination of `static_argnums` and "
"`jax.ensure_compile_time_eval()`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, y):\n"
" with jax.ensure_compile_time_eval():\n"
" y_pos = y > 0\n"
" if y_pos:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n")
raise NotImplementedError(msg)
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)

# TODO(phawkins): update users to refer to the public name.
_optimization_barrier = lax_internal.optimization_barrier