Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2672,7 +2672,7 @@ def __init__(self, aval, buf):
pytype_aval_mappings[Ref] = lambda x: x._aval
dtypes.canonicalize_value_handlers[Ref] = lambda x: x

def new_ref(init_val, *, memory_space: Any = None):
def new_ref(init_val, *, memory_space: Any = None, kind: Any = None):
"""Create a mutable array reference with initial value ``init_val``.

For more discussion, see the `Ref guide`_.
Expand All @@ -2687,19 +2687,18 @@ def new_ref(init_val, *, memory_space: Any = None):

.. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html
"""
return ref_p.bind(init_val, memory_space=memory_space)
return ref_p.bind(init_val, memory_space=memory_space, kind=kind)
ref_p = Primitive('new_ref')
ref_p.is_effectful = lambda params: True # type: ignore
ref_p.ref_primitive = True

ref_p.is_high = lambda aval, *, memory_space: aval.is_high # type: ignore
def _ref_to_lojax(init_val, *, memory_space):
ref_p.is_high = lambda aval, *, memory_space, kind: aval.is_high # type: ignore
def _ref_to_lojax(init_val, *, memory_space, kind):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
val_ty = typeof(init_val)
hival_of_refs = val_ty.raise_val(*map(new_ref, val_ty.lower_val(init_val))) # type: ignore
aval = AbstractRef(typeof(init_val))
return Ref(AbstractRef(val_ty), hival_of_refs)
# return Ref(
ref_p.to_lojax = _ref_to_lojax # type: ignore


Expand All @@ -2710,19 +2709,19 @@ class InternalMutableArrayEffect(effects.Effect):
effects.remat_allowed_effects.add_type(InternalMutableArrayEffect)

