Skip to content

Commit f4b946e

Browse files
authored
Merge pull request jax-ml#2199 from sharadmv/patch-1
Fix inconsistent indentation in `JaxprTrace.default_process_primitive`.
2 parents 28e802c + 76d77bf commit f4b946e

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

jax/interpreters/partial_eval.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,22 @@ def process_primitive(self, primitive, tracers, params):
9898
return self.default_process_primitive(primitive, tracers, params)
9999

100100
def default_process_primitive(self, primitive, tracers, params):
101-
pvs, consts = unzip2(t.pval for t in tracers)
102-
if all(pv is None for pv in pvs):
103-
return primitive.bind(*consts, **params)
104-
tracers = map(self.instantiate_const, tracers)
105-
avals = [t.aval for t in tracers]
106-
out_aval = primitive.abstract_eval(*avals, **params)
107-
if primitive.multiple_results:
108-
out_tracers = [JaxprTracer(self, PartialVal((aval, unit)), None)
109-
for aval in out_aval]
110-
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params)
111-
for t in out_tracers: t.recipe = eqn
112-
return out_tracers
113-
else:
114-
out_tracer = JaxprTracer(self, PartialVal((out_aval, unit)), None)
115-
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, params)
116-
return out_tracer
101+
pvs, consts = unzip2(t.pval for t in tracers)
102+
if all(pv is None for pv in pvs):
103+
return primitive.bind(*consts, **params)
104+
tracers = map(self.instantiate_const, tracers)
105+
avals = [t.aval for t in tracers]
106+
out_aval = primitive.abstract_eval(*avals, **params)
107+
if primitive.multiple_results:
108+
out_tracers = [JaxprTracer(self, PartialVal((aval, unit)), None)
109+
for aval in out_aval]
110+
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params)
111+
for t in out_tracers: t.recipe = eqn
112+
return out_tracers
113+
else:
114+
out_tracer = JaxprTracer(self, PartialVal((out_aval, unit)), None)
115+
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, params)
116+
return out_tracer
117117

118118
def process_call(self, call_primitive, f, tracers, params):
119119
name = params.get('name', f.__name__)

0 commit comments

Comments
 (0)