Skip to content

Commit 376d269

Browse files
committed
Apply extensive input to extensive output forwarding in scan.
1 parent f429162 commit 376d269

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

jax/_src/lax/control_flow/loops.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def _create_jaxpr(init):
341341
# If the body forwards an input carry to an output carry, that input is
342342
# read-only and can be moved to be a const. Doing so can lead to efficiency
343343
# wins, e.g. if the scan is inside a cond with a batched predicate.
344-
carry_fwd, _ = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry])
344+
carry_fwd, in_fwd = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry])
345345
move_to_const = [len(consts) + i == f for i, f in enumerate(carry_fwd)]
346346
if any(move_to_const):
347347
jaxpr = pe.prune_closed_jaxpr_outputs(
@@ -352,12 +352,30 @@ def _create_jaxpr(init):
352352
consts = [*new_consts, *consts]
353353
num_carry -= len(new_consts)
354354

355+
# When an extensive output is forwarded from an extensive input, we can
356+
# avoid copying it by pruning it from the jaxpr and forwarding manually. We
357+
# don't need to update the indexing based on the optimization above since it
358+
# doesn't change the total number of consts and carries combined. But, we do
359+
# remove the number of consts from the index since `_jaxpr_forwarding`
360+
# includes the Jaxpr's constvars when indexing.
361+
in_fwd = [in_idx - len(consts) if in_idx is not None and
362+
in_idx >= num_carry + len(consts) else None for in_idx in in_fwd]
363+
jaxpr = pe.prune_closed_jaxpr_outputs(
364+
jaxpr, [True] * num_carry + [i is None for i in in_fwd])
365+
355366
out = scan_p.bind(*consts, *in_flat,
356367
reverse=reverse, length=length, jaxpr=jaxpr,
357368
num_consts=len(consts), num_carry=num_carry,
358369
linear=(False,) * (len(consts) + len(in_flat)),
359370
unroll=unroll, _split_transpose=_split_transpose)
360371

372+
# Apply input to output forwarding that was computed above.
373+
consts_out, out = split_list(out, [num_carry])
374+
out_ = iter(out)
375+
out = [next(out_) if f is None else _maybe_put(in_flat[f]) for f in in_fwd]
376+
assert next(out_, None) is None
377+
out = [*consts_out, *out]
378+
361379
if any(move_to_const):
362380
out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts)
363381

tests/lax_control_flow_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3311,6 +3311,14 @@ def g(x):
33113311
eqn_jaxpr = jaxpr.eqns[0].params["jaxpr"]
33123312
self.assertIn("debug_callback", [e.primitive.name for e in eqn_jaxpr.eqns])
33133313

3314+
def test_scan_input_to_output_forwarding(self):
3315+
def f(c, x):
3316+
return c + 1, x
3317+
def g(x):
3318+
return jax.lax.scan(f, 0, x)
3319+
jaxpr = jax.make_jaxpr(g)(jnp.arange(3.))
3320+
self.assertLen(jaxpr.eqns[0].params["jaxpr"].jaxpr.outvars, 1)
3321+
33143322

33153323
if __name__ == '__main__':
33163324
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)