Skip to content

Fix shard_map for primitives with multiple outputs and None specs #27720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,8 @@ def write(v: core.Var, val: RepType) -> None:
rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive))
out_rep = rule(mesh, *map(read, e.invars), **e.params)
if e.primitive.multiple_results:
out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep
if isinstance(out_rep, (set, frozenset)):
out_rep = [out_rep] * len(e.outvars)
foreach(write, e.outvars, out_rep)
else:
write(e.outvars[0], out_rep)
Expand Down Expand Up @@ -942,7 +943,7 @@ def to_val_rep_pair(self, val):
raise Exception(f"Shouldn't have any non-shard_map tracers: {val}")
else:
val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh)
return val_, None
return val_, frozenset(self.mesh.axis_names) - self.auto

def process_primitive(self, prim, tracers, params):
in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers))
Expand All @@ -958,7 +959,8 @@ def process_primitive(self, prim, tracers, params):
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
if prim.multiple_results:
out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep
if isinstance(out_rep, (set, frozenset)):
out_rep = [out_rep] * len(out_vals)
return map(partial(ShardMapTracer, self), out_rep, out_vals)
return ShardMapTracer(self, out_rep, out_vals)

Expand Down
13 changes: 13 additions & 0 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,19 @@ def f3():
f3()
jax.jit(f3)()

def test_multiple_result_primitive_with_none_sharding(self):
# https://github.com/jax-ml/jax/issues/27673
xs = jnp.arange(20).reshape(2, 10)
mesh = jtu.create_mesh((2,), ("i",))
y = shard_map(
lambda x: jnp.split(x.squeeze(), 2),
mesh=mesh,
in_specs=(None,),
out_specs=P("i"),
)(xs)
expected = jnp.repeat(xs, 2, axis=0).reshape(2, 2, 10)
self.assertArraysEqual(y, expected)

def test_vmap_spmd_axis_name(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))

Expand Down
Loading