Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions jax/_src/hijax.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ def batch(self, axis_data, args, dims):
raise NotImplementedError(f"for vmap support, subclass {type(self)} must "
"implement `batch`")

def linearize_ncnp(self, trace, *args):
return trace.default_process_primitive(call_hi_primitive_p, args, dict(prim=self))

def __call__(self, *args):
args_flat = tree_leaves_checked(self.in_tree, args)
ans_flat = call_hi_primitive_p.bind(*args_flat, prim=self)
Expand Down Expand Up @@ -428,6 +431,10 @@ def fake_linear_op(prim, rs, *tangents):

ad.primitive_linearizations[call_hi_primitive_p] = _call_hi_primitive_linearize

def _call_hi_primitive_linearize(trace, *args, prim):
return prim.linearize_ncnp(trace, *args)
ad.fancy_linearizations[call_hi_primitive_p] = _call_hi_primitive_linearize

call_hi_primitive_linearized_p = core.Primitive("call_hi_primitive_linearized")
call_hi_primitive_linearized_p.multiple_results = True
call_hi_primitive_linearized_p.is_high = lambda *args, prim, residuals_tree: True # type: ignore
Expand Down
165 changes: 135 additions & 30 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from jax._src import config
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_flatten, tree_unflatten,
from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_structure,
register_pytree_node, Partial, PyTreeDef)
from jax._src import mesh as mesh_lib
from jax._src import core
Expand Down Expand Up @@ -98,6 +98,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace(debug_info, auto_dce=True)
tangent_trace.tag = _tag
breakpoint()
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
tracers = [LinearizeTracer(linearize_trace, p,
tangent_trace.new_arg(get_aval(p).to_tangent_aval(),
Expand Down Expand Up @@ -151,6 +152,82 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
store.store(aux_primals)
return out_primals, out_tangents

def linearize_jaxpr3(
jaxpr: core.ClosedJaxpr,
nonzeros: Sequence[bool],
num_remats: int,
) -> tuple[list[core.ClosedJaxpr], list[tuple[PyTreeDef, PyTreeDef]],
list[list[bool]]]:
dbg = jaxpr.jaxpr.debug_info
primal_trace = pe.DynamicJaxprTrace(dbg)
remat_traces = [pe.DynamicJaxprTrace(dbg.with_unknown_names())
for _ in range(num_remats)]
tangent_trace = pe.DynamicJaxprTrace(dbg.with_unknown_names(), auto_dce=True)
lin_trace = LinearizeTrace([primal_trace, *remat_traces], tangent_trace)
tangent_trace.tag = lin_trace.tag
source_info = source_info_util.current()

def new_arg(primal_aval, nz):
primal = primal_trace.new_arg(primal_aval, source_info)
remats = [t.new_arg(primal_aval, source_info) for t in remat_traces]
tangent_aval = primal_aval.to_tangent_aval()
tangent = (tangent_trace.new_arg(tangent_aval, source_info)
if nz else Zero(tangent_aval))
return LinearizeTracer(lin_trace, [primal, *remats], tangent)

tracers = map(new_arg, jaxpr.in_aval_qdds, nonzeros)
in_primals_, in_tangents = unzip2((t.primals, t.tangent) for t in tracers)
in_primals, *in_remats = zip(*in_primals_)

with core.set_current_trace(lin_trace, check_leaks=True):
ans = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *tracers)
out_primals_, out_tangents = unzip2(map(lin_trace.to_primals_tangent_list, ans))
out_primals, *out_remats = zip(*out_primals_)
# TODO instantiate zeros when we add the instantiate arg
del lin_trace, ans, new_arg, tracers

nzs_out = [type(t) is not Zero for t in out_tangents]
out_tangents_nz = [tangent_trace.to_jaxpr_tracer(t, source_info)
for (nz, t) in zip(nzs_out, out_tangents) if nz]
tangent_jaxpr, residuals = tangent_trace.to_jaxpr(
out_tangents_nz, dbg.with_unknown_names(), source_info)
tangent_trace.invalidate()
tangent_jaxpr, residuals = _dce_consts(tangent_jaxpr, residuals)
tangent_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(tangent_jaxpr))

tangent_trees = (tree_structure((residuals, in_tangents)),
tree_structure(out_tangents))

remat_jaxprs = []
remat_fwds = []
remat_trees = []
for trace in reversed(remat_traces):
# set `residuals`
breakpoint()

fwd_inputs = [*jaxpr.consts, *in_primals]
id_map = {id(x):i for i, x in enumerate(fwd_inputs)}
primal_fwds = [id_map.get(id(c)) for c in residuals]
reduced_residuals = [c for c, f in zip(residuals, primal_fwds) if f is None]

primals_and_residuals = *out_primals, *reduced_residuals
to_jaxpr_tracer = partial(primal_trace.to_jaxpr_tracer,source_info=source_info)
primals_and_residuals = map(to_jaxpr_tracer, primals_and_residuals)
primal_jaxpr, primal_consts = primal_trace.to_jaxpr(
primals_and_residuals, dbg.with_unknown_names(), source_info)
primal_trace.invalidate()
primal_jaxpr, primal_consts = _dce_consts(primal_jaxpr, primal_consts)
primal_jaxpr = core.ClosedJaxpr(primal_jaxpr, primal_consts)

primal_trees = (tree_structure(in_primals),
tree_structure((out_primals, reduced_residuals)))

jaxprs = [primal_jaxpr, *remat_jaxprs, tangent_jaxpr]
fwds = [primal_fwds, *remat_fwds]
trees = [primal_trees, *remat_trees, tangent_trees]
return jaxprs, trees, fwds


def linearize_jaxpr(
jaxpr: core.ClosedJaxpr,
nonzeros: Sequence[bool],
Expand All @@ -177,7 +254,7 @@ def _linearize_jaxpr(
dbg = jaxpr.jaxpr.debug_info
config.enable_checks.value and dbg.assert_arg_names(len(nonzeros))
primal_trace = pe.DynamicJaxprTrace(dbg)
tangent_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True)
tangent_trace = pe.DynamicJaxprTrace(dbg.with_unknown_names(), auto_dce=True)
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
tangent_trace.tag = lin_trace.tag

Expand Down Expand Up @@ -217,7 +294,8 @@ def new_arg(trace, primal_aval, nz, source_info):

# pe._check_no_returned_refs(debug_info, out_primals)
primals_and_residuals = *out_primals, *tangent_consts
primals_and_residuals = map(partial(primal_trace.to_jaxpr_tracer, source_info=source_info),
primals_and_residuals = map(partial(primal_trace.to_jaxpr_tracer,
source_info=source_info),
primals_and_residuals)
primal_jaxpr, primal_consts = primal_trace.to_jaxpr(
primals_and_residuals, dbg.with_unknown_names(),
Expand Down Expand Up @@ -245,9 +323,9 @@ def direct_linearize(traceable: lu.WrappedFun, primals, kwargs, *,
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
and isinstance(core.typeof(t), core.ShapedArray)
and dtype(t) == float0 else t for t in tangents]
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
linearize_trace = LinearizeTrace([parent_trace], tangent_trace, tag=tag)
tangent_trace.tag = linearize_trace.tag
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
tracers = [LinearizeTracer(linearize_trace, [p], t) for p, t in zip(primals, tangents)]
tracers = [t.full_lower() for t in tracers]
with (core.set_current_trace(linearize_trace),
source_info_util.transform_name_stack('jvp')):
Expand Down Expand Up @@ -869,10 +947,10 @@ def _primal_tangent_shapes_match(primal, tangent):

class LinearizeTrace(Trace):

def __init__(self, parent_trace, tangent_trace, tag=None):
def __init__(self, parent_traces, tangent_trace, tag=None):
super().__init__()
self.tag = core.TraceTag() if tag is None else tag
self.parent_trace = parent_trace
self.parent_traces = parent_traces
self.tangent_trace = tangent_trace
self._name_stack_prefix_len = len(source_info_util.current_name_stack())
self.requires_low = False
Expand All @@ -882,40 +960,63 @@ def _name_stack_suffix(self):

def to_primal_tangent_pair(self, val):
if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag:
return (val.primal, val.tangent)
primal, = val.primals
return (primal, val.tangent)
else:
tangent_zero = Zero.from_primal_value(val)
return (val, tangent_zero)

def to_primals_tangent_list(self, val):
if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag:
return (val.primals, val.tangent)
else:
tangent_zero = Zero.from_primal_value(val)
primals = (val,) * len(self.parent_traces)
return primals, tangent_zero

def process_primitive(self, primitive, args, params):
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args))
fancy_rule = fancy_linearizations.get(primitive)
if fancy_rule:
return fancy_rule(self, *args, **params)
return self.default_process_primitive(primitive, args, params)

def default_process_primitive(self, primitive, args, params):
primals_in_, tangents_in = unzip2(map(self.to_primals_tangent_list, args))
primals_in = zip(*primals_in_)
tangent_nzs = [type(t) is not Zero for t in tangents_in]
if (all(type(t) is Zero for t in tangents_in) and
primitive is not core.ref_p and
not any(isinstance(core.typeof(x), AbstractRef) for x in primals_in)):
return primitive.bind_with_trace(self.parent_trace, primals_in, params)
not any(isinstance(core.typeof(x), AbstractRef) for x in primals_in[0])):
return primitive.bind_with_trace(self.parent_traces[0], primals_in[0], params)

remat_outs = [primitive.bind_with_trace(t, xs, params)
for t, xs in zip(self.parent_traces[:-1], primals_in[:-1])]

fallback = partial(fallback_linearize_rule, primitive)
lin = primitive_linearizations.get(primitive, fallback)
with core.set_current_trace(self.parent_trace):
with core.set_current_trace(self.parent_traces[-1]):
primal_out, tangent_nzs_out, residuals, linearized = lin(
tangent_nzs, *primals_in, **params)
tangent_nzs, *primals_in[-1], **params)
with (core.set_current_trace(self.tangent_trace),
source_info_util.set_name_stack(self._name_stack_suffix())):
tangent_out = linearized(residuals, *tangents_in)
if primitive.multiple_results:
return [maybe_linearize_tracer(self, x, nz, t)
for x, nz, t in zip(primal_out, tangent_nzs_out, tangent_out)]
remat_outs = zip(*remat_outs)
return [maybe_linearize_tracer(self, [*xs, x], nz, t) for xs, x, nz, t
in zip(remat_outs, primal_out, tangent_nzs_out, tangent_out)]
else:
return maybe_linearize_tracer(self, primal_out, tangent_nzs_out, tangent_out)
return maybe_linearize_tracer(self, [*remat_outs, primal_out],
tangent_nzs_out, tangent_out)

