Skip to content

Latest commit

 

History

History
38 lines (25 loc) · 1.23 KB

README.md

File metadata and controls

38 lines (25 loc) · 1.23 KB

flowjax


Normalising flow implementations in jax. Training a flow can be done in a few lines of code

from flowjax.flows import BlockNeuralAutoregressiveFlow
from flowjax.train_utils import train_flow
from jax import random

data_key, flow_key, train_key = random.split(random.PRNGKey(0), 3)

x = random.uniform(data_key, (10000, 3))  # Toy data
flow = BlockNeuralAutoregressiveFlow(flow_key, target_dim=3)
flow, losses = train_flow(train_key, flow, x, learning_rate=0.05)

# We can now evaluate the log-probability of arbitrary points
flow.log_prob(x)

So far the package supports the following:

  • Affine coupling flows/RealNVP for conditional/unconditional density estimation (Dinh et al.)

  • Neural spline coupling flows for conditional/unconditional density estimation (Durkan et al.)

  • Block neural autoregressive flows for conditional/unconditional density estimation (De Cao et al.)

For more detailed examples, see examples.

Installation

pip install flowjax

Authors

flowjax was written by Daniel Ward <danielward27@outlook.com>.