Skip to content

Commit bbb0c51

Browse files
committed
Fix shard_map check_rep for multiple outputs.
1 parent 1bd0c58 commit bbb0c51

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

jax/experimental/shard_map.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,8 +657,10 @@ def write(v: core.Var, val: RepType) -> None:
657657
for e in jaxpr.eqns:
658658
rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive))
659659
out_rep = rule(mesh, *map(read, e.invars), **e.params)
660+
660661
if e.primitive.multiple_results:
661-
out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep
662+
if type(out_rep) is set or out_rep is None:
663+
out_rep = [out_rep] * len(e.outvars)
662664
foreach(write, e.outvars, out_rep)
663665
else:
664666
write(e.outvars[0], out_rep)
@@ -958,7 +960,8 @@ def process_primitive(self, prim, tracers, params):
958960
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
959961
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
960962
if prim.multiple_results:
961-
out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep
963+
if type(out_rep) is set or out_rep is None:
964+
out_rep = [out_rep] * len(out_vals)
962965
return map(partial(ShardMapTracer, self), out_rep, out_vals)
963966
return ShardMapTracer(self, out_rep, out_vals)
964967

tests/shard_map_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,17 @@ def f3():
685685
f3()
686686
jax.jit(f3)()
687687

688+
def test_multiple_result_primitive_with_none_sharding(self):
689+
# https://github.com/jax-ml/jax/issues/27673
690+
xs = jnp.arange(20).reshape(2, 10)
691+
mesh = jtu.create_mesh((2,), ("i",))
692+
shard_map(
693+
lambda x: jnp.split(x.squeeze(), 2),
694+
mesh=mesh,
695+
in_specs=(None,),
696+
out_specs=P("i"),
697+
)(xs)
698+
688699
def test_vmap_spmd_axis_name(self):
689700
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
690701

0 commit comments

Comments
 (0)