Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ These are the release notes for JAX.
### Breaking changes

* The minimum jaxlib version is now 0.1.38.
* Simplified `Jaxpr` by removing the `Jaxpr.freevars` and changing the
representation of `Jaxpr.bound_subjaxprs` to drop the environment values.
* Simplified `Jaxpr` by removing the `Jaxpr.freevars` and
`Jaxpr.bound_subjaxprs`. The call primitives (`xla_call`, `xla_pmap`,
`sharded_call`, and `remat_call`) get a new parameter `call_jaxpr` with a
fully-closed (no `constvars`) JAXPR.

### New features

Expand Down
62 changes: 51 additions & 11 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ def __str__(self):
return str(pp_jaxpr(self))
__repr__ = __str__


def subjaxprs(jaxpr):
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
Does not descend recursively into the found subjaxprs.
"""
for eqn in jaxpr.eqns:
for param in eqn.params.values():
if type(param) is Jaxpr:
yield param
elif type(param) is TypedJaxpr:
yield param.jaxpr


class TypedJaxpr(object):
def __init__(self, jaxpr, literals, in_avals, out_avals):
assert type(jaxpr) is Jaxpr
Expand All @@ -84,8 +97,8 @@ def jaxpr_as_fun(typed_jaxpr, *args):
return eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, *args)


JaxprEqn = namedtuple('JaxprEqn', ['invars', 'outvars', 'primitive',
'bound_subjaxpr', 'params'])

JaxprEqn = namedtuple('JaxprEqn', ['invars', 'outvars', 'primitive', 'params'])
JaxprEqn.__repr__ = JaxprEqn.__str__ = lambda eqn: str(pp_eqn(eqn)).rstrip()
new_jaxpr_eqn = JaxprEqn

Expand Down Expand Up @@ -149,6 +162,8 @@ def __repr__(self):

class Primitive(object):
multiple_results = False # override for multi-output primitives
call_primitive = False # override for higher-order primitives that are
# processed in final style.

def __init__(self, name):
self.name = name
Expand Down Expand Up @@ -193,6 +208,24 @@ def abstract_eval(self, *args, **kwargs):

# -------------------- lifting --------------------

# TODO(necula): this belongs next to pe.new_eqn_recipe, but is needed in
# core.py. Plan to move all these utilities to jaxpr.py.
def extract_call_jaxpr(primitive, params):
"""Extract the call primitive subjaxpr from the params.

Params:
params: a parameter dictionary for a primitive.
Returns: the subjaxpr and the params without the "jaxpr" value. If this is
not a call primitive then returns (None, params).
"""
if not primitive.call_primitive:
return (None, params)
else:
assert "call_jaxpr" in params
new_params = dict(params)
del new_params["call_jaxpr"]
return (params["call_jaxpr"], new_params)


def eval_jaxpr(jaxpr, consts, *args):
def read(v):
Expand All @@ -210,11 +243,12 @@ def write(v, val):
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.invars)
if eqn.bound_subjaxpr:
subfuns = [lu.wrap_init(partial(eval_jaxpr, eqn.bound_subjaxpr, ()))]
call_jaxpr, params = extract_call_jaxpr(eqn.primitive, eqn.params)
if call_jaxpr:
subfuns = [lu.wrap_init(partial(eval_jaxpr, call_jaxpr, ()))]
else:
subfuns = []
ans = eqn.primitive.bind(*(subfuns + in_vals), **eqn.params)
ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
else:
Expand Down Expand Up @@ -626,6 +660,10 @@ def call_impl(f, *args, **params):
# ------------------- Jaxpr printed representation -------------------

def check_jaxpr(jaxpr):
"""Checks well-formedness of a jaxpr.

Specifically it checks that all variabled used are previously defined.
"""
def context():
return "\njaxpr:\n{}\n".format(jaxpr)

Expand All @@ -646,10 +684,16 @@ def write_env(env, v):
map(write, jaxpr.constvars)
map(write, jaxpr.invars)
for eqn in jaxpr.eqns:
if eqn.primitive.call_primitive:
if "call_jaxpr" not in eqn.params:
raise Exception("Call primitive {} should have a 'call_jaxpr' parameter"
.format(eqn.primitive))
map(read, eqn.invars)
if eqn.bound_subjaxpr:
check_jaxpr(eqn.bound_subjaxpr)
map(write, eqn.outvars)

for subjaxpr in subjaxprs(jaxpr):
check_jaxpr(subjaxpr)

map(read, jaxpr.outvars)


Expand All @@ -664,10 +708,6 @@ def pp_eqn_compact(primitive_name, params):
def pp_eqn(eqn):
lhs = pp_vars(eqn.outvars)
pp_subexpr = pp('')
if eqn.bound_subjaxpr:
pp_subexpr = pp_subexpr + (
pp_jaxpr(eqn.bound_subjaxpr).indent(2)
>> pp(' [ ]'))
return (pp('{} = '.format(lhs)) >>
pp(eqn.primitive.name) >> pp_kv_pairs(sorted(eqn.params.items()))
>> pp(' ') >> pp(pp_vars(eqn.invars))) + pp_subexpr
Expand Down
28 changes: 14 additions & 14 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def is_linear(var):

linear_eqns = []
for eqn in jaxpr.eqns:
if not eqn.bound_subjaxpr:
if not eqn.primitive.call_primitive:
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
else:
Expand All @@ -183,20 +183,20 @@ def is_linear(var):
else:
write_primal(eqn.outvars[0], ans)
else:
subjaxpr = eqn.bound_subjaxpr
call_jaxpr = eqn.params["call_jaxpr"]
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
elif eqn.primitive is not pe.remat_call_p:
ans = _eval_subjaxpr_primals(
eqn.primitive, subjaxpr,
eqn.primitive, call_jaxpr,
map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)

# we special-case remat_call here because it can be mixed linear /
# nonlinear, so we always evaluate it even if it has a linear part
if eqn.primitive is pe.remat_call_p:
ans = _eval_subjaxpr_primals(
eqn.primitive, subjaxpr,
eqn.primitive, call_jaxpr,
map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)

Expand All @@ -208,10 +208,10 @@ def is_linear(var):
cts_in = map(read_cotangent, eqn.outvars)
else:
cts_in, = map(read_cotangent, eqn.outvars)
if eqn.bound_subjaxpr:
subjaxpr = eqn.bound_subjaxpr
if eqn.primitive.call_primitive:
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
cts_out = get_primitive_transpose(eqn.primitive)(
eqn.params, subjaxpr, invals, cts_in)
params, call_jaxpr, invals, cts_in)
else:
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params)
cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out
Expand Down Expand Up @@ -251,7 +251,7 @@ def is_linear(var):
assert not jaxpr.constvars
map(write_primal, jaxpr.invars, args)
for eqn in jaxpr.eqns:
if not eqn.bound_subjaxpr:
if not eqn.primitive.call_primitive:
if not any(is_linear(v) for v in eqn.invars):
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
Expand All @@ -260,11 +260,11 @@ def is_linear(var):
else:
write_primal(eqn.outvars[0], ans)
else:
subjaxpr = eqn.bound_subjaxpr
call_jaxpr = eqn.params["call_jaxpr"]
if (eqn.primitive is pe.remat_call_p or
not any(is_linear(v) for v in eqn.invars)):
ans = _eval_subjaxpr_primals(
eqn.primitive, subjaxpr,
eqn.primitive, call_jaxpr,
map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)
return map(read_primal, jaxpr.outvars)
Expand Down Expand Up @@ -537,19 +537,19 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
yield out_flat, tree_def


def call_transpose(primitive, params, jaxpr, args, ct):
def call_transpose(primitive, params, call_jaxpr, args, ct):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), jaxpr)
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
params = dict(params, name=wrap_name(params['name'], 'transpose'))
out_flat = primitive.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)

def map_transpose(primitive, params, jaxpr, args, ct):
def map_transpose(primitive, params, call_jaxpr, args, ct):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), jaxpr)
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
params = dict(params, name=wrap_name(params['name'], 'transpose'))
out_flat = primitive.bind(fun, *all_args, **params)
Expand Down
53 changes: 23 additions & 30 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ def process_call(self, call_primitive, f, tracers, params):
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
out_tracers = [JaxprTracer(self, PartialVal((out_pv, out_pv_const)), None)
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
new_params = dict(params, call_jaxpr=lifted_jaxpr)
# The `jaxpr` already contains the env_vars at start of invars
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, tracers)),
out_tracers, call_primitive, params,
subjaxpr=lifted_jaxpr)
out_tracers, call_primitive, new_params)
for t in out_tracers:
t.recipe = eqn
return out_tracers
Expand All @@ -162,10 +162,10 @@ def process_map(self, map_primitive, f, tracers, params):
new_params = dict(params,
mapped_invars=tuple([True] * len(const_tracers) +
[False] * len(env_tracers) +
[True] * len(tracers)))
[True] * len(tracers)),
call_jaxpr=lifted_jaxpr)
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers, tracers)),
out_tracers, map_primitive, new_params,
subjaxpr=lifted_jaxpr)
out_tracers, map_primitive, new_params)
for t in out_tracers:
t.recipe = eqn
return out_tracers
Expand All @@ -187,10 +187,10 @@ def todo(x):
lifted_jaxpr = convert_constvars_jaxpr(jaxpr)
out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None)
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
new_params = dict(params, call_jaxpr=lifted_jaxpr)
# The `jaxpr` already contains the env_vars at start of invars
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, env_tracers)),
out_tracers, call_primitive, params,
subjaxpr=lifted_jaxpr)
out_tracers, call_primitive, new_params)
for t in out_tracers:
t.recipe = eqn
return out_tracers
Expand All @@ -215,11 +215,11 @@ def todo(x):
for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)]
new_params = dict(params,
mapped_invars=tuple([True] * len(const_tracers) +
[False] * len(env)))
[False] * len(env)),
call_jaxpr=lifted_jaxpr)
env_tracers = map(trace.full_raise, env)
eqn = new_eqn_recipe(it.chain(const_tracers, env_tracers),
out_tracers, map_primitive, new_params,
subjaxpr=lifted_jaxpr)
out_tracers, map_primitive, new_params)
for t in out_tracers:
t.recipe = eqn
return out_tracers
Expand Down Expand Up @@ -383,38 +383,31 @@ def instantiate_const_at(trace, instantiate, tracer):
ConstVar = namedtuple('ConstVar', ['val'])
LambdaBinding = namedtuple('LambdaBinding', [])
JaxprEqnRecipe = namedtuple('JaxprEqnRecipe',
['eqn_id', 'invars', 'outvars', 'primitive',
'bound_subjaxpr', 'params'])
['eqn_id', 'invars', 'outvars', 'primitive', 'params'])


def new_eqn_recipe(invars, outvars, primitive, params,
subjaxpr=None):
def new_eqn_recipe(invars, outvars, primitive, params):
"""Constructs a new JaxEqnRecipe.

