Skip to content

Commit 0d0a44e

Browse files
committed
Fix shard_map check_rep for multiple outputs.
1 parent 1bd0c58 commit 0d0a44e

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

jax/experimental/shard_map.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,8 @@ def write(v: core.Var, val: RepType) -> None:
658658
rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive))
659659
out_rep = rule(mesh, *map(read, e.invars), **e.params)
660660
if e.primitive.multiple_results:
661-
out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep
661+
if isinstance(out_rep, (set, frozenset)):
662+
out_rep = [out_rep] * len(e.outvars)
662663
foreach(write, e.outvars, out_rep)
663664
else:
664665
write(e.outvars[0], out_rep)
@@ -942,7 +943,7 @@ def to_val_rep_pair(self, val):
942943
raise Exception(f"Shouldn't have any non-shard_map tracers: {val}")
943944
else:
944945
val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh)
945-
return val_, None
946+
return val_, frozenset(self.mesh.axis_names) - self.auto
946947

947948
def process_primitive(self, prim, tracers, params):
948949
in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers))
@@ -958,7 +959,8 @@ def process_primitive(self, prim, tracers, params):
958959
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
959960
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
960961
if prim.multiple_results:
961-
out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep
962+
if isinstance(out_rep, (set, frozenset)):
963+
out_rep = [out_rep] * len(out_vals)
962964
return map(partial(ShardMapTracer, self), out_rep, out_vals)
963965
return ShardMapTracer(self, out_rep, out_vals)
964966

tests/shard_map_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,19 @@ def f3():
685685
f3()
686686
jax.jit(f3)()
687687

688+
def test_multiple_result_primitive_with_none_sharding(self):
689+
# https://github.com/jax-ml/jax/issues/27673
690+
xs = jnp.arange(20).reshape(2, 10)
691+
mesh = jtu.create_mesh((2,), ("i",))
692+
y = shard_map(
693+
lambda x: jnp.split(x.squeeze(), 2),
694+
mesh=mesh,
695+
in_specs=(None,),
696+
out_specs=P("i"),
697+
)(xs)
698+
expected = jnp.repeat(xs, 2, axis=0).reshape(2, 2, 10)
699+
self.assertArraysEqual(y, expected)
700+
688701
def test_vmap_spmd_axis_name(self):
689702
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
690703

0 commit comments

Comments
 (0)