Skip to content

Commit

Permalink
Merge pull request #23738 from keshavb96:disable_remat_pass
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681510015
  • Loading branch information
Google-ML-Automation committed Oct 2, 2024
2 parents 78b65dd + 8770fb2 commit cfb7541
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def get_compile_options(
debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False

if not config.enable_remat_opt_pass.value:
debug_options.xla_disable_hlo_passes = "rematerialization"

# XLA-AutoFDO profile version: precedence order is:
# 1. Whatever --jax_xla_profile_version is set to.
Expand Down
8 changes: 8 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,14 @@ def _update_disable_jit_thread_local(val):
default=True,
help=('Enables using optimization-barrier op for lowering remat.'))

enable_remat_opt_pass = bool_state(
name='jax_compiler_enable_remat_pass',
default=True,
help=('Config to enable / disable the rematerialization HLO pass. '
'Useful to allow XLA to automatically trade off memory and '
'compute when encountering OOM errors. However, you are '
'likely to get better results manually with jax.checkpoint'))

# TODO(sharadmv,mattjj): set default to True, then remove
eager_pmap = bool_state(
name='jax_eager_pmap',
Expand Down

0 comments on commit cfb7541

Please sign in to comment.