@ref_p.def_effectful_abstract_eval
def array_ref_abstract_eval(init_aval, *, memory_space: Any):
def _ref_abstract_eval(init_aval, *, memory_space: Any, kind: Any):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
return (AbstractRef(init_aval, memory_space=memory_space),
return (AbstractRef(init_aval, memory_space=memory_space, kind=kind),
{internal_mutable_array_effect})

@ref_p.def_impl
def _array_ref_impl(init_val, *, memory_space: Any):
def _ref_impl(init_val, *, memory_space: Any, kind: Any):
if memory_space is not None:
raise NotImplementedError(
"array ref with memory space only works inside of a `jit`.")
from jax._src.state.types import AbstractRef # pytype: disable=import-error
from jax._src.lax.lax import _array_copy # pytype: disable=import-error
aval = AbstractRef(typeof(init_val))
aval = AbstractRef(typeof(init_val), kind=kind)
return Ref(aval, ArrayRefImpl(aval, _array_copy(init_val)))

def freeze(ref: Ref) -> Array:
Expand Down
43 changes: 24 additions & 19 deletions jax/_src/state/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,8 @@ def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any):
out_primal = swap_p.bind(ref_primal, x_primal, *idx, **params)
if isinstance(ref_tangent, ad_util.Zero) and isinstance(x_tangent, ad_util.Zero):
out_tangent = ad_util.Zero(core.typeof(out_primal).to_tangent_aval())
elif ref_tangent.aval.kind == "anselm_ref":
out_tangent = ad_util.Zero(core.typeof(out_primal).to_tangent_aval())
else:
if isinstance(ref_tangent, ad_util.Zero):
raise Exception("performing a set/swap operation with a differentiated "
Expand All @@ -610,8 +612,9 @@ def addupdate_jvp_rule(primals: list[Any], tangents: list[Any], **params: Any):
ref_primal, x_primal, *idx = primals
ref_tangent, x_tangent, *_ = tangents
x_tangent = ad_util.instantiate(x_tangent)
addupdate_p.bind(ref_primal, x_primal, *idx, **params)
addupdate_p.bind(ref_tangent, x_tangent, *idx, **params)
if ref_tangent.aval.kind != "anselm_ref":
addupdate_p.bind(ref_primal, x_primal, *idx, **params)
addupdate_p.bind(ref_tangent, x_tangent, *idx, **params)
return [], []
ad.primitive_jvps[addupdate_p] = addupdate_jvp_rule

Expand Down Expand Up @@ -675,16 +678,16 @@ def _array_ref_partial_eval_custom(saveable, unks_in, inst_in, eqn):
return eqn, eqn, [False], [True], res # full remat
pe.partial_eval_jaxpr_custom_rules[core.ref_p] = _array_ref_partial_eval_custom

def _array_ref_batched(axis_data, vals_in, dims_in, memory_space):
def _array_ref_batched(axis_data, vals_in, dims_in, memory_space, kind):
val, = vals_in
dim, = dims_in
if dim is None:
# We defensively batch the ref, b/c it could later be hit with a batched val
val2 = batching.broadcast(val, axis_data.size, 0,
axis_data.explicit_mesh_axis)
return core.ref_p.bind(val2, memory_space=memory_space), 0
return core.ref_p.bind(val2, memory_space=memory_space, kind=kind), 0
else:
return core.ref_p.bind(val, memory_space=memory_space), dim
return core.ref_p.bind(val, memory_space=memory_space, kind=kind), dim
batching.fancy_primitive_batchers[core.ref_p] = _array_ref_batched

def _freeze_batched(axis_data, vals_in, dims_in):
Expand All @@ -695,13 +698,16 @@ def _freeze_batched(axis_data, vals_in, dims_in):

def _state_partial_eval_custom(saveable, unks_in, inst_in, eqn):
del saveable # ignored, always full remat state ops on known inputs
# (except for anselm_ref)
ref_unk, *_ = unks_in
ref_inst, *inst_in = inst_in
_, *val_vars = eqn.invars
assert ref_inst
res = [v for v, inst in zip(val_vars, inst_in) if not inst]
if ref_unk:
return None, eqn, [True], [True], res # tangent operation
elif eqn.invars[0].aval.kind == "anselm_ref":
return eqn, None, [False], [False], res
else:
return eqn, eqn, [False], [True], res # full remat
pe.partial_eval_jaxpr_custom_rules[get_p] = _state_partial_eval_custom
Expand Down Expand Up @@ -1069,27 +1075,26 @@ def _broadcast_to_abstract_eval(aval, *, shape):

# === AD rules for mutable arrays ===

def _mut_jvp(primals, tangents, *, memory_space):
(init_val,), (init_val_dot,) = primals, tangents
primal_out = core.ref_p.bind(init_val, memory_space=memory_space)
if type(init_val_dot) is ad_util.Zero:
tangent_out = core.ref_p.bind(
ad_util.zeros_like_aval(init_val_dot.aval), memory_space=memory_space)
def _ref_jvp(primals, tangents, *, memory_space, kind):
(init_val,), (init_dot,) = primals, tangents
primal_out = core.ref_p.bind(init_val, memory_space=memory_space, kind=kind)
if type(init_dot) is ad_util.Zero:
zero = ad_util.zeros_like_aval(init_dot.aval)
tangent_out = core.ref_p.bind(zero, memory_space=memory_space, kind=kind)
else:
tangent_out = core.ref_p.bind(init_val_dot,
memory_space=memory_space)
tangent_out = core.ref_p.bind(init_dot, memory_space=memory_space, kind=kind)
return primal_out, tangent_out

def _mut_lin(nzs, x, *, memory_space):
def _ref_lin(nzs, x, *, memory_space, kind):
nz, = nzs
x_ref = core.ref_p.bind(x, memory_space=memory_space)
x_ref = core.ref_p.bind(x, memory_space=memory_space, kind=kind)
def mut_lin(_, x_dot):
return core.ref_p.bind(ad_util.instantiate(x_dot),
memory_space=memory_space)
zero = ad_util.instantiate(x_dot)
return core.ref_p.bind(zero, memory_space=memory_space, kind=kind)
return x_ref, True, None, mut_lin

ad.primitive_jvps[core.ref_p] = _mut_jvp
ad.primitive_linearizations[core.ref_p] = _mut_lin
ad.primitive_jvps[core.ref_p] = _ref_jvp
ad.primitive_linearizations[core.ref_p] = _ref_lin
# TODO(mattjj): lin rule for freeze and accum_grad_in_ref?
ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g))
ad.defjvp(core.accum_grad_in_ref_p, lambda g, _: core.accum_grad_in_ref_p.bind(g))
Expand Down
9 changes: 6 additions & 3 deletions jax/_src/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class AccumEffect(RefEffect):
effects.custom_derivatives_allowed_effects.add_type(RefEffect)
effects.custom_derivatives_allowed_effects.add_type(core.InternalMutableArrayEffect)
effects.partial_eval_kept_effects.add_type(RefEffect)
effects.remat_allowed_effects.add_type(RefEffect)

