Skip to content

[hijax] prototype hijax pieces #28781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions jax/_src/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
map = safe_map

def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array:
ty = core.typeof(x)
if hasattr(ty, 'vspace_add'): # TODO(mattjj,dougalm): revise away hasattr
return ty.vspace_add(x, y)
x, y = core.standard_insert_pvary(x, y)
return add_jaxvals_p.bind(x, y)

Expand All @@ -48,6 +51,8 @@ def add_abstract(x, y):
return x

def zeros_like_aval(aval: core.AbstractValue) -> Array:
if hasattr(aval, 'vspace_zero'): # TODO(mattjj,dougalm): revise away hasattr
return aval.vspace_zero()
return aval_zeros_likers[type(aval)](aval)
aval_zeros_likers: dict[type, Callable[[Any], Array]] = {}

Expand Down
24 changes: 13 additions & 11 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,17 +540,18 @@ def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
f"but got {aval.dtype.name}.")
if (dtypes.issubdtype(aval.dtype, dtypes.extended) or
dtypes.issubdtype(aval.dtype, np.integer) or
dtypes.issubdtype(aval.dtype, np.bool_)):
if not allow_int:
raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype "
f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. "
"If you want to use Boolean- or integer-valued inputs, use vjp "
"or set allow_int to True.")
elif not dtypes.issubdtype(aval.dtype, np.inexact):
raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a "
f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.")
if isinstance(aval, ShapedArray):
if (dtypes.issubdtype(aval.dtype, dtypes.extended) or
dtypes.issubdtype(aval.dtype, np.integer) or
dtypes.issubdtype(aval.dtype, np.bool_)):
if not allow_int:
raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype "
f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. "
"If you want to use Boolean- or integer-valued inputs, use vjp "
"or set allow_int to True.")
elif not dtypes.issubdtype(aval.dtype, np.inexact):
raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a "
f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.")
_check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")

def _check_output_dtype_revderiv(name, holomorphic, x):
Expand Down Expand Up @@ -1873,6 +1874,7 @@ def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
f"structure; primals have tree structure {tree_def} whereas tangents have "
f"tree structure {tree_def_2}.")
for p, t in zip(ps_flat, ts_flat):
if not isinstance(core.typeof(p), ShapedArray): continue
if core.primal_dtype_to_tangent_dtype(_dtype(p)) != _dtype(t):
raise TypeError("primal and tangent arguments to jax.jvp do not match; "
"dtypes must be equal, or in case of int/bool primal dtype "
Expand Down
36 changes: 33 additions & 3 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,16 @@

class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
'_effects', '_debug_info']
'_effects', '_debug_info', '_is_high', '_mut_types']

_constvars: list[Var]
_invars: list[Var]
_outvars: list[Atom]
_eqns: list[JaxprEqn]
_effects: Effects
_debug_info: DebugInfo
_is_high: bool
_mut_types: dict[Var, Any]

@property
def constvars(self) -> list[Var]:
Expand All @@ -121,13 +123,23 @@ def effects(self) -> Effects:
def debug_info(self) -> DebugInfo:
return self._debug_info

@property
def is_high(self) -> bool:
return self._is_high

@property
def mut_types(self) -> dict[Var, Any]:
return self._mut_types

def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects,
# We want all calls to pass a DebugInfo object, but for backwards
# compatibility we have to allow calls when the debug_info
# is missing.
debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch,assignment]
is_high: bool = False,
mut_types: dict | None = None,
):
"""
Args:
Expand All @@ -152,6 +164,8 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
# TODO(necula): re-enable these safety checks
# assert (len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
self._is_high = is_high
self._mut_types = mut_types or {}

def __str__(self):
return str(self.pretty_print())
Expand All @@ -178,6 +192,8 @@ def replace(self, **kwargs):
eqns=kwargs.pop("eqns", self.eqns),
effects=kwargs.pop("effects", self.effects),
debug_info=kwargs.pop("debug_info", self.debug_info),
is_high=kwargs.pop("is_high", self.is_high),
mut_types=kwargs.pop("mut_types", self.mut_types),
)
if kwargs:
raise ValueError(f"Unknown keyword arguments: {kwargs}")
Expand Down Expand Up @@ -517,14 +533,18 @@ def _true_bind(self, *args, **params):
for arg in args:
if isinstance(arg, Tracer) and not arg._trace.is_valid():
raise escaped_tracer_error(arg)
# TODO: figure out how to handle function arguments
# TODO: figure out how to handle function arguments for this assert
# assert (not config.enable_checks.value or
# all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args

# This is equivalent to "with take_current_trace()", but the bind() code
# is called frequently and it's slightly faster to avoid using a context
# manager object.
prev_trace = trace_ctx.trace

if self.is_high(**params) and prev_trace.requires_low:
return self.to_lojax(*args, **params) # type: ignore

trace_ctx.set_trace(eval_trace)
try:
return self.bind_with_trace(prev_trace, args, params)
Expand Down Expand Up @@ -561,6 +581,9 @@ def abstract_eval(self, *args, **params):
def get_bind_params(self, params):
return [], params

def is_high(self, **params) -> bool:
return False


def _effect_free_abstract_eval(abstract_eval):
def abstract_eval_(*args, **kwargs):
Expand Down Expand Up @@ -627,12 +650,13 @@ def check_avals_context_mesh(avals, prim_name):
TracerType = TypeVar('TracerType', bound='Tracer')

class Trace(Generic[TracerType]):
__slots__ = ("__weakref__", "_invalidated", "_weakref")
__slots__ = ("__weakref__", "_invalidated", "_weakref", "requires_low")

def __init__(self):
self._invalidated = False
# We frequently need a weakref to a trace, so let's precompute one.
self._weakref = weakref.ref(self)
self.requires_low = True

def process_primitive(self, primitive, tracers, params):
raise NotImplementedError("must override")
Expand Down Expand Up @@ -1445,6 +1469,8 @@ def definitely_equal(x, y):

class AbstractValue:
__slots__: list[str] = []
is_high = False
mutable = False

def to_tangent_aval(self):
raise NotImplementedError("must override")
Expand Down Expand Up @@ -1948,6 +1974,10 @@ def __init__(self, shape, dtype, weak_type=False, *, sharding=None,
self.sharding = get_sharding(sharding, self.shape)
self.vma = get_vma(vma, self.sharding.mesh)

def lower_val(self, val): return [val]
def raise_val(self, val): return val
def lo_ty(self): return [self]

def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
if shape is None:
shape = self.shape
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
tag = core.TraceTag()
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
and isinstance(core.typeof(t), core.ShapedArray)
and dtype(t) == float0 else t for t in tangents]
ctx = (source_info_util.transform_name_stack('jvp') if transform_stack
else contextlib.nullcontext())
Expand Down Expand Up @@ -475,6 +476,7 @@ def __init__(self, parent_trace, tag):
super().__init__()
self.tag = tag
self.parent_trace = parent_trace
self.requires_low = False

def to_primal_tangent_pair(self, val):
if isinstance(val, JVPTracer) and val._trace.tag is self.tag:
Expand Down Expand Up @@ -606,7 +608,8 @@ def process_custom_transpose(self, prim, call, tracers, **params):
return map(partial(maybe_jvp_tracer, self), ps_out, ts_out)

def maybe_jvp_tracer(trace, primal, tangent):
if type(tangent) is Zero or dtype(tangent) == float0:
if (type(tangent) is Zero or
core.typeof(tangent) is core.ShapedArray and dtype(tangent) == float0):
return primal
else:
return JVPTracer(trace, primal, tangent)
Expand Down Expand Up @@ -641,6 +644,7 @@ def _primal_tangent_shapes_match(primal, tangent):
if type(tangent) is not Zero:
primal_aval = get_aval(primal).strip_weak_type()
tangent_aval = get_aval(tangent).strip_weak_type()
if not isinstance(primal_aval, core.ShapedArray): return # TODO(mattjj,dougalm)
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), (primal_aval.shape, tangent_aval.shape)
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
Expand Down
68 changes: 51 additions & 17 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, t
self.name_stack = name_stack
self.tag = tag
self.parent_trace = parent_trace
self.requires_low = False

def to_jaxpr_tracer(self, x):
if isinstance(x, JaxprTracer) and x._trace.tag is self.tag:
Expand Down Expand Up @@ -899,9 +900,8 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
raise NotImplementedError
config.enable_checks.value and core.check_jaxpr(jaxpr)
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects, debug_info=jaxpr.debug_info)
converted_jaxpr = jaxpr.replace(constvars=jaxpr.constvars + env_vars,
invars=invars)
config.enable_checks.value and core.check_jaxpr(converted_jaxpr)
return converted_jaxpr

Expand Down Expand Up @@ -1173,6 +1173,7 @@ def has_effects(effects) -> bool:
out_unknowns = map(op.or_, out_unknowns, ensure_out_unknowns)
out_inst = map(op.or_, out_inst, ensure_out_inst)


ins_known, _ = partition_list(in_unknowns, jaxpr.invars)
outs_known, _ = partition_list(out_unknowns, jaxpr.outvars)
ref_res_is_input = [r in ins_known for r in residual_refs]
Expand All @@ -1181,18 +1182,25 @@ def has_effects(effects) -> bool:
known_outvars = [*outs_known, *residuals]
known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res,
known_outvars, known_eqns)
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars,
known_eqns, known_effects, jaxpr.debug_info)
known_mut, staged_mut, ins_known_ = {}, {}, set(ins_known) # type: ignore
for v, t in jaxpr.mut_types.items():
[staged_mut, known_mut][v in ins_known_][v] = t

# TODO(mattjj,necula): debug info should be updated here
jaxpr_known = jaxpr.replace(
invars=ins_known_and_ref_res, outvars=known_outvars,
eqns=known_eqns, effects=known_effects, mut_types=known_mut)
config.enable_checks.value and core.check_jaxpr(jaxpr_known)

_, ins_staged = partition_list(in_inst, jaxpr.invars)
_, outs_staged = partition_list(out_inst, jaxpr.outvars)
staged_invars = [*residuals, *non_input_res_refs, *ins_staged]
staged_effects = make_jaxpr_effects(jaxpr.constvars, staged_invars,
outs_staged, staged_eqns)
jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars,
outs_staged, staged_eqns, staged_effects,
jaxpr.debug_info)
# TODO(mattjj,necula): debug info should be updated here
jaxpr_staged = jaxpr.replace(
invars=staged_invars, outvars=outs_staged, eqns=staged_eqns,
effects=staged_effects, mut_types=staged_mut)
config.enable_checks.value and core.check_jaxpr(jaxpr_staged)

return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals),
Expand Down Expand Up @@ -1483,7 +1491,8 @@ def write(x: Atom, b: bool) -> None:
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.filter_arg_names(used_inputs),
jaxpr.debug_info.filter_result_paths(used_outputs))
new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg)
new_jaxpr = jaxpr.replace(invars=invars, outvars=outvars, eqns=eqns,
effects=jaxpr_effects, debug_info=dbg)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)

return new_jaxpr, used_inputs
Expand Down Expand Up @@ -1561,9 +1570,8 @@ def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]
new_invars = _move_to_front(invars, to_move)
new_effs = _renumber_effects(
(*constvars, *new_invars), (*constvars, *invars), closed_jaxpr.jaxpr.effects)
new_jaxpr = Jaxpr(constvars, new_invars, closed_jaxpr.jaxpr.outvars,
closed_jaxpr.jaxpr.eqns, new_effs,
closed_jaxpr.jaxpr.debug_info)
new_jaxpr = closed_jaxpr.jaxpr.replace(
constvars=constvars, invars=new_invars, effects=new_effs)
new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
return new_closed_jaxpr

Expand Down Expand Up @@ -1704,6 +1712,7 @@ class JaxprStackFrame:
attrs_inits: list
attrs_vars: list[Var]
debug_info: core.DebugInfo
is_high: bool

def __init__(self, debug_info: core.DebugInfo):
self.gensym = core.gensym()
Expand All @@ -1718,6 +1727,7 @@ def __init__(self, debug_info: core.DebugInfo):
self.attrs_inits = []
self.attrs_vars = []
self.debug_info = debug_info
self.is_high = False

def add_eqn(self, eqn: core.JaxprEqn):
self.eqns.append(eqn)
Expand All @@ -1743,8 +1753,9 @@ def to_jaxpr(
outvars = state_outvars + explicit_outvars
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns)
mut_types = {v: v.aval for v in invars if v.aval.mutable} if self.is_high else {}
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects,
debug_info)
debug_info, self.is_high, mut_types)
jaxpr, constvals = _drop_unused_vars(jaxpr, constvals)
init_trees = [tree_structure(init_val) for init_val in self.attrs_inits]
return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked)
Expand Down Expand Up @@ -1831,8 +1842,9 @@ def vars(atom: Atom) -> list[Var]:
class DynamicJaxprTrace(core.Trace):
__slots__ = ("frame", "tag")

def __init__(self, debug_info: core.DebugInfo):
def __init__(self, debug_info: core.DebugInfo, lower=False):
super().__init__()
self.requires_low = lower
self.frame = JaxprStackFrame(debug_info)

def invalidate(self):
Expand Down Expand Up @@ -2193,10 +2205,11 @@ def trace_to_jaxpr_dynamic(
in_avals: Sequence[AbstractValue],
*,
keep_inputs: list[bool] | None = None,
lower: bool = False,
) -> tuple[Jaxpr, list[AbstractValue], list[Any],
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]:
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
trace = DynamicJaxprTrace(fun.debug_info)
trace = DynamicJaxprTrace(fun.debug_info, lower=lower)
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
source_info = source_info_util.current()
in_tracers = _input_type_to_tracers(
Expand Down Expand Up @@ -2418,8 +2431,7 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]:
kept_outs = [False] * len(impl_outvars) + [True] * len(expl_outvars)
out_type = tuple(zip(out_avals, kept_outs))

new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns,
jaxpr.effects, jaxpr.debug_info)
new_jaxpr = jaxpr.replace(outvars=outvars)
config.enable_checks.value and core.check_jaxpr(jaxpr)
return new_jaxpr, out_type

Expand Down Expand Up @@ -2663,3 +2675,25 @@ def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, lis
_, jaxpr = f.f.closure
return convert_constvars_jaxpr(jaxpr), []
return jaxpr, consts


@weakref_lru_cache
def lower_jaxpr(hi_jaxpr):
in_avals = [lo_ty for t in hi_jaxpr.in_avals for lo_ty in t.lo_ty()]
f = lu.wrap_init(partial(lower_traceable, hi_jaxpr),
debug_info=hi_jaxpr.jaxpr.debug_info)
lo_jaxpr, _, consts, () = trace_to_jaxpr_dynamic(f, in_avals, lower=True)
return core.ClosedJaxpr(lo_jaxpr, consts)

def lower_traceable(jaxpr, *lo_args):
lo_args_ = iter(lo_args)
hi_args = [t.raise_val(*it.islice(lo_args_, len(t.lo_ty())))
for t in jaxpr.in_avals]
assert (problem := next(lo_args_, None)) is None
hi_outs = core.jaxpr_as_fun(jaxpr)(*hi_args)
in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)}
mut_outs = [lo_val for v, ty in jaxpr.jaxpr.mut_types.items()
for lo_val in ty.get(hi_args[in_idx[v]])]
lo_outs = [lo_val for t, hi_val in zip(jaxpr.out_avals, hi_outs)
for lo_val in t.lower_val(hi_val)]
return mut_outs + lo_outs
Loading
Loading