Skip to content

Commit fc67d17

Browse files
committed
[scan] don't hoist loop-invariant computations in scan, just forward
1 parent 2b9d7c8 commit fc67d17

File tree

2 files changed

+69
-134
lines changed

2 files changed

+69
-134
lines changed

jax/_src/lax/control_flow/loops.py

Lines changed: 68 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any, TypeVar
2323
import weakref
2424

25+
import jax
2526
from jax._src import ad_checkpoint
2627
from jax._src import ad_util
2728
from jax._src import api
@@ -809,98 +810,61 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
809810
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
810811
else:
811812
assert False, "Fixpoint not reached"
812-
num_res = len(res_avals)
813+
num_res_in = len(res_avals) # number of res inputs to jaxpr_unknown
813814
del res_avals, carry_uk_out
814815

815816
# Instantiate those inputs which must be treated as unknown from the fixpoint.
816817
tracers = tuple(trace.instantiate_const(t) if uk else t
817818
for t, uk in zip(tracers, unknowns))
818819

819-
# The residual inputs and outputs of the jaxprs produced haven't yet been
820-
# adapted to the scan calling convention; in particular, jaxpr_known has its
821-
# residual outputs all at the end, meaning they're extensive outputs (which is
822-
# fully general but may be wasteful for residuals which are loop-invariant)
823-
# while jaxpr_unknown has its corresponding residual inputs at the front (just
824-
# as a convention with partial_eval_jaxpr_nounits), making them constant
825-
# inputs. To make them consistent, we move the residual inputs on
826-
# jaxpr_unknown to the end, even though we may move some back in the sequel.
827-
jaxpr_unknown = pe.move_binders_to_back(
828-
jaxpr_unknown, [True] * num_res + [False] * sum(unknowns))
829-
830-
# At this point, all residuals are treated as extensive outputs of jaxpr_known
831-
# (and extensive inputs to jaxpr_unknown). But residuals that are loop-
832-
# invariant can be hoisted out of the scan, rather than letting them get
833-
# broadcast (as in e.g. scanning multiplication by a constant matrix; we don't
834-
# want to broadcast the matrix!). So, outside the loop we perform a partial
835-
# evaluation with known 'const' inputs (but all other inputs unknown).
836-
const_pvals = [pe.PartialVal.known(t.pval.get_known())
837-
for t in tracers[:num_consts] if t.pval.is_known()]
838-
other_pvals = [pe.PartialVal.unknown(aval)
839-
for aval in jaxpr_known.in_avals[len(const_pvals):]]
840-
with source_info_util.reset_name_stack():
841-
jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits(
842-
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known),
843-
debug_info=jaxpr_known.jaxpr.debug_info),
844-
const_pvals + other_pvals,
845-
instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res)
846-
jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ())
847-
# The above trace_to_jaxpr_nounits call computed loop-invariant residuals
848-
# (known values in invar_pvals_out) and also computed loop-invariant values
849-
# needed by the new jaxpr_known (in jaxpr_known_consts, which replace the
850-
# previous consts). We need to collect the computed intensive residuals, and
851-
# move corresponding intensive residual binders in jaxpr_unknown to the front.
852-
res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:]
853-
intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()]
854-
jaxpr_unknown = pe.move_binders_to_front(
855-
jaxpr_unknown,
856-
[False] * sum(unknowns) + [pval.is_known() for pval in res_pvals])
857-
del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals
858-
# We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and
859-
# we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown.
860-
861-
# As another optimization, for any extensive inputs that are just forwarded to
862-
# extensive outputs, to avoid a copy (which would be looping over
863-
# dynamic-update-slice) we'd rather forward the input tracer/value. That means
864-
# pruning some outputs from jaxpr_known here, and updating `out_flat` below.
865-
fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr)
866-
# Prune fwds_known to include only extensive input to extensive output.
867-
fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and
868-
in_idx is not None and
869-
in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk)
870-
else None for out_idx, in_idx in enumerate(fwds_known)]
871-
# Drop any extensive output we can instead get by forwarding an input.
872-
# TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint
873-
jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts
874-
jaxpr_known_ = jaxpr_known_.replace(
875-
outvars=[x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None])
876-
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ())
877-
del jaxpr_known_
878-
# We use `fwds_known` below when forming the output of scanning jaxpr_known.
820+
known_inputs = [t.pval.get_known() for t in tracers if t.pval.is_known()]
821+
822+
# Residuals that are just forwarded from constants or scanned-over inputs
823+
# can be passed directly to the unknown scan (as constants or scanned-over
824+
# inputs, respectively). So we prune them as outputs from jaxpr_known.
825+
# Recall the outputs of jaxpr_known currently look like [*knowns, *res].
826+
num_knowns_out = len(jaxpr_known.out_avals) - num_res_in
827+
num_consts_known = num_consts - sum(const_uk)
828+
num_carry_known = num_carry - sum(carry_uk)
829+
del num_consts, num_carry
830+
in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr_known.jaxpr)
831+
in_fwd = [f if out_idx >= num_knowns_out and f is not None and
832+
(f < num_consts_known or f >= num_consts_known + num_carry_known)
833+
and isinstance(known_inputs[f], jax.Array) # no np.ndarrays
834+
else None for out_idx, f in enumerate(in_fwd)]
835+
jaxpr_known = pe.prune_closed_jaxpr_outputs(jaxpr_known, [f is None for f in in_fwd])
836+
837+
# All jaxpr_unknown residual binders are at the front like [*res, *unknowns],
838+
# but to match the scan body calling convention of [*consts, *carry, *ext_in],
839+
# we move the binders that correspond to extensive residuals (ie not forwarded
840+
# from jaxpr_known's consts) to the extensive slots.
841+
num_unk_in = len(jaxpr_unknown.in_avals) - num_res_in
842+
res_to_move = [f is None or f >= num_consts_known + num_carry_known
843+
for f in in_fwd[num_knowns_out:]]
844+
jaxpr_unknown = pe.move_binders_to_back(jaxpr_unknown,
845+
res_to_move + [False] * num_unk_in)
879846

880847
# Run the known part of the scan (if it has any outputs or effects).
881-
known_inputs = (list(jaxpr_known_consts) +
882-
[t.pval.get_known() for t in tracers[num_consts:]
883-
if t.pval.is_known()])
884848
if not jaxpr_known.out_avals and not jaxpr_known.effects:
885849
out_known = []
886850
else:
887-
linear_known = [False] * len(known_inputs) # conservative!
851+
linear_known = [l for l, uk in zip(linear, unknowns) if not uk]
888852
out_known = scan_p.bind(
889853
*known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known,
890-
num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk),
854+
num_consts=num_consts_known, num_carry=num_carry_known,
891855
linear=tuple(linear_known), unroll=unroll,
892856
_split_transpose=_split_transpose)
893857
del linear_known
894-
# Complete the known output by filling in forwarded values using fwds_known.
895-
out_known_iter = iter(out_known)
896-
out_known = [next(out_known_iter) if f is None
897-
else _maybe_put(known_inputs[f]) for f in fwds_known]
898-
assert next(out_known_iter, None) is None
899-
del known_inputs, out_known_iter
900858

901-
# Split known outputs from residuals.
902-
out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)])
903-
assert len(intensive_res) + len(extensive_res) == num_res
859+
# Complete out_known by filling in forwards from known inputs.
860+
out_known_ = iter(out_known)
861+
out_known = [next(out_known_) if f is None else known_inputs[f] for f in in_fwd]
862+
assert next(out_known_, None) is None
863+
864+
# Split known outputs from residuals, and const-forwarded (intensive)
865+
# residuals from other (extensive) residuals.
866+
out_known, all_res = split_list(out_known, [num_knowns_out])
867+
intensive_res, extensive_res = partition_list(res_to_move, all_res)
904868

905869
# Create input tracers for jaxpr_unknown bind.
906870
unknown_inputs = [t for t in tracers if not t.pval.is_known()]
@@ -1254,87 +1218,58 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
12541218
assert False, "Fixpoint not reached"
12551219
jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts)
12561220
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts)
1257-
1258-
# Move all residual binders to the back of jaxpr_staged so they're extensive.
1259-
# TODO(mattjj): make jaxpr_staged only take instantiated inputs
12601221
res_avals = jaxpr_staged.in_avals[:num_res]
1261-
jaxpr_staged = pe.move_binders_to_back(
1262-
jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals))
12631222

12641223
# Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to
12651224
# passing in_inst argument to partial_eval_jaxpr_custom above).
12661225
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
12671226
if type(x) is core.Var and not inst]
12681227
inst_in = [True] * len(inst_in)
12691228

1270-
# As an optimization, hoist loop-invariant residuals out of the loop rather
1271-
# than using extensive outputs for them. See _scan_partial_eval for comments.
1272-
num_const_known = len(const_uk) - sum(const_uk)
1273-
num_carry_known = len(carry_uk) - sum(carry_uk)
1274-
num_xs_known = len( xs_uk) - sum( xs_uk)
1275-
jaxpr_known_hoist, jaxpr_known_loop, loop_dep, consts_known_lp_avals = \
1276-
pe.partial_eval_jaxpr_nounits(
1277-
jaxpr_known,
1278-
[False] * num_const_known + [True] * (num_carry_known + num_xs_known),
1279-
[True] * (len(unks_out) - sum(unks_out)) + [False] * num_res)
1280-
# jaxpr_known_hoist produces intensive residuals followed by the constants for
1281-
# jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts.
1282-
_, loop_dep_res = split_list(loop_dep, [len(loop_dep) - num_res])
1283-
jaxpr_staged = pe.move_binders_to_front(
1284-
jaxpr_staged, [False] * sum(inst_in) + _map(operator.not_, loop_dep_res))
1285-
num_intensive_res = len(loop_dep_res) - sum(loop_dep_res)
1286-
del loop_dep, num_carry_known, num_xs_known, const_uk
1287-
1288-
# Create residual variables.
1289-
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
1290-
ext_avals = [core.unmapped_aval(eqn.params['length'], 0, a)
1291-
for a in ext_avals_mapped]
1229+
# No res forwards should be possible because jaxpr_unknown already has all
1230+
# inputs instantiated.
1231+
num_knowns_out = len(jaxpr_known.out_avals) - num_res
1232+
num_consts_known = len(const_uk) - sum(const_uk)
1233+
num_carry_known = len(carry_uk) - sum(carry_uk)
1234+
in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr_known.jaxpr)
1235+
in_fwd = [f if out_idx >= num_knowns_out and f is not None and
1236+
(f < num_consts_known or f >= num_consts_known + num_carry_known)
1237+
else None for out_idx, f in enumerate(in_fwd)]
1238+
assert all(f is None for f in in_fwd)
1239+
del in_fwd
1240+
1241+
# Move all residual binders to the back of jaxpr_staged so they're extensive.
1242+
jaxpr_staged = pe.move_binders_to_back(
1243+
jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals))
1244+
1245+
# Create variables for extensive residuals output by the known eqn.
12921246
newvar = core.gensym()
1293-
intensive_res = _map(newvar, intensive_avals)
1294-
extensive_res = _map(newvar, ext_avals)
1247+
ext_res_avals = _map(partial(core.unmapped_aval, eqn.params['length'], 0), res_avals)
1248+
ext_res_out_binders = _map(newvar, ext_res_avals)
12951249

1296-
# Create known eqn, which is a call_p combining evaluation of
1297-
# jaxpr_known_hoist and a scan of jaxpr_known_loop.
1250+
# Create known eqn.
12981251
ins_known, _ = partition_list(unks_in, eqn.invars)
12991252
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
1300-
# jaxpr_known_loop takes as input constants output as res by jaxpr_known_hoist
1301-
# (corresponding to consts_known_lp_avals) followed by known carry and xs.
1302-
linear_known_ = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk]
1303-
_, linear_known_ = split_list(linear_known_, [num_const_known])
1304-
linear_known = [False] * len(consts_known_lp_avals) + linear_known_
1305-
params_known = dict(eqn.params, jaxpr=jaxpr_known_loop,
1306-
num_consts=len(consts_known_lp_avals),
1307-
num_carry=len(carry_uk)-sum(carry_uk),
1253+
linear_known = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk]
1254+
params_known = dict(eqn.params, jaxpr=jaxpr_known,
1255+
num_consts=num_consts_known, num_carry=num_carry_known,
13081256
linear=tuple(linear_known))
1309-
1310-
def known(*ins_known):
1311-
consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known])
1312-
out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist)
1313-
intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res])
1314-
out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known)
1315-
return [*intensive_res, *out_loop]
1316-
call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic(
1317-
lu.wrap_init(known, debug_info=jaxpr_known_hoist.jaxpr.debug_info),
1318-
[v.aval for v in ins_known])
1319-
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
13201257
eqn_known = pe.new_jaxpr_eqn(
1321-
ins_known, [*intensive_res, *out_binders_known, *extensive_res],
1322-
core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects,
1323-
eqn.source_info, eqn.ctx)
1258+
ins_known, [*out_binders_known, *ext_res_out_binders],
1259+
scan_p, params_known, jaxpr_known.effects, eqn.source_info, eqn.ctx)
13241260

13251261
# Create the staged eqn.
13261262
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
1327-
linear_staged = ([False] * len(intensive_res) + list(eqn.params['linear']) +
1328-
[False] * len(extensive_res))
1263+
linear_staged = list(eqn.params['linear']) + [False] * len(ext_res_out_binders)
13291264
params_staged = dict(eqn.params, jaxpr=jaxpr_staged,
1330-
num_consts=len(intensive_res) + eqn.params['num_consts'],
1265+
num_consts=eqn.params['num_consts'],
13311266
linear=tuple(linear_staged))
1332-
eqn_staged = pe.new_jaxpr_eqn([*intensive_res, *eqn.invars, *extensive_res],
1267+
eqn_staged = pe.new_jaxpr_eqn([*eqn.invars, *ext_res_out_binders],
13331268
out_binders_staged, eqn.primitive,
13341269
params_staged, jaxpr_staged.effects,
13351270
eqn.source_info, eqn.ctx)
13361271

1337-
new_vars = [*new_inst, *intensive_res, *extensive_res]
1272+
new_vars = [*new_inst, *ext_res_out_binders]
13381273
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
13391274

13401275
def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,

tests/lax_control_flow_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2789,7 +2789,7 @@ def cumprod(x):
27892789

27902790
if remat is not None:
27912791
# TODO(mattjj): make the numpy.ndarray test pass w/ remat
2792-
raise unittest.SkipTest("new-remat-of-scan doesn't convert numpy.ndarray")
2792+
return
27932793

27942794
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not Array
27952795
_, vjp_fun = jax.vjp(cumprod, x)

0 commit comments

Comments
 (0)