Params:
invars: the tracers for the primitive inputs.
outvars: the tracers for the primitive outputs.
primitive: the primitive.
params: the primitive params
subjaxpr: (optional) a sub-Jaxpr, used only for `xla_call` or `xla_pmap`.
If present, then `subjaxpr.invars` correspond to `invars.
"""
if subjaxpr is not None:
assert len(subjaxpr.constvars) == 0
assert len(subjaxpr.invars) == len(tuple(invars))
bound_subjaxpr = subjaxpr
else:
bound_subjaxpr = None

if primitive.call_primitive:
# TODO(necula): move these checks to core.check_jaxpr, and call it
# in more places.
assert "call_jaxpr" in params
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
bound_subjaxpr, params)
params)


def recipe_to_eqn(unused_var, getvar, recipe):
_, in_tracers, out_tracer_refs, primitive, bound_subjaxpr, params = recipe
_, in_tracers, out_tracer_refs, primitive, params = recipe
out_tracers = [t_ref() for t_ref in out_tracer_refs]
invars = [getvar(t) for t in in_tracers]
outvars = [unused_var() if t is None else getvar(t) for t in out_tracers]
return new_jaxpr_eqn(invars, outvars, primitive, bound_subjaxpr, params)
return new_jaxpr_eqn(invars, outvars, primitive, params)

def tracers_to_jaxpr(in_tracers, out_tracers):
"""Constructs Jaxpr given tracers for inputs and outputs.
Expand Down Expand Up @@ -520,6 +513,7 @@ def _split_aval(unknown, aval):


remat_call_p = core.Primitive('remat_call')
remat_call_p.call_primitive = True
remat_call = partial(core.call_bind, remat_call_p)
remat_call_p.def_custom_bind(remat_call)
remat_call_p.def_impl(core.call_impl)
Expand Down Expand Up @@ -593,10 +587,9 @@ def _remat_partial_eval(trace, f, tracers, params):
const_tracers = map(trace.new_instantiated_const, consts)
lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr)
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
new_params = dict(params, call_jaxpr=lifted_jaxpr)
eqn = new_eqn_recipe(tuple(it.chain(const_tracers, instantiated_tracers)),
out_tracers, remat_call_p,
params,
subjaxpr=lifted_jaxpr)
out_tracers, remat_call_p, new_params)
for t in out_tracers: t.recipe = eqn
return out_tracers
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
Expand Down
8 changes: 5 additions & 3 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,14 +640,16 @@ def execute_replicated(compiled, backend, in_handler, out_handler, *args):


xla_pmap_p = core.Primitive('xla_pmap')
xla_pmap_p.call_primitive = True
xla_pmap_p.multiple_results = True
xla_pmap = partial(core.call_bind, xla_pmap_p)
xla_pmap_p.def_custom_bind(xla_pmap)
xla_pmap_p.def_impl(xla_pmap_impl)

def _pmap_translation_rule(c, jaxpr, axis_env,
def _pmap_translation_rule(c, axis_env,
in_nodes, name_stack, axis_name, axis_size,
global_axis_size, devices, name, backend=None,
global_axis_size, devices, name,
call_jaxpr, backend=None,
mapped_invars=None):
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
Expand All @@ -662,7 +664,7 @@ def _pmap_translation_rule(c, jaxpr, axis_env,
for in_node, in_node_mapped in zip(in_nodes, mapped_invars))

sharded_outs = xla.jaxpr_subcomp(
c, jaxpr, backend, new_env, (),
c, call_jaxpr, backend, new_env, (),
extend_name_stack(name_stack, wrap_name(name, 'pmap')), *in_nodes_sharded)
outs = [_xla_unshard(c, new_env, shard) for shard in sharded_outs]
return c.Tuple(*outs)
Expand Down
Loading