@@ -98,22 +98,22 @@ def process_primitive(self, primitive, tracers, params):
98
98
return self .default_process_primitive (primitive , tracers , params )
99
99
100
100
def default_process_primitive (self , primitive , tracers , params ):
101
- pvs , consts = unzip2 (t .pval for t in tracers )
102
- if all (pv is None for pv in pvs ):
103
- return primitive .bind (* consts , ** params )
104
- tracers = map (self .instantiate_const , tracers )
105
- avals = [t .aval for t in tracers ]
106
- out_aval = primitive .abstract_eval (* avals , ** params )
107
- if primitive .multiple_results :
108
- out_tracers = [JaxprTracer (self , PartialVal ((aval , unit )), None )
109
- for aval in out_aval ]
110
- eqn = new_eqn_recipe (tracers , out_tracers , primitive , params )
111
- for t in out_tracers : t .recipe = eqn
112
- return out_tracers
113
- else :
114
- out_tracer = JaxprTracer (self , PartialVal ((out_aval , unit )), None )
115
- out_tracer .recipe = new_eqn_recipe (tracers , [out_tracer ], primitive , params )
116
- return out_tracer
101
+ pvs , consts = unzip2 (t .pval for t in tracers )
102
+ if all (pv is None for pv in pvs ):
103
+ return primitive .bind (* consts , ** params )
104
+ tracers = map (self .instantiate_const , tracers )
105
+ avals = [t .aval for t in tracers ]
106
+ out_aval = primitive .abstract_eval (* avals , ** params )
107
+ if primitive .multiple_results :
108
+ out_tracers = [JaxprTracer (self , PartialVal ((aval , unit )), None )
109
+ for aval in out_aval ]
110
+ eqn = new_eqn_recipe (tracers , out_tracers , primitive , params )
111
+ for t in out_tracers : t .recipe = eqn
112
+ return out_tracers
113
+ else :
114
+ out_tracer = JaxprTracer (self , PartialVal ((out_aval , unit )), None )
115
+ out_tracer .recipe = new_eqn_recipe (tracers , [out_tracer ], primitive , params )
116
+ return out_tracer
117
117
118
118
def process_call (self , call_primitive , f , tracers , params ):
119
119
name = params .get ('name' , f .__name__ )
0 commit comments