def cur_qdd(self, x):
p, _ = self.to_primal_tangent_pair(x)
with core.set_current_trace(self.parent_trace):
with core.set_current_trace(self.parent_traces[0]):
return core.cur_qdd(p)

def process_custom_jvp_call(self, prim, fun: lu.WrappedFun,
f_jvp: lu.WrappedFun, tracers, *,
symbolic_zeros: bool):
breakpoint()
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
if all(type(t) is Zero for t in tangents_in):
return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in),
Expand Down Expand Up @@ -945,6 +1046,7 @@ def process_custom_vjp_call(self, prim, fun, fwd,
out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]],
symbolic_zeros: bool):
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
breakpoint()
if all(type(t) is Zero for t in tangents_in):
return prim.bind_with_trace(self.parent_trace,
(fun, fwd, bwd, *primals_in),
Expand Down Expand Up @@ -972,6 +1074,7 @@ def process_custom_vjp_call(self, prim, fun, fwd,
return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out)

def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
breakpoint()
assert call_primitive.multiple_results
primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
nzs_in = tuple(type(t) is not Zero for t in tangents)
Expand Down Expand Up @@ -1043,13 +1146,13 @@ def f_tangent(*args):
# that's handled in process_call.
process_map = process_call

def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
def maybe_linearize_tracer(trace, primals, is_nonzero, tangent):
if is_nonzero:
assert not type(tangent) is Zero
return LinearizeTracer(trace, primal, tangent)
return LinearizeTracer(trace, primals, tangent)
else:
assert type(tangent) is Zero
return primal
return primals[0]

def fallback_linearize_rule(_prim: core.Primitive,
_nonzeros: Sequence[bool], *primals, **params):
Expand Down Expand Up @@ -1140,40 +1243,42 @@ def linearized(residuals, *tangents):
return out_primal, out_nz, out_consts, linearized

class LinearizeTracer(Tracer):
__slots__ = ['primal', 'tangent']
__slots__ = ['primals', 'tangent']

def __init__(self, trace, primal, tangent):
if config.enable_checks.value:
_primal_tangent_shapes_match(primal, tangent)
def __init__(self, trace, primals, tangent):
# TODO check all primals have same aval
# if config.enable_checks.value: # TODO DO NOT SUBMIT
# _primal_tangent_shapes_match(primal, tangent)
self._trace = trace
self.primal = primal
self.primals = primals
self.tangent = tangent

@property
def aval(self):
return get_aval(self.primal)
return get_aval(self.primals[0])

def full_lower(self):
if type(self.tangent) is Zero:
return core.full_lower(self.primal)
return core.full_lower(self.primals[0])
else:
return self

def to_concrete_value(self):
return core.to_concrete_value(self.primal)
return core.to_concrete_value(self.primals[0])

def get_referent(self):
return core.get_referent(self.primal)
return core.get_referent(self.primals[0])

def cur_qdd(self):
return core.cur_qdd(self.primal)
return core.cur_qdd(self.primals[0])


# -------------------- Primitives --------------------

primitive_jvps : dict[core.Primitive, Callable] = {}
primitive_transposes: dict[core.Primitive, Callable] = {}
primitive_linearizations : dict[core.Primitive, Callable] = {}
fancy_linearizations : dict[core.Primitive, Callable] = {}

def deflinear(primitive, transpose_rule):
primitive_jvps[primitive] = partial(linear_jvp, primitive)
Expand Down
Loading