Skip to content

Avoid doing DCE of effectful ops and reordering in partial eval. #28955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer,
out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
for x in jaxpr_unknown.outvars]
new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
recipe = pe.new_eqn_recipe(trace, in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
new_params, jaxpr_unknown.effects,
source_info_util.current())

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def debug_callback_jvp_rule(primals, tangents, **params):
ad.primitive_jvps[debug_callback_p] = debug_callback_jvp_rule

def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any],
effect: DebugEffect):
effect: DebugEffect, partitioned):
del flat_args, callback, effect
raise ValueError("Transpose doesn't support debugging callbacks.")
ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ def make_zero(aval):
out_nz_tracers = [trace.to_jaxpr_tracer(r)
for (r, nz) in zip(out_tangents, out_nzs) if nz]
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info)
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, [], jvp.debug_info)
jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
jaxpr, [True] * len(jaxpr.outvars),
[False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars))
Expand Down
46 changes: 34 additions & 12 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections import namedtuple
from collections.abc import Callable, Sequence, Hashable
import contextlib
from dataclasses import dataclass
from functools import partial
import itertools as it
import operator as op
Expand All @@ -42,7 +43,7 @@
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.source_info_util import SourceInfo
from jax._src.state.types import AbstractRef, ReadEffect
from jax._src.state.types import AbstractRef, ReadEffect, RefEffect
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten,
tree_structure, register_static)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
Expand Down Expand Up @@ -147,6 +148,10 @@ def get_aval(self) -> AbstractValue:
else:
return self[0]

@dataclass(frozen=True)
class EffectHandle:
parents : list[Tracer]
recipe : JaxprEqnRecipe

class JaxprTrace(Trace['JaxprTracer']):

Expand All @@ -156,6 +161,8 @@ def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, t
self.tag = tag
self.parent_trace = parent_trace
self.requires_low = False
self.effect_handles : list[EffectHandle] = []
self.counter = it.count()

def to_jaxpr_tracer(self, x):
if isinstance(x, JaxprTracer) and x._trace.tag is self.tag:
Expand Down Expand Up @@ -239,14 +246,19 @@ def default_process_primitive(self, primitive, tracers, params):
if primitive.multiple_results:
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
for aval in out_aval]
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects,
eqn = new_eqn_recipe(self, tracers, out_tracers, primitive, params, effects,
source)
if any(isinstance(e, RefEffect) for e in effects):
self.effect_handles.append(EffectHandle(tracers, eqn))
for t in out_tracers: t.recipe = eqn
return out_tracers
else:
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
params, effects, source)
eqn = new_eqn_recipe(self, tracers, [out_tracer], primitive,
params, effects, source)
if any(isinstance(e, RefEffect) for e in effects):
self.effect_handles.append(EffectHandle(tracers, eqn))
out_tracer.recipe = eqn
return out_tracer

def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
Expand Down Expand Up @@ -321,7 +333,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
for a in out_type]
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe((*res_tracers, *env_tracers, *unknown_arg_tracers),
eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *unknown_arg_tracers),
out_tracers, primitive, staged_params, jaxpr.effects,
source)
for t in out_tracers: t.recipe = eqn
Expand Down Expand Up @@ -390,7 +402,7 @@ def const_out_axes_thunk():
for a in out_avals]
effs = core.filter_named_axis_effects(jaxpr.effects, {params['axis_name']})
src_info = source_info_util.current()
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
eqn = new_eqn_recipe(self, (*const_tracers, *env_tracers, *unknown_arg_tracers),
out_tracers, primitive, staged_params, effs, src_info)
for t in out_tracers: t.recipe = eqn

Expand Down Expand Up @@ -425,7 +437,7 @@ def process_custom_transpose(self, prim, call, tracers, **params):
for aval in params['out_types']]
in_tracers = map(self.instantiate_const, tracers)
new_params = dict(params, call=call)
eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params,
eqn = new_eqn_recipe(self, in_tracers, out_tracers, prim, new_params,
core.no_effects, source_info_util.current())
for t in out_tracers: t.recipe = eqn
return out_tracers
Expand Down Expand Up @@ -470,7 +482,7 @@ def fwd_jaxpr_thunk(*zeros):
out_trees=out_trees,
symbolic_zeros=symbolic_zeros
)
eqn = new_eqn_recipe((*res_tracers, *env_tracers, *tracers),
eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *tracers),
out_tracers, prim, params, jaxpr.effects, source)
for t in out_tracers: t.recipe = eqn
return out_tracers
Expand Down Expand Up @@ -657,7 +669,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
out_tracers = [trace.instantiate_const(t) if inst else t
for inst, t in zip(instantiate, out_tracers)]
out_tracers_ = [t for t in out_tracers if not t.is_known()]
jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, debug_info)
jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, trace.effect_handles, debug_info)
return out_tracers, jaxpr, out_consts, env

# The below variant implements an optimization where residuals which are also
Expand Down Expand Up @@ -739,7 +751,8 @@ class JaxprEqnRecipe(NamedTuple):
source_info: source_info_util.SourceInfo
ctx: JaxprEqnContext

def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
def new_eqn_recipe(trace: JaxprTrace,
in_tracers: Sequence[JaxprTracer],
out_tracers: Sequence[JaxprTracer],
primitive: Primitive,
params: dict[str, Any],
Expand All @@ -762,7 +775,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
config.threefry_partitionable.value,
xla_metadata_lib.current_xla_metadata(),
)
return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers),
return JaxprEqnRecipe(next(trace.counter), tuple(in_tracers), map(ref, out_tracers),
out_avals, primitive, params, effects, source_info,
ctx)

