@@ -657,8 +657,10 @@ def write(v: core.Var, val: RepType) -> None:
657
657
for e in jaxpr .eqns :
658
658
rule = _check_rules .get (e .primitive , partial (_rule_missing , e .primitive ))
659
659
out_rep = rule (mesh , * map (read , e .invars ), ** e .params )
660
+
660
661
if e .primitive .multiple_results :
661
- out_rep = [out_rep ] * len (e .outvars ) if type (out_rep ) is set else out_rep
662
+ if type (out_rep ) is set or out_rep is None :
663
+ out_rep = [out_rep ] * len (e .outvars )
662
664
foreach (write , e .outvars , out_rep )
663
665
else :
664
666
write (e .outvars [0 ], out_rep )
@@ -958,7 +960,8 @@ def process_primitive(self, prim, tracers, params):
958
960
rep_rule = _check_rules .get (prim , partial (_rule_missing , prim ))
959
961
out_rep = rep_rule (self .mesh , * in_rep , ** params ) if self .check else set ()
960
962
if prim .multiple_results :
961
- out_rep = [out_rep ] * len (out_vals ) if type (out_rep ) is set else out_rep
963
+ if type (out_rep ) is set or out_rep is None :
964
+ out_rep = [out_rep ] * len (out_vals )
962
965
return map (partial (ShardMapTracer , self ), out_rep , out_vals )
963
966
return ShardMapTracer (self , out_rep , out_vals )
964
967
0 commit comments