Skip to content

[scan] don't hoist loop-invariant computations in scan, just forward #28993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented May 23, 2025

No description provided.

@mattjj mattjj requested review from dfm and dougalm May 23, 2025 23:18
@mattjj mattjj self-assigned this May 23, 2025
@mattjj mattjj added the pull ready Ready for copybara import and testing label May 23, 2025
@mattjj mattjj force-pushed the scan-dont-hoist branch from 9ba357a to fc67d17 Compare May 24, 2025 00:03
in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr_known.jaxpr)
in_fwd = [f if out_idx >= num_knowns_out and f is not None and
(f < num_consts_known or f >= num_consts_known + num_carry_known)
and isinstance(known_inputs[f], jax.Array) # no np.ndarrays
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the other forwarding logic, we use the _maybe_put helper to (I think?) support forwarding numpy arrays, but casting them to CPU device arrays. It seems like that should work here too, but I could well be missing something. What's the motivation for one or the other?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this "dont forward numpy.ndarrays" line because this PR caused a failure in some Pathways test without it. I think _maybe_put was perhaps hardcoding the Arrays to be on host, rather than on default device, and that caused a really low-level error.

Out of laziness I decided to try just skipping the forwarding for numpy.ndarrays because I suspect we don't really need it. Another fix might be to improve _maybe_put, but I haven't looked into that at all.

@mattjj mattjj force-pushed the scan-dont-hoist branch 2 times, most recently from 6e08f0d to 08bc676 Compare May 28, 2025 17:39
@mattjj mattjj marked this pull request as ready for review May 28, 2025 21:40
@mattjj mattjj force-pushed the scan-dont-hoist branch from 08bc676 to aa91b06 Compare May 29, 2025 00:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants