Skip to content

Apply extensive input to extensive output forwarding in scan. #28985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _create_jaxpr(init):
# If the body forwards an input carry to an output carry, that input is
# read-only and can be moved to be a const. Doing so can lead to efficiency
# wins, e.g. if the scan is inside a cond with a batched predicate.
carry_fwd, _ = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry])
carry_fwd, ext_fwd = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry])
move_to_const = [len(consts) + i == f for i, f in enumerate(carry_fwd)]
if any(move_to_const):
jaxpr = pe.prune_closed_jaxpr_outputs(
Expand All @@ -352,12 +352,32 @@ def _create_jaxpr(init):
consts = [*new_consts, *consts]
num_carry -= len(new_consts)

# When an extensive output is forwarded from an extensive input, we can
# avoid copying it by pruning it from the jaxpr and forwarding manually. We
# don't need to update the indexing based on the optimization above since it
# doesn't change the total number of consts and carries combined, and
# `ext_fwd` already only includes the extensive outputs. But, we do remove
# the number of consts from the index since we're going to use it to index
# into `in_flat`, which doesn't include consts.
ext_to_ext_fwd = [
in_idx - len(consts) if in_idx is not None and
in_idx >= num_carry + len(consts) else None for in_idx in ext_fwd]
jaxpr = pe.prune_closed_jaxpr_outputs(
jaxpr, [True] * num_carry + [i is None for i in ext_to_ext_fwd])

out = scan_p.bind(*consts, *in_flat,
reverse=reverse, length=length, jaxpr=jaxpr,
num_consts=len(consts), num_carry=num_carry,
linear=(False,) * (len(consts) + len(in_flat)),
unroll=unroll, _split_transpose=_split_transpose)

# Apply input to output forwarding that was computed above.
carry_out, out = split_list(out, [num_carry])
out_ = iter(out)
out = [next(out_) if f is None else _maybe_put(in_flat[f]) for f in ext_to_ext_fwd]
assert next(out_, None) is None
out = [*carry_out, *out]

if any(move_to_const):
out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts)

Expand Down
61 changes: 61 additions & 0 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3299,6 +3299,59 @@ def body_fun(c, _):
outs_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0]
self.assertAllClose(outs, outs_ref, check_dtypes=False)

@parameterized.parameters(itertools.product(range(3), repeat=4))
@jtu.run_on_devices("cpu")
def test_scan_forwarding_correctness(
self,
seed,
num_body_consts,
num_const_fwds,
num_input_fwds):

num_carry = num_const_fwds + 4
num_xs = num_input_fwds + 2
num_ys = num_xs + 1

rng = np.random.RandomState(seed)
carry_perm = rng.permutation(num_carry)
carry_iperm = np.argsort(carry_perm)

xs_perm = rng.permutation(num_xs)
ys_perm = rng.permutation(num_ys)
f = np.arange(num_xs)
f = [f[i] if idx < num_input_fwds else None for idx, i in enumerate(xs_perm)]
f += [None]
in_fwd = [f[i] for i in ys_perm]

body_consts = [rng.randn(3) for _ in range(num_body_consts)]
init_vals = list(rng.uniform(size=num_carry))

def body_fun(c, x):
c = [c[i] for i in carry_iperm]
carry_fwds, carry_dont_fwd = split_list(c, [num_const_fwds])
carry_dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts)
for x in carry_dont_fwd]
new_c_perm = [*carry_fwds, *carry_dont_fwd]
new_c = [new_c_perm[i] for i in carry_perm]

x = [x[i] for i in xs_perm]
x_fwd, x_dont_fwd = split_list(x, [num_input_fwds])
x_dont_fwd = [jnp.cos(x) * sum(jnp.sum(c) for c in body_consts)
for x in x_dont_fwd]
y = [*x_fwd, *x_dont_fwd, 0]
y = [y[i] for i in ys_perm]

return new_c, y

xs = list(rng.uniform(size=(num_xs, 2)))
final, outs = jax.lax.scan(body_fun, init_vals, xs)
for f, y in zip(in_fwd, outs):
if f is not None:
self.assertAllClose(y, xs[f])

final_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0]
self.assertAllClose(final, final_ref, check_dtypes=False)

def test_scan_diff_of_print(self):
# ref: https://github.com/jax-ml/jax/issues/28738
def f(c, _):
Expand All @@ -3311,6 +3364,14 @@ def g(x):
eqn_jaxpr = jaxpr.eqns[0].params["jaxpr"]
self.assertIn("debug_callback", [e.primitive.name for e in eqn_jaxpr.eqns])

def test_scan_input_to_output_forwarding(self):
def f(c, x):
return c + 1, x
def g(x):
return jax.lax.scan(f, 0, x)
jaxpr = jax.make_jaxpr(g)(jnp.arange(3.))
self.assertLen(jaxpr.eqns[0].params["jaxpr"].jaxpr.outvars, 1)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
Loading