@@ -341,7 +341,7 @@ def _create_jaxpr(init):
341
341
# If the body forwards an input carry to an output carry, that input is
342
342
# read-only and can be moved to be a const. Doing so can lead to efficiency
343
343
# wins, e.g. if the scan is inside a cond with a batched predicate.
344
- carry_fwd , _ = split_list (pe ._jaxpr_forwarding (jaxpr .jaxpr ), [num_carry ])
344
+ carry_fwd , in_fwd = split_list (pe ._jaxpr_forwarding (jaxpr .jaxpr ), [num_carry ])
345
345
move_to_const = [len (consts ) + i == f for i , f in enumerate (carry_fwd )]
346
346
if any (move_to_const ):
347
347
jaxpr = pe .prune_closed_jaxpr_outputs (
@@ -352,12 +352,30 @@ def _create_jaxpr(init):
352
352
consts = [* new_consts , * consts ]
353
353
num_carry -= len (new_consts )
354
354
355
+ # When an extensive output is forwarded from an extensive input, we can
356
+ # avoid copying it by pruning it from the jaxpr and forwarding manually. We
357
+ # don't need to update the indexing based on the optimization above since it
358
+ # doesn't change the total number of consts and carries combined. But, we do
359
+ # remove the number of consts from the index since `_jaxpr_forwarding`
360
+ # includes the Jaxpr's constvars when indexing.
361
+ in_fwd = [in_idx - len (consts ) if in_idx is not None and
362
+ in_idx >= num_carry + len (consts ) else None for in_idx in in_fwd ]
363
+ jaxpr = pe .prune_closed_jaxpr_outputs (
364
+ jaxpr , [True ] * num_carry + [i is None for i in in_fwd ])
365
+
355
366
out = scan_p .bind (* consts , * in_flat ,
356
367
reverse = reverse , length = length , jaxpr = jaxpr ,
357
368
num_consts = len (consts ), num_carry = num_carry ,
358
369
linear = (False ,) * (len (consts ) + len (in_flat )),
359
370
unroll = unroll , _split_transpose = _split_transpose )
360
371
372
+ # Apply input to output forwarding that was computed above.
373
+ consts_out , out = split_list (out , [num_carry ])
374
+ out_ = iter (out )
375
+ out = [next (out_ ) if f is None else _maybe_put (in_flat [f ]) for f in in_fwd ]
376
+ assert next (out_ , None ) is None
377
+ out = [* consts_out , * out ]
378
+
361
379
if any (move_to_const ):
362
380
out = pe .merge_lists (move_to_const + [False ] * num_ys , out , new_consts )
363
381
0 commit comments