Available here.
Training a flow can be done in a few lines of code:
from flowjax.flows import BlockNeuralAutoregressiveFlow
from flowjax.train import fit_to_data
from flowjax.distributions import Normal
from jax import random
import jax.numpy as jnp
data_key, flow_key, train_key = random.split(random.PRNGKey(0), 3)
x = random.uniform(data_key, (10000, 3)) # Toy data
base_dist = Normal(jnp.zeros(x.shape[1]))
flow = BlockNeuralAutoregressiveFlow(flow_key, base_dist)
flow, losses = fit_to_data(train_key, flow, x, learning_rate=0.05)
# We can now evaluate the log-probability of arbitrary points
flow.log_prob(x)
The package currently includes:
- Many simple bijections and distributions, implemented as Equinox modules.
CouplingFlow
(Dinh et al., 2017) andMaskedAutoregressiveFlow
(Papamakarios et al., 2017) normalizing flow architectures.- These can be used with arbitrary bijections as transformers, such as
Affine
orRationalQuadraticSpline
(the latter used in neural spline flows; Durkan et al., 2019).
- These can be used with arbitrary bijections as transformers, such as
BlockNeuralAutoregressiveFlow
, as introduced by De Cao et al., 2019TriangularSplineFlow
, introduced here.- Training scripts for fitting by maximum likelihood, variational inference, or using contrastive learning for sequential neural posterior estimation (Greenberg et al., 2019; Durkan et al., 2020)
pip install flowjax
We can install a version for development as follows
git clone https://github.com/danielward27/flowjax.git
cd flowjax
pip install -e .[dev]
sudo apt-get install pandoc # Required for building documentation
This package is new and may have substantial breaking changes between major releases.
A few limitations / things that could be worth including in the future:
- Add ability to "reshape" bijections.
We make use of the Equinox package, which facilitates object-oriented programming with Jax.
flowjax
was written by Daniel Ward <danielward27@outlook.com>
.