Skip to content

Commit

Permalink
[shard-map] relax vmap-of-shmap error with spmd_axis_name
Browse files Browse the repository at this point in the history
Because Matt W said so

Co-authored-by: Matthew Wiethoff <wiethoff@google.com>
  • Loading branch information
mattjj and Matthew Wiethoff committed Jul 25, 2024
1 parent 7de3c06 commit fd0fbe2
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,9 +1313,9 @@ def _shard_map_batch(
for ax in names} for names, d in zip(in_names, in_dims)]
spmd_axis_name = trace.spmd_axis_name
if spmd_axis_name is not None:
used = {n for names in in_names for ns in names.values() for n in ns}
if set(spmd_axis_name) & used:
raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs")
for names, dim in zip(in_names, in_dims):
if set(spmd_axis_name) & set(names.values()):
raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs")
new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped # type: ignore
else ns for ns, d in zip(new_in_names, in_dims)]
@as_hashable_function(closure=out_names_thunk)
Expand Down

0 comments on commit fd0fbe2

Please sign in to comment.