Skip to content

danielward27/flowjax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

logo

FlowJax: Normalizing Flows in Jax

Documentation

Available here.

Short example

Training a flow can be done in a few lines of code:

from flowjax.flows import BlockNeuralAutoregressiveFlow
from flowjax.train import fit_to_data
from flowjax.distributions import Normal
from jax import random
import jax.numpy as jnp

data_key, flow_key, train_key = random.split(random.PRNGKey(0), 3)

x = random.uniform(data_key, (10000, 3))  # Toy data
base_dist = Normal(jnp.zeros(x.shape[1]))
flow = BlockNeuralAutoregressiveFlow(flow_key, base_dist)
flow, losses = fit_to_data(train_key, flow, x, learning_rate=0.05)

# We can now evaluate the log-probability of arbitrary points
flow.log_prob(x)

The package currently includes:

  • Many simple bijections and distributions, implemented as Equinox modules.
  • CouplingFlow (Dinh et al., 2017) and MaskedAutoregressiveFlow (Papamakarios et al., 2017) normalizing flow architectures.
    • These can be used with arbitrary bijections as transformers, such as Affine or RationalQuadraticSpline (the latter used in neural spline flows; Durkan et al., 2019).
  • BlockNeuralAutoregressiveFlow, as introduced by De Cao et al., 2019
  • TriangularSplineFlow, introduced here.
  • Training scripts for fitting by maximum likelihood, variational inference, or using contrastive learning for sequential neural posterior estimation (Greenberg et al., 2019; Durkan et al., 2020)

Installation

pip install flowjax

Development

We can install a version for development as follows

git clone https://github.com/danielward27/flowjax.git
cd flowjax
pip install -e .[dev]
sudo apt-get install pandoc  # Required for building documentation

Warning

This package is new and may have substantial breaking changes between major releases.

TODO

A few limitations / things that could be worth including in the future:

  • Add ability to "reshape" bijections.

Related

We make use of the Equinox package, which facilitates object-oriented programming with Jax.

Authors

flowjax was written by Daniel Ward <danielward27@outlook.com>.