Skip to content

Commit

Permalink
Merge pull request #179 from danielward27/danielward27-patch-1
Browse files Browse the repository at this point in the history
Update README.md
  • Loading branch information
danielward27 authored Sep 17, 2024
2 parents 1a63239 + 8aea42f commit 73163c9
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,23 +1,37 @@
# FlowJAX: Distributions, bijections and normalizing Flows in Jax

![FlowJAX](/docs/_static/logo_light.svg)

Distributions, bijections and normalizing flows using Equinox and JAX
-----------------------------------------------------------------------
- Includes a wide range of distributions and bijections.
- Distributions and bijections are PyTrees, registered through
[Equinox](https://github.com/patrick-kidger/equinox/) modules, making them
compatible with [JAX](https://github.com/google/jax) transformations.
- Includes many state of the art normalizing flow models.
- First class support for conditional distributions and density estimation.

## Documentation
Available [here](https://danielward27.github.io/flowjax/index.html).

## Short example
Training a flow can be done in a few lines of code:
As an example we will create and train a normalizing flow model to toy data in just a few lines of code:

```python
from flowjax.flows import block_neural_autoregressive_flow
from flowjax.train import fit_to_data
from flowjax.distributions import Normal
from jax import random
import jax.random as jr
import jax.numpy as jnp

data_key, flow_key, train_key, sample_key = random.split(random.key(0), 4)
data_key, flow_key, train_key, sample_key = jr.split(jr.key(0), 4)

x = jr.uniform(data_key, (5000, 2)) # Toy data

flow = block_neural_autoregressive_flow(
key=flow_key,
base_dist=Normal(jnp.zeros(x.shape[1])),
)

x = random.uniform(data_key, (5000, 2)) # Toy data
base_dist = Normal(jnp.zeros(x.shape[1]))
flow = block_neural_autoregressive_flow(flow_key, base_dist=base_dist)
flow, losses = fit_to_data(
key=train_key,
dist=flow,
Expand Down

0 comments on commit 73163c9

Please sign in to comment.