StateEffect = Union[ReadEffect, WriteEffect, AccumEffect]

Expand Down Expand Up @@ -395,11 +396,13 @@ class AbstractRef(core.AbstractValue):

.. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html
"""
__slots__ = ["inner_aval", "memory_space"]
__slots__ = ["inner_aval", "memory_space", "kind"]

def __init__(self, inner_aval: core.AbstractValue, memory_space: Any = None):
def __init__(self, inner_aval: core.AbstractValue, memory_space: Any = None,
kind: Any = None):
self.inner_aval = inner_aval
self.memory_space = memory_space
self.kind = kind

@property
def is_high(self):
Expand Down Expand Up @@ -548,7 +551,7 @@ def __repr__(self) -> str:
__str__ = __repr__

def to_tangent_aval(self):
return AbstractRef(self.inner_aval.to_tangent_aval(), self.memory_space)
return AbstractRef(self.inner_aval.to_tangent_aval(), self.memory_space, kind=self.kind)

def __eq__(self, other):
return (type(self) is type(other) and self.inner_aval == other.inner_aval
Expand Down
65 changes: 59 additions & 6 deletions tests/mutable_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,16 +748,69 @@ def body(c, x):
jtu.check_grads(f, (mut_const_vals, pure_consts, init_carry, xs),
2, ['fwd', 'rev'], rtol=1.5e-2)

def test_remat_basic_errors(self):
@parameterized.parameters([False, True])
def test_remat_basic_internal(self, jit):
@jax.remat
def f(x_ref, y):
def f(y, x):
x_ref = jax.new_ref(x)
out = y * x_ref[...]
x_ref[...] += 1
return y
return out

x_ref = core.new_ref(0)
if jit:
f = jax.jit(f)

g = jax.grad(f)(2., 1.)
self.assertAllClose(g, 1.)

@parameterized.parameters([False, True])
def test_remat_basic_arg(self, jit):
@jax.remat
def f(y, x_ref):
out = y * y
x_ref[...] += out
return out

if jit:
f = jax.jit(f)

x_ref = core.new_ref(1., kind='anselm_ref')
g = jax.grad(f)(2., x_ref)
self.assertAllClose(x_ref[...], 5.)
self.assertAllClose(g, 4.)

with self.assertRaises(NotImplementedError):
jax.grad(f, 1)(x_ref, 3.14)
@parameterized.parameters([False, True])
def test_remat_basic_closed_over(self, jit):
@jax.remat
def f(y):
out = y * x_ref[...]
x_ref[...] += 1
return out

if jit:
f = jax.jit(f)

x_ref = core.new_ref(1., kind='anselm_ref')
g = jax.grad(f)(2.)
self.assertAllClose(x_ref[...], 2.)
self.assertAllClose(g, 1.)

def test_remat_basic_closed_over_nested(self):
@jax.remat
@partial(jax.remat, policy=lambda *_, **__: False)
@jax.remat
def f(y):
jax.debug.callback(lambda _: lst.append('hi'), y)
out = y * x_ref[...]
x_ref[...] += 1
return jnp.sin(out)

lst = []
x_ref = core.new_ref(1., kind='anselm_ref')
g = jax.grad(f)(2.)
self.assertAllClose(x_ref[...], 2.)
self.assertAllClose(g, jnp.cos(2.))
self.assertLen(lst, 4)

def test_remat_grad_stats_plumbing_basic(self):
@jax.remat
Expand Down
Loading