@@ -658,7 +658,8 @@ def write(v: core.Var, val: RepType) -> None:
658
658
rule = _check_rules .get (e .primitive , partial (_rule_missing , e .primitive ))
659
659
out_rep = rule (mesh , * map (read , e .invars ), ** e .params )
660
660
if e .primitive .multiple_results :
661
- out_rep = [out_rep ] * len (e .outvars ) if type (out_rep ) is set else out_rep
661
+ if isinstance (out_rep , (set , frozenset )):
662
+ out_rep = [out_rep ] * len (e .outvars )
662
663
foreach (write , e .outvars , out_rep )
663
664
else :
664
665
write (e .outvars [0 ], out_rep )
@@ -942,7 +943,7 @@ def to_val_rep_pair(self, val):
942
943
raise Exception (f"Shouldn't have any non-shard_map tracers: { val } " )
943
944
else :
944
945
val_ = _unmatch_spec (self .mesh , {}, val , self .context_mesh )
945
- return val_ , None
946
+ return val_ , frozenset ( self . mesh . axis_names ) - self . auto
946
947
947
948
def process_primitive (self , prim , tracers , params ):
948
949
in_vals , in_rep = unzip2 (map (self .to_val_rep_pair , tracers ))
@@ -958,7 +959,8 @@ def process_primitive(self, prim, tracers, params):
958
959
rep_rule = _check_rules .get (prim , partial (_rule_missing , prim ))
959
960
out_rep = rep_rule (self .mesh , * in_rep , ** params ) if self .check else set ()
960
961
if prim .multiple_results :
961
- out_rep = [out_rep ] * len (out_vals ) if type (out_rep ) is set else out_rep
962
+ if isinstance (out_rep , (set , frozenset )):
963
+ out_rep = [out_rep ] * len (out_vals )
962
964
return map (partial (ShardMapTracer , self ), out_rep , out_vals )
963
965
return ShardMapTracer (self , out_rep , out_vals )
964
966
0 commit comments