Skip to content

Commit ad55eb8

Browse files
Merge pull request #31389 from mattjj:remat-internal-refs
PiperOrigin-RevId: 800212283
2 parents 36d6858 + 1e2c56e commit ad55eb8

File tree

5 files changed

+42
-17
lines changed

5 files changed

+42
-17
lines changed

jax/_src/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2582,6 +2582,7 @@ class InternalMutableArrayEffect(effects.Effect):
25822582
pass
25832583
array_ref_effect = internal_mutable_array_effect = InternalMutableArrayEffect()
25842584
effects.control_flow_allowed_effects.add_type(InternalMutableArrayEffect)
2585+
effects.remat_allowed_effects.add_type(InternalMutableArrayEffect)
25852586

25862587
@array_ref_p.def_effectful_abstract_eval
25872588
def array_ref_abstract_eval(init_aval, *, memory_space: Any):

jax/_src/interpreters/partial_eval.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,6 @@ def has_effects(effects) -> bool:
11931193
out_unknowns = map(op.or_, out_unknowns, ensure_out_unknowns)
11941194
out_inst = map(op.or_, out_inst, ensure_out_inst)
11951195

1196-
11971196
ins_known, _ = partition_list(in_unknowns, jaxpr.invars)
11981197
outs_known, _ = partition_list(out_unknowns, jaxpr.outvars)
11991198
ref_res_is_input = [r in ins_known for r in residual_refs]

jax/_src/lax/control_flow/loops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1465,7 +1465,6 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
14651465
def _scan_state_partial_discharge_rule(
14661466
should_discharge, in_avals, out_avals, *args, jaxpr, num_consts, num_carry,
14671467
linear, unroll, reverse, length, _split_transpose):
1468-
if jaxpr.consts: raise NotImplementedError("open an issue!") # TODO(mattjj)
14691468
# jaxpr: [*consts, *pure_carry, *xs] -> [*pure_carry, *pure_ys]
14701469
# jaxpr_: [*consts, *pure_carry, *xs] -> [*pure_carry, *pure_ys, *ref_outs]
14711470
discharged_jaxpr = state_discharge.discharge_state2(jaxpr, should_discharge)

jax/_src/pallas/core.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,3 +1517,18 @@ def lower_as_mlir(
15171517
_out_shape_to_aval_mapping: dict[
15181518
type[Any], Callable[[Any], jax_core.AbstractValue]
15191519
] = {}
1520+
1521+
1522+
def _core_map_partial_eval_custom(saveable, unks_in, inst_in, eqn):
1523+
assert all(inst_in)
1524+
if all(unks_in):
1525+
return None, eqn, [], [], [] # purely unknown
1526+
elif not any(unks_in):
1527+
return eqn, eqn, [], [], [] # full remat
1528+
else:
1529+
# Some values, e.g. empty refs or refs initialized to constant zero, can be
1530+
# 'known', but really they belong in the staged/tangent computation. We
1531+
# encounter them here as known inputs mixed in with unknown/tangent inputs,
1532+
# which tells us that this core_map is really a purely tangent computation.
1533+
return None, eqn, [], [], []
1534+
pe.partial_eval_jaxpr_custom_rules[core_map_p] = _core_map_partial_eval_custom

jax/_src/state/primitives.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -598,21 +598,32 @@ def addupdate_transpose(cts_in, ref, x, *idx, **params):
598598

599599
## get/swap/addupdate partial_eval_custom rules
600600

601-
def _state_partial_eval_custom(prim, saveable, unks_in, inst_in, eqn):
602-
if any(unks_in):
603-
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
604-
return None, eqn, [True] * len(eqn.outvars), [True] * len(eqn.outvars), res
605-
elif saveable(prim, *[var.aval for var in eqn.invars], **eqn.params):
606-
return eqn, None, [False] * len(eqn.outvars), [False] * len(eqn.outvars), []
607-
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
608-
return eqn, eqn, [False] * len(eqn.outvars), [True] * len(eqn.outvars), res
609-
610-
pe.partial_eval_jaxpr_custom_rules[get_p] = partial(_state_partial_eval_custom,
611-
get_p)
612-
pe.partial_eval_jaxpr_custom_rules[swap_p] = partial(_state_partial_eval_custom,
613-
swap_p)
614-
pe.partial_eval_jaxpr_custom_rules[addupdate_p] = partial(
615-
_state_partial_eval_custom, addupdate_p)
601+
def _array_ref_partial_eval_custom(saveable, unks_in, inst_in, eqn):
602+
del saveable # ignored, always full remat array_ref on known input
603+
unk, = unks_in
604+
inst, = inst_in
605+
invar, = eqn.invars
606+
res = [invar] if not inst else []
607+
if unk:
608+
return None, eqn, [True], [True], res # tangent operation
609+
else:
610+
return eqn, eqn, [False], [True], res # full remat
611+
pe.partial_eval_jaxpr_custom_rules[core.array_ref_p] = _array_ref_partial_eval_custom
612+
613+
def _state_partial_eval_custom(saveable, unks_in, inst_in, eqn):
614+
del saveable # ignored, always full remat state ops on known inputs
615+
ref_unk, *_ = unks_in
616+
ref_inst, *inst_in = inst_in
617+
_, *val_vars = eqn.invars
618+
assert ref_inst
619+
res = [v for v, inst in zip(val_vars, inst_in) if not inst]
620+
if ref_unk:
621+
return None, eqn, [True], [True], res # tangent operation
622+
else:
623+
return eqn, eqn, [False], [True], res # full remat
624+
pe.partial_eval_jaxpr_custom_rules[get_p] = _state_partial_eval_custom
625+
pe.partial_eval_jaxpr_custom_rules[swap_p] = _state_partial_eval_custom
626+
pe.partial_eval_jaxpr_custom_rules[addupdate_p] = _state_partial_eval_custom
616627

617628
## get/swap/addupdate batching rules
618629

0 commit comments

Comments
 (0)