From e954dfd55ac69b1a768ac1b3a822bf8e884812ed Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 15 Oct 2024 10:15:24 -0700 Subject: [PATCH] Update jax/_src/ad_checkpoint.py Co-authored-by: Jake Vanderplas --- jax/_src/ad_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 53398f9c1dd9..d27996add9a2 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -330,7 +330,6 @@ def remat(fun: Callable, *, prevent_cse: bool = True, return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums) -remat = checkpoint # alias # This function is similar to api_util.argnums_partial, except the error # messages are specific to jax.remat (and thus more actionable), the