Skip to content

Improve the interface to JAX-compiled functions #1385

@rlouf

Description

@rlouf

Consider the following probabilistic model:

import aesara.tensor as at

X_at = at.matrix('X')

srng = at.random.RandomStream(0)

tau_rv = srng.halfcauchy(0, 1)
lambda_rv = srng.halfcauchy(0, 1, size=X_at.shape[-1])
beta_rv = srng.normal(0, tau_rv * lambda_rv, size=X_at.shape[-1])

eta = X_at @ beta_rv
p = at.sigmoid(-eta)
Y_rv = srng.bernoulli(p)

In a typical bayesian modeling workflow, we want to be able to generate and use three functions:

  1. A function that samples from the prior joint distribution (initialize the values)
  2. A function that computes the model's logdensity
  3. A function that computes posterior predictive sampling

In a workflow that uses JAX we typically want to be able to use jax.vmap with (1) and (3), and jax.grad with (2). While it is possible to do this with Aesara-compiled function, we need to go through unnecessary levels of indirections.

Current behavior

First, to use jax.grad we must use the vm.jit_fn attribute of the Aesara-compiled function and wrap the function so it returns a single value instead of a 1-element tuple:

import aesara
import aeppl
import jax
import numpy as np

logprob, vvs = aeppl.joint_logprob(tau_rv, lambda_rv, beta_rv, Y_rv)

logdensity_fn = aesara.function([X_at] + list(vvs), logprob, mode="JAX")
try:
    jax.grad(logdensity_fn)(np.ones((3,2)), 1., np.ones(2), np.ones(2), np.ones(2))
except Exception as e:
    print(e)
# Bad input argument to aesara function with name "<stdin>:22" at index 0 (0-based).
# Backtrace when that variable is created:

#   File "<stdin>", line 3, in <module>
# The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[1. 1.]
#  [1. 1.]
#  [1. 1.]], dtype=float64)>with<JVPTrace(level=2/0)> with
#   primal = array([[1., 1.],
#        [1., 1.],
#        [1., 1.]])
#   tangent = Traced<ShapedArray(float64[3,2])>with<JaxprTrace(level=1/0)> with
#     pval = (ShapedArray(float64[3,2]), None)
#     recipe = LambdaBinding()
# See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

logdensity_jit_fn = aesara.function([X_at] + list(vvs), logprob, mode="JAX").vm.jit_fn
try:
    jax.grad(logdensity_jit_fn)(np.ones((3,2)), np.array(1.), np.ones(2), np.ones(2), np.ones(3))
except Exception as e:
    print(e)
# Gradient only defined for scalar-output functions. Output was (Array(-12.65285075, dtype=float64),).

def logdensity_squeezed_jit_fn(*x):
    return logdensity_jit_fn(*x)[0]
print(jax.grad(logdensity_squeezed_jit_fn)(np.ones((3,2)), np.array(1.), np.ones(2), np.ones(2), np.ones(3)))
# [[-0.88079708 -0.88079708]
#  [-0.88079708 -0.88079708]
#  [-0.88079708 -0.88079708]]

To be able to use jax.vmap to sample multiple values from the prior distribution / do posterior predictive sampling we must also use the jit-compiled function directly. In addition, we must pass one PRNGKey per random variable in the graph which does not reflect the RandomStream mechanism, wrap them in a dictionary with the same structure as the internal random state, and pass them as the last arguments (first argument is idiomatic in JAX):

prior_sample_fn = aesara.function([X_at], [tau_rv, lambda_rv, beta_rv], mode="JAX").vm.jit_fn
rng_key = jax.random.PRNGKey(0)
key1, key2, key3 = jax.random.split(rng_key, 3)
print(prior_sample_fn(np.ones((2,3)), {"jax_state": key1}, {"jax_state": key2}, {"jax_state": key3}))

Expected behavior

I would expect to be able to use the compiled function just like any JAX function:

logdensity_fn = aesara.function([X_at] + list(vvs), logprob, mode="JAX")
jax.grad(logdensity_fn)(np.ones((3,2)), 1., np.ones(2), np.ones(2), np.ones(2))

With the assumption that for random functions the first argument must be a JAX PRNGKey:

prior_sample_fn = aesara.function([X_at], [tau_rv, lambda_rv, beta_rv], mode="JAX").vm.jit_fn

rng_key = jax.random.PRNGKey(0)
keys = jax.random.split(rng_key, 100)
samples = jax.vmap(prior_sample_fn, in_axes=(0, None))(keys, np.ones((2,3)))

Proposals

To make the compiled functions truly compatible with JAX and the rest of its ecosystem I suggest the following changes:

  1. Function.__call__ should be immediately compatible with JAX; we shouldn't have to fetch the vm.jit_fn attribute. An Aesara function is supposed to manage other things like updates of SharedVariables, but I am not sure this is necessary for JAX-compiled functions. Furthermore, if it is supposed to do something that normally cannot be done with JAX then it would be preferable to fail during transpilation rather than output a function that cannot be composed with the rest of the JAX ecosystem;
  2. Function.__call__ should not return a tuple when there is a single output;
  3. The internal random state should be represented by the rng_key directly, and not a dictionary that contains the key;
  4. When there are several random variables in the graph, we should only require one PRNG key;
  5. The PRNG key should be passed as the first input
  6. Do not return the updated PRNG keys with the functions that take PRNG keys as inputs;

Related to #1194

Metadata

Metadata

Assignees

No one assigned

    Labels

    JAXInvolves JAX transpilationenhancementNew feature or request

    Type

    No type

    Projects

    Status

    Backends

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions