Skip to content

Commit f617ece

Browse files
committed
Apply extensive input to extensive output forwarding in scan.
1 parent 39f0906 commit f617ece

File tree

2 files changed

+103
-6
lines changed

2 files changed

+103
-6
lines changed

jax/_src/lax/control_flow/loops.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def _create_jaxpr(init):
341341
# If the body forwards an input carry to an output carry, that input is
342342
# read-only and can be moved to be a const. Doing so can lead to efficiency
343343
# wins, e.g. if the scan is inside a cond with a batched predicate.
344-
carry_fwd, _ = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry])
344+
carry_fwd, ext_fwd = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry])
345345
move_to_const = [len(consts) + i == f for i, f in enumerate(carry_fwd)]
346346
if any(move_to_const):
347347
jaxpr = pe.prune_closed_jaxpr_outputs(
@@ -352,11 +352,37 @@ def _create_jaxpr(init):
352352
consts = [*new_consts, *consts]
353353
num_carry -= len(new_consts)
354354

355-
out = scan_p.bind(*consts, *in_flat,
356-
reverse=reverse, length=length, jaxpr=jaxpr,
357-
num_consts=len(consts), num_carry=num_carry,
358-
linear=(False,) * (len(consts) + len(in_flat)),
359-
unroll=unroll, _split_transpose=_split_transpose)
355+
# When an extensive output is forwarded from an extensive input, we can
356+
# avoid copying it by pruning it from the jaxpr and forwarding manually. We
357+
# don't need to update the indexing based on the optimization above since it
358+
# doesn't change the total number of consts and carries combined, and
359+
# `ext_fwd` already only includes the extensive outputs. But, we do remove
360+
# the number of consts from the index since we're going to use it to index
361+
# into `in_flat`, which doesn't include consts.
362+
ext_to_ext_fwd = [
363+
in_idx - len(consts) if in_idx is not None and
364+
in_idx >= num_carry + len(consts) else None for in_idx in ext_fwd]
365+
jaxpr = pe.prune_closed_jaxpr_outputs(
366+
jaxpr, [True] * num_carry + [i is None for i in ext_to_ext_fwd])
367+
368+
if not jaxpr.eqns:
369+
# After all optimizations above, it's possible that the the body becomes a
370+
# no-op without outputs or effects. In this case, we don't bind the scan at
371+
# all.
372+
out = []
373+
else:
374+
out = scan_p.bind(*consts, *in_flat,
375+
reverse=reverse, length=length, jaxpr=jaxpr,
376+
num_consts=len(consts), num_carry=num_carry,
377+
linear=(False,) * (len(consts) + len(in_flat)),
378+
unroll=unroll, _split_transpose=_split_transpose)
379+
380+
# Apply input to output forwarding that was computed above.
381+
carry_out, out = split_list(out, [num_carry])
382+
out_ = iter(out)
383+
out = [next(out_) if f is None else _maybe_put(in_flat[f]) for f in ext_to_ext_fwd]
384+
assert next(out_, None) is None
385+
out = [*carry_out, *out]
360386

361387
if any(move_to_const):
362388
out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts)

tests/lax_control_flow_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3299,6 +3299,59 @@ def body_fun(c, _):
32993299
outs_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0]
33003300
self.assertAllClose(outs, outs_ref, check_dtypes=False)
33013301

3302+
@parameterized.parameters(itertools.product(range(3), repeat=4))
3303+
@jtu.run_on_devices("cpu")
3304+
def test_scan_forwarding_correctness(
3305+
self,
3306+
seed,
3307+
num_body_consts,
3308+
num_const_fwds,
3309+
num_input_fwds):
3310+
3311+
num_carry = num_const_fwds + 4
3312+
num_xs = num_input_fwds + 2
3313+
num_ys = num_xs + 1
3314+
3315+
rng = np.random.RandomState(seed)
3316+
carry_perm = rng.permutation(num_carry)
3317+
carry_iperm = np.argsort(carry_perm)
3318+
3319+
xs_perm = rng.permutation(num_xs)
3320+
ys_perm = rng.permutation(num_ys)
3321+
f = np.arange(num_xs)
3322+
f = [f[i] if idx < num_input_fwds else None for idx, i in enumerate(xs_perm)]
3323+
f += [None]
3324+
in_fwd = [f[i] for i in ys_perm]
3325+
3326+
body_consts = [rng.randn(3) for _ in range(num_body_consts)]
3327+
init_vals = list(rng.uniform(size=num_carry))
3328+
3329+
def body_fun(c, x):
3330+
c = [c[i] for i in carry_iperm]
3331+
carry_fwds, carry_dont_fwd = split_list(c, [num_const_fwds])
3332+
carry_dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts)
3333+
for x in carry_dont_fwd]
3334+
new_c_perm = [*carry_fwds, *carry_dont_fwd]
3335+
new_c = [new_c_perm[i] for i in carry_perm]
3336+
3337+
x = [x[i] for i in xs_perm]
3338+
x_fwd, x_dont_fwd = split_list(x, [num_input_fwds])
3339+
x_dont_fwd = [jnp.cos(x) * sum(jnp.sum(c) for c in body_consts)
3340+
for x in x_dont_fwd]
3341+
y = [*x_fwd, *x_dont_fwd, 0]
3342+
y = [y[i] for i in ys_perm]
3343+
3344+
return new_c, y
3345+
3346+
xs = list(rng.uniform(size=(num_xs, 2)))
3347+
final, outs = jax.lax.scan(body_fun, init_vals, xs)
3348+
for f, y in zip(in_fwd, outs):
3349+
if f is not None:
3350+
self.assertAllClose(y, xs[f])
3351+
3352+
final_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0]
3353+
self.assertAllClose(final, final_ref, check_dtypes=False)
3354+
33023355
def test_scan_diff_of_print(self):
33033356
# ref: https://github.com/jax-ml/jax/issues/28738
33043357
def f(c, _):
@@ -3311,6 +3364,24 @@ def g(x):
33113364
eqn_jaxpr = jaxpr.eqns[0].params["jaxpr"]
33123365
self.assertIn("debug_callback", [e.primitive.name for e in eqn_jaxpr.eqns])
33133366

3367+
def test_scan_input_to_output_forwarding(self):
3368+
def f(c, x):
3369+
return c + 1, x
3370+
def g(x):
3371+
return jax.lax.scan(f, 0, x)
3372+
jaxpr = jax.make_jaxpr(g)(jnp.arange(3.))
3373+
self.assertLen(jaxpr.eqns[0].params["jaxpr"].jaxpr.outvars, 1)
3374+
3375+
def test_scan_only_forwarding(self):
3376+
def f(_, x):
3377+
return None, x
3378+
def g(x):
3379+
return jax.lax.scan(f, None, x)
3380+
x = jnp.arange(3)
3381+
jaxpr = jax.make_jaxpr(g)(x)
3382+
self.assertLen(jaxpr.eqns, 0)
3383+
self.assertArraysEqual(g(x)[1], x)
3384+
33143385

33153386
if __name__ == '__main__':
33163387
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)