Skip to content

[shard-map] in eager shmap, handle all rep rule output cases #27797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2025

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Apr 7, 2025

By convention, rep_rules can return three kinds of thing:

  1. a sequence (tuple or list),
  2. a single set, or
  3. a single None.

Even rules for primitives with multiple results can return single objects rather than sequences; the reason is that it's convenient not ot have to infer the number of outputs for higher-order primitives.

In the latter two cases we rely on the caller (in this case, ShardMapTrace.process_primitive) to 'broadcast' the singleton result to a list of results equal to the number of outputs.

Previously, the code was checking if type(out_rep) is set, which doesn't handle case 3.

(We briefly tried another fix direction where we don't allow case 3, because we don't have case 3 in the upcoming VMA type system which replaces this stuff. But until that lands the easiest fix is just to handle all cases correctly.)

fixes #26148, fixes #27673

This PR replaces #27720

@mattjj mattjj requested a review from justinjfu April 7, 2025 19:18
Copy link

google-cla bot commented Apr 7, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

By convention, rep_rules can return three kinds of thing:
1. a sequence (tuple or list),
2. a single set, or
3. a single None.

Even rules for primitives with multiple results can return single objects rather than sequences; the reason is that it's convenient not ot have to infer the number of outputs for higher-order primitives.

In the latter two cases we rely on the caller (in this case, ShardMapTrace.process_primitive) to 'broadcast' the singleton result to a list of results equal to the number of outputs.

Previously, the code was checking `if type(out_rep) is set`, which doesn't handle case 3.

(We briefly tried another fix direction where we don't allow case 3, because we don't have case 3 in the upcoming VMA type system which replaces this stuff. But until that lands the easiest fix is just to handle all cases correctly.)

fixes jax-ml#26148, fixes jax-ml#27673

Co-authored-by: Justin Fu <justinfu@google.com>
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 7, 2025
@copybara-service copybara-service bot merged commit 0d06731 into jax-ml:main Apr 7, 2025
26 of 27 checks passed
@mattjj mattjj deleted the shmap-fix branch April 7, 2025 20:20
mattjj added a commit to mattjj/jax that referenced this pull request Apr 9, 2025
…k_rep

this was essentially another instance of the jax-ml#27797 fix

fixes jax-ml#24762
charleshofer pushed a commit to ROCm/jax that referenced this pull request Apr 30, 2025
…k_rep

this was essentially another instance of the jax-ml#27797 fix

fixes jax-ml#24762
charleshofer pushed a commit to ROCm/jax that referenced this pull request May 1, 2025
…k_rep

this was essentially another instance of the jax-ml#27797 fix

fixes jax-ml#24762
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

JAX 0.4.38 breaks jax.numpy.split in shard_map shard_map errors when composing with jax.jacrev or jax.jacfwd with jax 0.5.0
2 participants