Expand All @@ -780,6 +793,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
def tracers_to_jaxpr(
in_tracers: Sequence[JaxprTracer],
out_tracers: Sequence[JaxprTracer],
effect_handles: Sequence[Any],
debug_info: core.DebugInfo,
) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]:
"""Constructs Jaxpr given tracers for inputs and outputs.
Expand Down Expand Up @@ -821,7 +835,15 @@ def type_substitute(aval: AbstractValue) -> AbstractValue:

processed_eqn_ids = set()
eqns: list[core.JaxprEqn] = []
for t in toposort((*in_tracers, *out_tracers)):

reachable = toposort
tracers = reachable((*in_tracers, *out_tracers, *effect_handles))
def sort_key(t):
r = t.recipe
return r.eqn_id if isinstance(r, JaxprEqnRecipe) else -1
tracers = sorted(tracers, key=sort_key)

for t in tracers:
r = t.recipe
if isinstance(r, JaxprEqnRecipe):
# TODO broadcast_in_dim can create a new tracer, not present in parents
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def _cond_partial_eval(trace, *tracers, branches, **params):
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
eqn = pe.new_eqn_recipe(
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
trace, [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
core.join_effects(*(j.effects for j in branches_unknown)), source)
for t in out_tracers: t.recipe = eqn
return util.merge_lists(out_uks, out_consts, out_tracers)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,

assert len(unknown_inputs) == len(res_ref_unknown_outputs)
assert len(unknown_inputs) == len(jaxpr_unknown.invars) - 1
eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs,
eqn = pe.new_eqn_recipe(trace, unknown_inputs, res_ref_unknown_outputs,
for_p, dict(jaxpr=jaxpr_unknown, nsteps=nsteps,
reverse=reverse,
which_linear=which_linear_unknown,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
assert len(out_tracers) == len(jaxpr_unknown.out_avals)
eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res],
eqn = pe.new_eqn_recipe(trace, [*intensive_res, *unknown_inputs, *extensive_res],
out_tracers, scan_p,
dict(reverse=reverse, length=length, unroll=unroll,
jaxpr=jaxpr_unknown, linear=linear_unknown,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6550,7 +6550,7 @@ def _broadcast_in_dim_partial_eval(
out_aval = core.DShapedArray(tuple(shape_), operand.dtype, operand.weak_type)
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
eqn = pe.new_eqn_recipe(
[operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p,
trace, [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p,
dict(shape=shape, broadcast_dimensions=broadcast_dimensions,
sharding=None),
core.no_effects, source_info_util.current())
Expand Down
16 changes: 3 additions & 13 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2324,18 +2324,8 @@ def _pjit_partial_eval(trace: pe.JaxprTrace,

known_ins = tuple(pv.is_known() for pv in in_pvals)
unknown_ins = tuple(not k for k in known_ins)
if any(isinstance(e, (RefEffect, core.InternalMutableArrayEffect))
for e in jaxpr.effects):
known_jaxpr_, unknown_jaxpr_, unknown_outs, _, num_res_val, num_res_ref = \
pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, unknown_ins, unknown_ins,
False, False, None)
if num_res_ref: raise NotImplementedError
known_jaxpr = pe.ClosedJaxpr(known_jaxpr_, jaxpr.consts)
unknown_jaxpr = pe.ClosedJaxpr(unknown_jaxpr_, jaxpr.consts)
res_avals = unknown_jaxpr.in_avals[:num_res_val]
else:
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
unknown_outs = tuple(unknown_outs) # type: ignore[assignment]
known_outs = tuple(not uk for uk in unknown_outs)
num_residuals = len(res_avals)
Expand Down Expand Up @@ -2431,7 +2421,7 @@ def keep_where(l, should_keep):
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
for aval in unknown_out_avals
]
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),
eqn = pe.new_eqn_recipe(trace, (*unknown_tracers_in, *residual_tracers),
unknown_tracers_out,
pjit_p,
unknown_params,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,7 @@ def known_out_specs():
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in out_avals]
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers),
eqn = pe.new_eqn_recipe(trace, (*const_tracers, *env_tracers, *unk_arg_tracers),
out_tracers, shard_map_p, unk_params,
effs, source_info_util.current())
for t in out_tracers: t.recipe = eqn
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
is_initialized=(True,) * len(jaxpr_unknown.invars))
_, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs],
**uk_params)
eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs,
eqn = pe.new_eqn_recipe(trace, unknown_inputs, res_ref_unknown_outputs,
run_state_p, uk_params,
eqn_effects, source)
for t in res_ref_unknown_outputs: t.recipe = eqn
Expand Down
8 changes: 6 additions & 2 deletions tests/mutable_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,18 @@ def f():
x = f()
self.assertArraysEqual(x, jnp.zeros(8))

def test_grad_mutable_array(self):
@jax.jit
@parameterized.parameters([False, True])
def test_grad_mutable_array(self, jit):

def f(x):
x_ = core.mutable_array(x)
x_[()] = x_[()] + x_[()]
y = core.freeze(x_)
return y

if jit:
f = jax.jit(f)

ans = jax.grad(f)(1.)
expected = 2.0
self.assertAllClose(ans, expected, check_dtypes=False)
Expand Down
Loading