Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2582,6 +2582,7 @@ class InternalMutableArrayEffect(effects.Effect):
pass
array_ref_effect = internal_mutable_array_effect = InternalMutableArrayEffect()
effects.control_flow_allowed_effects.add_type(InternalMutableArrayEffect)
effects.remat_allowed_effects.add_type(InternalMutableArrayEffect)

@array_ref_p.def_effectful_abstract_eval
def array_ref_abstract_eval(init_aval, *, memory_space: Any):
Expand Down
1 change: 0 additions & 1 deletion jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,6 @@ def has_effects(effects) -> bool:
out_unknowns = map(op.or_, out_unknowns, ensure_out_unknowns)
out_inst = map(op.or_, out_inst, ensure_out_inst)


ins_known, _ = partition_list(in_unknowns, jaxpr.invars)
outs_known, _ = partition_list(out_unknowns, jaxpr.outvars)
ref_res_is_input = [r in ins_known for r in residual_refs]
Expand Down
1 change: 0 additions & 1 deletion jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,6 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
def _scan_state_partial_discharge_rule(
should_discharge, in_avals, out_avals, *args, jaxpr, num_consts, num_carry,
linear, unroll, reverse, length, _split_transpose):
if jaxpr.consts: raise NotImplementedError("open an issue!") # TODO(mattjj)
# jaxpr: [*consts, *pure_carry, *xs] -> [*pure_carry, *pure_ys]
# jaxpr_: [*consts, *pure_carry, *xs] -> [*pure_carry, *pure_ys, *ref_outs]
discharged_jaxpr = state_discharge.discharge_state2(jaxpr, should_discharge)
Expand Down
15 changes: 15 additions & 0 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,3 +1517,18 @@ def lower_as_mlir(
_out_shape_to_aval_mapping: dict[
type[Any], Callable[[Any], jax_core.AbstractValue]
] = {}


def _core_map_partial_eval_custom(saveable, unks_in, inst_in, eqn):
assert all(inst_in)
if all(unks_in):
return None, eqn, [], [], [] # purely unknown
elif not any(unks_in):
return eqn, eqn, [], [], [] # full remat
else:
# Some values, e.g. empty refs or refs initialized to constant zero, can be
# 'known', but really they belong in the staged/tangent computation. We
# encounter them here as known inputs mixed in with unknown/tangent inputs,
# which tells us that this core_map is really a purely tangent computation.
return None, eqn, [], [], []
pe.partial_eval_jaxpr_custom_rules[core_map_p] = _core_map_partial_eval_custom
41 changes: 26 additions & 15 deletions jax/_src/state/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,21 +598,32 @@ def addupdate_transpose(cts_in, ref, x, *idx, **params):

## get/swap/addupdate partial_eval_custom rules

def _state_partial_eval_custom(prim, saveable, unks_in, inst_in, eqn):
if any(unks_in):
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
return None, eqn, [True] * len(eqn.outvars), [True] * len(eqn.outvars), res
elif saveable(prim, *[var.aval for var in eqn.invars], **eqn.params):
return eqn, None, [False] * len(eqn.outvars), [False] * len(eqn.outvars), []
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
return eqn, eqn, [False] * len(eqn.outvars), [True] * len(eqn.outvars), res

pe.partial_eval_jaxpr_custom_rules[get_p] = partial(_state_partial_eval_custom,
get_p)
pe.partial_eval_jaxpr_custom_rules[swap_p] = partial(_state_partial_eval_custom,
swap_p)
pe.partial_eval_jaxpr_custom_rules[addupdate_p] = partial(
_state_partial_eval_custom, addupdate_p)
def _array_ref_partial_eval_custom(saveable, unks_in, inst_in, eqn):
del saveable # ignored, always full remat array_ref on known input
unk, = unks_in
inst, = inst_in
invar, = eqn.invars
res = [invar] if not inst else []
if unk:
return None, eqn, [True], [True], res # tangent operation
else:
return eqn, eqn, [False], [True], res # full remat
pe.partial_eval_jaxpr_custom_rules[core.array_ref_p] = _array_ref_partial_eval_custom

def _state_partial_eval_custom(saveable, unks_in, inst_in, eqn):
del saveable # ignored, always full remat state ops on known inputs
ref_unk, *_ = unks_in
ref_inst, *inst_in = inst_in
_, *val_vars = eqn.invars
assert ref_inst
res = [v for v, inst in zip(val_vars, inst_in) if not inst]
if ref_unk:
return None, eqn, [True], [True], res # tangent operation
else:
return eqn, eqn, [False], [True], res # full remat
pe.partial_eval_jaxpr_custom_rules[get_p] = _state_partial_eval_custom
pe.partial_eval_jaxpr_custom_rules[swap_p] = _state_partial_eval_custom
pe.partial_eval_jaxpr_custom_rules[addupdate_p] = _state_partial_eval_custom

## get/swap/addupdate batching rules

Expand Down
Loading