-
-
Notifications
You must be signed in to change notification settings - Fork 151
Description
Consider the following probabilistic model:
import aesara.tensor as at
X_at = at.matrix('X')
srng = at.random.RandomStream(0)
tau_rv = srng.halfcauchy(0, 1)
lambda_rv = srng.halfcauchy(0, 1, size=X_at.shape[-1])
beta_rv = srng.normal(0, tau_rv * lambda_rv, size=X_at.shape[-1])
eta = X_at @ beta_rv
p = at.sigmoid(-eta)
Y_rv = srng.bernoulli(p)
In a typical bayesian modeling workflow, we want to be able to generate and use three functions:
- A function that samples from the prior joint distribution (initialize the values)
- A function that computes the model's logdensity
- A function that computes posterior predictive sampling
In a workflow that uses JAX we typically want to be able to use jax.vmap
with (1) and (3), and jax.grad
with (2). While it is possible to do this with Aesara-compiled function, we need to go through unnecessary levels of indirections.
Current behavior
First, to use jax.grad
we must use the vm.jit_fn
attribute of the Aesara-compiled function and wrap the function so it returns a single value instead of a 1-element tuple:
import aesara
import aeppl
import jax
import numpy as np
logprob, vvs = aeppl.joint_logprob(tau_rv, lambda_rv, beta_rv, Y_rv)
logdensity_fn = aesara.function([X_at] + list(vvs), logprob, mode="JAX")
try:
jax.grad(logdensity_fn)(np.ones((3,2)), 1., np.ones(2), np.ones(2), np.ones(2))
except Exception as e:
print(e)
# Bad input argument to aesara function with name "<stdin>:22" at index 0 (0-based).
# Backtrace when that variable is created:
# File "<stdin>", line 3, in <module>
# The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[1. 1.]
# [1. 1.]
# [1. 1.]], dtype=float64)>with<JVPTrace(level=2/0)> with
# primal = array([[1., 1.],
# [1., 1.],
# [1., 1.]])
# tangent = Traced<ShapedArray(float64[3,2])>with<JaxprTrace(level=1/0)> with
# pval = (ShapedArray(float64[3,2]), None)
# recipe = LambdaBinding()
# See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
logdensity_jit_fn = aesara.function([X_at] + list(vvs), logprob, mode="JAX").vm.jit_fn
try:
jax.grad(logdensity_jit_fn)(np.ones((3,2)), np.array(1.), np.ones(2), np.ones(2), np.ones(3))
except Exception as e:
print(e)
# Gradient only defined for scalar-output functions. Output was (Array(-12.65285075, dtype=float64),).
def logdensity_squeezed_jit_fn(*x):
return logdensity_jit_fn(*x)[0]
print(jax.grad(logdensity_squeezed_jit_fn)(np.ones((3,2)), np.array(1.), np.ones(2), np.ones(2), np.ones(3)))
# [[-0.88079708 -0.88079708]
# [-0.88079708 -0.88079708]
# [-0.88079708 -0.88079708]]
To be able to use jax.vmap
to sample multiple values from the prior distribution / do posterior predictive sampling we must also use the jit-compiled function directly. In addition, we must pass one PRNGKey
per random variable in the graph which does not reflect the RandomStream
mechanism, wrap them in a dictionary with the same structure as the internal random state, and pass them as the last arguments (first argument is idiomatic in JAX):
prior_sample_fn = aesara.function([X_at], [tau_rv, lambda_rv, beta_rv], mode="JAX").vm.jit_fn
rng_key = jax.random.PRNGKey(0)
key1, key2, key3 = jax.random.split(rng_key, 3)
print(prior_sample_fn(np.ones((2,3)), {"jax_state": key1}, {"jax_state": key2}, {"jax_state": key3}))
Expected behavior
I would expect to be able to use the compiled function just like any JAX function:
logdensity_fn = aesara.function([X_at] + list(vvs), logprob, mode="JAX")
jax.grad(logdensity_fn)(np.ones((3,2)), 1., np.ones(2), np.ones(2), np.ones(2))
With the assumption that for random functions the first argument must be a JAX PRNGKey
:
prior_sample_fn = aesara.function([X_at], [tau_rv, lambda_rv, beta_rv], mode="JAX").vm.jit_fn
rng_key = jax.random.PRNGKey(0)
keys = jax.random.split(rng_key, 100)
samples = jax.vmap(prior_sample_fn, in_axes=(0, None))(keys, np.ones((2,3)))
Proposals
To make the compiled functions truly compatible with JAX and the rest of its ecosystem I suggest the following changes:
Function.__call__
should be immediately compatible with JAX; we shouldn't have to fetch thevm.jit_fn
attribute. An Aesara function is supposed to manage other things like updates ofSharedVariable
s, but I am not sure this is necessary for JAX-compiled functions. Furthermore, if it is supposed to do something that normally cannot be done with JAX then it would be preferable to fail during transpilation rather than output a function that cannot be composed with the rest of the JAX ecosystem;Function.__call__
should not return a tuple when there is a single output;- The internal random state should be represented by the
rng_key
directly, and not a dictionary that contains the key; - When there are several random variables in the graph, we should only require one PRNG key;
- The PRNG key should be passed as the first input
- Do not return the updated PRNG keys with the functions that take PRNG keys as inputs;
Related to #1194
Metadata
Metadata
Assignees
Labels
Type
Projects
Status