@@ -657,8 +657,10 @@ def write(v: core.Var, val: RepType) -> None:
657
657
for e in jaxpr .eqns :
658
658
rule = _check_rules .get (e .primitive , partial (_rule_missing , e .primitive ))
659
659
out_rep = rule (mesh , * map (read , e .invars ), ** e .params )
660
+
660
661
if e .primitive .multiple_results :
661
- out_rep = [out_rep ] * len (e .outvars ) if type (out_rep ) is set else out_rep
662
+ if type (out_rep ) in [set , frozenset ]:
663
+ out_rep = [out_rep ] * len (e .outvars )
662
664
foreach (write , e .outvars , out_rep )
663
665
else :
664
666
write (e .outvars [0 ], out_rep )
@@ -942,7 +944,7 @@ def to_val_rep_pair(self, val):
942
944
raise Exception (f"Shouldn't have any non-shard_map tracers: { val } " )
943
945
else :
944
946
val_ = _unmatch_spec (self .mesh , {}, val , self .context_mesh )
945
- return val_ , None
947
+ return val_ , frozenset ( self . mesh . axis_names ) - self . auto
946
948
947
949
def process_primitive (self , prim , tracers , params ):
948
950
in_vals , in_rep = unzip2 (map (self .to_val_rep_pair , tracers ))
@@ -958,7 +960,8 @@ def process_primitive(self, prim, tracers, params):
958
960
rep_rule = _check_rules .get (prim , partial (_rule_missing , prim ))
959
961
out_rep = rep_rule (self .mesh , * in_rep , ** params ) if self .check else set ()
960
962
if prim .multiple_results :
961
- out_rep = [out_rep ] * len (out_vals ) if type (out_rep ) is set else out_rep
963
+ if type (out_rep ) in [set , frozenset ]:
964
+ out_rep = [out_rep ] * len (out_vals )
962
965
return map (partial (ShardMapTracer , self ), out_rep , out_vals )
963
966
return ShardMapTracer (self , out_rep , out_vals )
964
967
0 commit comments