Skip to content

Commit fb7e48f

Browse files
authored
Merge pull request jax-ml#2176 from gnecula/simple_jaxpr2
Simplify Jaxpr: remove the bound_subjaxpr field, all subjaxprs are in…
2 parents f4b946e + 20f9230 commit fb7e48f

File tree

10 files changed

+147
-119
lines changed

10 files changed

+147
-119
lines changed

CHANGELOG.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ These are the release notes for JAX.
77
### Breaking changes
88

99
* The minimum jaxlib version is now 0.1.38.
10-
* Simplified `Jaxpr` by removing the `Jaxpr.freevars` and changing the
11-
representation of `Jaxpr.bound_subjaxprs` to drop the environment values.
10+
* Simplified `Jaxpr` by removing the `Jaxpr.freevars` and
11+
`Jaxpr.bound_subjaxprs`. The call primitives (`xla_call`, `xla_pmap`,
12+
`sharded_call`, and `remat_call`) get a new parameter `call_jaxpr` with a
13+
fully-closed (no `constvars`) JAXPR.
1214

1315
### New features
1416

jax/core.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@ def __str__(self):
5858
return str(pp_jaxpr(self))
5959
__repr__ = __str__
6060

61+
62+
def subjaxprs(jaxpr):
63+
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
64+
Does not descend recursively into the found subjaxprs.
65+
"""
66+
for eqn in jaxpr.eqns:
67+
for param in eqn.params.values():
68+
if type(param) is Jaxpr:
69+
yield param
70+
elif type(param) is TypedJaxpr:
71+
yield param.jaxpr
72+
73+
6174
class TypedJaxpr(object):
6275
def __init__(self, jaxpr, literals, in_avals, out_avals):
6376
assert type(jaxpr) is Jaxpr
@@ -84,8 +97,8 @@ def jaxpr_as_fun(typed_jaxpr, *args):
8497
return eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, *args)
8598

8699

87-
JaxprEqn = namedtuple('JaxprEqn', ['invars', 'outvars', 'primitive',
88-
'bound_subjaxpr', 'params'])
100+
101+
JaxprEqn = namedtuple('JaxprEqn', ['invars', 'outvars', 'primitive', 'params'])
89102
JaxprEqn.__repr__ = JaxprEqn.__str__ = lambda eqn: str(pp_eqn(eqn)).rstrip()
90103
new_jaxpr_eqn = JaxprEqn
91104

@@ -149,6 +162,8 @@ def __repr__(self):
149162

150163
class Primitive(object):
151164
multiple_results = False # override for multi-output primitives
165+
call_primitive = False # override for higher-order primitives that are
166+
# processed in final style.
152167

153168
def __init__(self, name):
154169
self.name = name
@@ -193,6 +208,24 @@ def abstract_eval(self, *args, **kwargs):
193208

194209
# -------------------- lifting --------------------
195210

211+
# TODO(necula): this belongs next to pe.new_eqn_recipe, but is needed in
212+
# core.py. Plan to move all these utilities to jaxpr.py.
213+
def extract_call_jaxpr(primitive, params):
214+
"""Extract the call primitive subjaxpr from the params.
215+
216+
Params:
217+
params: a parameter dictionary for a primitive.
218+
Returns: the subjaxpr and the params without the "jaxpr" value. If this is
219+
not a call primitive then returns (None, params).
220+
"""
221+
if not primitive.call_primitive:
222+
return (None, params)
223+
else:
224+
assert "call_jaxpr" in params
225+
new_params = dict(params)
226+
del new_params["call_jaxpr"]
227+
return (params["call_jaxpr"], new_params)
228+
196229

197230
def eval_jaxpr(jaxpr, consts, *args):
198231
def read(v):
@@ -210,11 +243,12 @@ def write(v, val):
210243
map(write, jaxpr.invars, args)
211244
for eqn in jaxpr.eqns:
212245
in_vals = map(read, eqn.invars)
213-
if eqn.bound_subjaxpr:
214-
subfuns = [lu.wrap_init(partial(eval_jaxpr, eqn.bound_subjaxpr, ()))]
246+
call_jaxpr, params = extract_call_jaxpr(eqn.primitive, eqn.params)
247+
if call_jaxpr:
248+
subfuns = [lu.wrap_init(partial(eval_jaxpr, call_jaxpr, ()))]
215249
else:
216250
subfuns = []
217-
ans = eqn.primitive.bind(*(subfuns + in_vals), **eqn.params)
251+
ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
218252
if eqn.primitive.multiple_results:
219253
map(write, eqn.outvars, ans)
220254
else:
@@ -626,6 +660,10 @@ def call_impl(f, *args, **params):
626660
# ------------------- Jaxpr printed representation -------------------
627661

628662
def check_jaxpr(jaxpr):
663+
"""Checks well-formedness of a jaxpr.
664+
665+
Specifically it checks that all variabled used are previously defined.
666+
"""
629667
def context():
630668
return "\njaxpr:\n{}\n".format(jaxpr)
631669

@@ -646,10 +684,16 @@ def write_env(env, v):
646684
map(write, jaxpr.constvars)
647685
map(write, jaxpr.invars)
648686
for eqn in jaxpr.eqns:
687+
if eqn.primitive.call_primitive:
688+
if "call_jaxpr" not in eqn.params:
689+
raise Exception("Call primitive {} should have a 'call_jaxpr' parameter"
690+
.format(eqn.primitive))
649691
map(read, eqn.invars)
650-
if eqn.bound_subjaxpr:
651-
check_jaxpr(eqn.bound_subjaxpr)
652692
map(write, eqn.outvars)
693+
694+
for subjaxpr in subjaxprs(jaxpr):
695+
check_jaxpr(subjaxpr)
696+
653697
map(read, jaxpr.outvars)
654698

655699

@@ -664,10 +708,6 @@ def pp_eqn_compact(primitive_name, params):
664708
def pp_eqn(eqn):
665709
lhs = pp_vars(eqn.outvars)
666710
pp_subexpr = pp('')
667-
if eqn.bound_subjaxpr:
668-
pp_subexpr = pp_subexpr + (
669-
pp_jaxpr(eqn.bound_subjaxpr).indent(2)
670-
>> pp(' [ ]'))
671711
return (pp('{} = '.format(lhs)) >>
672712
pp(eqn.primitive.name) >> pp_kv_pairs(sorted(eqn.params.items()))
673713
>> pp(' ') >> pp(pp_vars(eqn.invars))) + pp_subexpr

jax/interpreters/ad.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def is_linear(var):
172172

173173
linear_eqns = []
174174
for eqn in jaxpr.eqns:
175-
if not eqn.bound_subjaxpr:
175+
if not eqn.primitive.call_primitive:
176176
if any(is_linear(v) for v in eqn.invars):
177177
linear_eqns.append(eqn)
178178
else:
@@ -183,20 +183,20 @@ def is_linear(var):
183183
else:
184184
write_primal(eqn.outvars[0], ans)
185185
else:
186-
subjaxpr = eqn.bound_subjaxpr
186+
call_jaxpr = eqn.params["call_jaxpr"]
187187
if any(is_linear(v) for v in eqn.invars):
188188
linear_eqns.append(eqn)
189189
elif eqn.primitive is not pe.remat_call_p:
190190
ans = _eval_subjaxpr_primals(
191-
eqn.primitive, subjaxpr,
191+
eqn.primitive, call_jaxpr,
192192
map(read_primal, eqn.invars), eqn.params)
193193
map(write_primal, eqn.outvars, ans)
194194

195195
# we special-case remat_call here because it can be mixed linear /
196196
# nonlinear, so we always evaluate it even if it has a linear part
197197
if eqn.primitive is pe.remat_call_p:
198198
ans = _eval_subjaxpr_primals(
199-
eqn.primitive, subjaxpr,
199+
eqn.primitive, call_jaxpr,
200200
map(read_primal, eqn.invars), eqn.params)
201201
map(write_primal, eqn.outvars, ans)
202202

@@ -208,10 +208,10 @@ def is_linear(var):
208208
cts_in = map(read_cotangent, eqn.outvars)
209209
else:
210210
cts_in, = map(read_cotangent, eqn.outvars)
211-
if eqn.bound_subjaxpr:
212-
subjaxpr = eqn.bound_subjaxpr
211+
if eqn.primitive.call_primitive:
212+
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
213213
cts_out = get_primitive_transpose(eqn.primitive)(
214-
eqn.params, subjaxpr, invals, cts_in)
214+
params, call_jaxpr, invals, cts_in)
215215
else:
216216
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params)
217217
cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out
@@ -251,7 +251,7 @@ def is_linear(var):
251251
assert not jaxpr.constvars
252252
map(write_primal, jaxpr.invars, args)
253253
for eqn in jaxpr.eqns:
254-
if not eqn.bound_subjaxpr:
254+
if not eqn.primitive.call_primitive:
255255
if not any(is_linear(v) for v in eqn.invars):
256256
in_vals = map(read_primal, eqn.invars)
257257
ans = eqn.primitive.bind(*in_vals, **eqn.params)
@@ -260,11 +260,11 @@ def is_linear(var):
260260
else:
261261
write_primal(eqn.outvars[0], ans)
262262
else:
263-
subjaxpr = eqn.bound_subjaxpr
263+
call_jaxpr = eqn.params["call_jaxpr"]
264264
if (eqn.primitive is pe.remat_call_p or
265265
not any(is_linear(v) for v in eqn.invars)):
266266
ans = _eval_subjaxpr_primals(
267-
eqn.primitive, subjaxpr,
267+
eqn.primitive, call_jaxpr,
268268
map(read_primal, eqn.invars), eqn.params)
269269
map(write_primal, eqn.outvars, ans)
270270
return map(read_primal, jaxpr.outvars)
@@ -537,19 +537,19 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
537537
yield out_flat, tree_def
538538

539539

540-
def call_transpose(primitive, params, jaxpr, args, ct):
540+
def call_transpose(primitive, params, call_jaxpr, args, ct):
541541
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
542-
fun = lu.hashable_partial(lu.wrap_init(backward_pass), jaxpr)
542+
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
543543
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
544544
params = dict(params, name=wrap_name(params['name'], 'transpose'))
545545
out_flat = primitive.bind(fun, *all_args, **params)
546546
return tree_unflatten(out_tree(), out_flat)
547547
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
548548
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)
549549

550-
def map_transpose(primitive, params, jaxpr, args, ct):
550+
def map_transpose(primitive, params, call_jaxpr, args, ct):
551551
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
552-
fun = lu.hashable_partial(lu.wrap_init(backward_pass), jaxpr)
552+
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
553553
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
554554
params = dict(params, name=wrap_name(params['name'], 'transpose'))
555555
out_flat = primitive.bind(fun, *all_args, **params)

jax/interpreters/partial_eval.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ def process_call(self, call_primitive, f, tracers, params):
136136
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
137137
out_tracers = [JaxprTracer(self, PartialVal((out_pv, out_pv_const)), None)
138138
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
139+
new_params = dict(params, call_jaxpr=lifted_jaxpr)
139140
# The `jaxpr` already contains the env_vars at start of invars
140141
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, tracers)),
141-
out_tracers, call_primitive, params,
142-
subjaxpr=lifted_jaxpr)
142+
out_tracers, call_primitive, new_params)
143143
for t in out_tracers:
144144
t.recipe = eqn
145145
return out_tracers
@@ -162,10 +162,10 @@ def process_map(self, map_primitive, f, tracers, params):
162162
new_params = dict(params,
163163
mapped_invars=tuple([True] * len(const_tracers) +
164164
[False] * len(env_tracers) +
165-
[True] * len(tracers)))
165+
[True] * len(tracers)),
166+
call_jaxpr=lifted_jaxpr)
166167
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, tracers)),
167-
out_tracers, map_primitive, new_params,
168-
subjaxpr=lifted_jaxpr)
168+
out_tracers, map_primitive, new_params)
169169
for t in out_tracers:
170170
t.recipe = eqn
171171
return out_tracers
@@ -187,10 +187,10 @@ def todo(x):
187187
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
188188
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
189189
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
190+
new_params = dict(params, call_jaxpr=lifted_jaxpr)
190191
# The `jaxpr` already contains the env_vars at start of invars
191192
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers)),
192-
out_tracers, call_primitive, params,
193-
subjaxpr=lifted_jaxpr)
193+
out_tracers, call_primitive, new_params)
194194
for t in out_tracers:
195195
t.recipe = eqn
196196
return out_tracers
@@ -215,11 +215,11 @@ def todo(x):
215215
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
216216
new_params = dict(params,
217217
mapped_invars=tuple([True] * len(const_tracers) +
218-
[False] * len(env)))
218+
[False] * len(env)),
219+
call_jaxpr=lifted_jaxpr)
219220
env_tracers = map(trace.full_raise, env)
220221
eqn = new_eqn_recipe(it.chain(const_tracers, env_tracers),
221-
out_tracers, map_primitive, new_params,
222-
subjaxpr=lifted_jaxpr)
222+
out_tracers, map_primitive, new_params)
223223
for t in out_tracers:
224224
t.recipe = eqn
225225
return out_tracers
@@ -383,38 +383,31 @@ def instantiate_const_at(trace, instantiate, tracer):
383383
ConstVar = namedtuple('ConstVar', ['val'])
384384
LambdaBinding = namedtuple('LambdaBinding', [])
385385
JaxprEqnRecipe = namedtuple('JaxprEqnRecipe',
386-
['eqn_id', 'invars', 'outvars', 'primitive',
387-
'bound_subjaxpr', 'params'])
386+
['eqn_id', 'invars', 'outvars', 'primitive', 'params'])
388387

389-
390-
def new_eqn_recipe(invars, outvars, primitive, params,
391-
subjaxpr=None):
388+
def new_eqn_recipe(invars, outvars, primitive, params):
392389
"""Constructs a new JaxEqnRecipe.
393390
394391
Params:
395392
invars: the tracers for the primitive inputs.
396393
outvars: the tracers for the primitive outputs.
397394
primitive: the primitive.
398395
params: the primitive params
399-
subjaxpr: (optional) a sub-Jaxpr, used only for `xla_call` or `xla_pmap`.
400-
If present, then `subjaxpr.invars` correspond to `invars.
401396
"""
402-
if subjaxpr is not None:
403-
assert len(subjaxpr.constvars) == 0
404-
assert len(subjaxpr.invars) == len(tuple(invars))
405-
bound_subjaxpr = subjaxpr
406-
else:
407-
bound_subjaxpr = None
408-
397+
if primitive.call_primitive:
398+
# TODO(necula): move these checks to core.check_jaxpr, and call it
399+
# in more places.
400+
assert "call_jaxpr" in params
409401
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
410-
bound_subjaxpr, params)
402+
params)
403+
411404

412405
def recipe_to_eqn(unused_var, getvar, recipe):
413-
_, in_tracers, out_tracer_refs, primitive, bound_subjaxpr, params = recipe
406+
_, in_tracers, out_tracer_refs, primitive, params = recipe
414407
out_tracers = [t_ref() for t_ref in out_tracer_refs]
415408
invars = [getvar(t) for t in in_tracers]
416409
outvars = [unused_var() if t is None else getvar(t) for t in out_tracers]
417-
return new_jaxpr_eqn(invars, outvars, primitive, bound_subjaxpr, params)
410+
return new_jaxpr_eqn(invars, outvars, primitive, params)
418411

419412
def tracers_to_jaxpr(in_tracers, out_tracers):
420413
"""Constructs Jaxpr given tracers for inputs and outputs.
@@ -520,6 +513,7 @@ def _split_aval(unknown, aval):
520513

521514

522515
remat_call_p = core.Primitive('remat_call')
516+
remat_call_p.call_primitive = True
523517
remat_call = partial(core.call_bind, remat_call_p)
524518
remat_call_p.def_custom_bind(remat_call)
525519
remat_call_p.def_impl(core.call_impl)
@@ -593,10 +587,9 @@ def _remat_partial_eval(trace, f, tracers, params):
593587
const_tracers = map(trace.new_instantiated_const, consts)
594588
lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr)
595589
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
590+
new_params = dict(params, call_jaxpr=lifted_jaxpr)
596591
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, instantiated_tracers)),
597-
out_tracers, remat_call_p,
598-
params,
599-
subjaxpr=lifted_jaxpr)
592+
out_tracers, remat_call_p, new_params)
600593
for t in out_tracers: t.recipe = eqn
601594
return out_tracers
602595
call_partial_eval_rules[remat_call_p] = _remat_partial_eval

jax/interpreters/pxla.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -640,14 +640,16 @@ def execute_replicated(compiled, backend, in_handler, out_handler, *args):
640640

641641

642642
xla_pmap_p = core.Primitive('xla_pmap')
643+
xla_pmap_p.call_primitive = True
643644
xla_pmap_p.multiple_results = True
644645
xla_pmap = partial(core.call_bind, xla_pmap_p)
645646
xla_pmap_p.def_custom_bind(xla_pmap)
646647
xla_pmap_p.def_impl(xla_pmap_impl)
647648

648-
def _pmap_translation_rule(c, jaxpr, axis_env,
649+
def _pmap_translation_rule(c, axis_env,
649650
in_nodes, name_stack, axis_name, axis_size,
650-
global_axis_size, devices, name, backend=None,
651+
global_axis_size, devices, name,
652+
call_jaxpr, backend=None,
651653
mapped_invars=None):
652654
# We in-line here rather than generating a Call HLO as in the xla_call
653655
# translation rule just because the extra tuple stuff is a pain.
@@ -662,7 +664,7 @@ def _pmap_translation_rule(c, jaxpr, axis_env,
662664
for in_node, in_node_mapped in zip(in_nodes, mapped_invars))
663665

664666
sharded_outs = xla.jaxpr_subcomp(
665-
c, jaxpr, backend, new_env, (),
667+
c, call_jaxpr, backend, new_env, (),
666668
extend_name_stack(name_stack, wrap_name(name, 'pmap')), *in_nodes_sharded)
667669
outs = [_xla_unshard(c, new_env, shard) for shard in sharded_outs]
668670
return c.Tuple(*outs)

0 commit comments

Comments
 (0)