Skip to content

Apply extensive input to extensive output forwarding in scan. #28985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dfm
Copy link
Contributor

@dfm dfm commented May 23, 2025

This PR applies a small optimization to scan for the case where some outputs are directly forwarded inputs. For example, in

def body(c, x):
  return c + 1, x

_, y = jax.lax.scan(body, 0, jnp.arange(10))

y would be computed by copying each element of x into a new array. After this change, y is returned as x without making a copy.

@dfm dfm requested a review from mattjj May 23, 2025 19:19
@dfm dfm self-assigned this May 23, 2025
@dfm dfm force-pushed the scan-fwd-ext-traceable branch from 376d269 to bbc2905 Compare May 23, 2025 19:21
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful, thanks @dfm

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels May 27, 2025
@dfm dfm force-pushed the scan-fwd-ext-traceable branch 3 times, most recently from f617ece to 548a61d Compare May 28, 2025 13:33
@dfm dfm force-pushed the scan-fwd-ext-traceable branch from 548a61d to 0e86574 Compare May 28, 2025 16:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants