From 8770fb283baf9a5cfae5f9b4620c9e7e9eb77f33 Mon Sep 17 00:00:00 2001 From: Keshav Date: Fri, 20 Sep 2024 11:48:41 -0700 Subject: [PATCH] set default value to True --- jax/_src/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index ee258b93fb1c..3941a1b9fd58 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1530,7 +1530,7 @@ def _update_disable_jit_thread_local(val): enable_remat_opt_pass = bool_state( name='jax_compiler_enable_remat_pass', - default=False, + default=True, help=('Config to enable / disable the rematerialization HLO pass. ' 'Useful to allow XLA to automatically trade off memory and ' 'compute when encountering OOM errors. However, you are '