Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

60 add vi helper #63

Merged
merged 24 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3883c16
Add vi fit function
Tennessee-Wallaceh Jan 9, 2023
187c862
Merge branch 'main' of github.com:Tennessee-Wallaceh/flowjax into 60-…
Tennessee-Wallaceh Jan 13, 2023
0ff775e
Update the VI example notebook
Tennessee-Wallaceh Jan 13, 2023
104642e
Remove "safe" elbo loss
Tennessee-Wallaceh Jan 16, 2023
3237e85
Revert flows and update to new sample API
Tennessee-Wallaceh Jan 16, 2023
2f1a220
Add slightly more docstring
Tennessee-Wallaceh Jan 16, 2023
7133954
Tidy imports
Tennessee-Wallaceh Jan 16, 2023
bf9bb2f
Update examples
Tennessee-Wallaceh Jan 16, 2023
3a7f6b4
Update docs
Tennessee-Wallaceh Jan 16, 2023
5d231ac
Update docs
Tennessee-Wallaceh Jan 16, 2023
513fe22
Move test_train_utils to new test dir
Tennessee-Wallaceh Jan 16, 2023
8238b17
Add elbo loss test
Tennessee-Wallaceh Jan 16, 2023
7c89354
Add test for VI function
Tennessee-Wallaceh Jan 16, 2023
7b35961
Merge branch 'danielward27:main' into 60-add-vi-helper
Tennessee-Wallaceh Jan 18, 2023
c7826e9
Update to use fn rather than fcn abbreviation
Tennessee-Wallaceh Jan 18, 2023
55bbc96
Remove recorder from training script
Tennessee-Wallaceh Jan 18, 2023
fdbdeb1
Update naming of train scripts
Tennessee-Wallaceh Jan 18, 2023
1aaf226
Update tests
Tennessee-Wallaceh Jan 18, 2023
a249b99
Update docs
Tennessee-Wallaceh Jan 18, 2023
f864ac2
Remove distribution mock from VI testing
Tennessee-Wallaceh Jan 18, 2023
4e58a9f
Add VI example to toctree
Tennessee-Wallaceh Jan 18, 2023
9210c67
Update README
Tennessee-Wallaceh Jan 18, 2023
88a687e
Remove type aliases
Tennessee-Wallaceh Jan 18, 2023
3e57c6c
fix docstring
Tennessee-Wallaceh Jan 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update naming of train scripts
  • Loading branch information
Tennessee-Wallaceh committed Jan 18, 2023
commit fdbdeb16dbbfc8bead667bbb4ba81d81d3c6baa0
6 changes: 3 additions & 3 deletions docs/examples/conditional.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"import numpy as onp\n",
"from flowjax.flows import BlockNeuralAutoregressiveFlow\n",
"from flowjax.distributions import Normal\n",
"from flowjax.train import train_flow\n",
"from flowjax.train import fit_to_data\n",
"import matplotlib.pyplot as plt"
]
},
Expand Down Expand Up @@ -83,7 +83,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 88%|████████▊ | 44/50 [00:07<00:01, 5.65it/s, train=1.27, val=1.47 (Max patience reached)]\n"
" 88%|████████▊ | 44/50 [00:09<00:01, 4.85it/s, train=1.27, val=1.47 (Max patience reached)]\n"
]
}
],
Expand All @@ -97,7 +97,7 @@
")\n",
"\n",
"key, subkey = jr.split(key)\n",
"flow, losses = train_flow(subkey, flow, x, u, learning_rate=1e-2, max_patience=10)"
"flow, losses = fit_to_data(subkey, flow, x, u, learning_rate=1e-2, max_patience=10)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/unconditional.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"from flowjax.flows import MaskedAutoregressiveFlow\n",
"from flowjax.train import train_flow\n",
"from flowjax.train import fit_to_data\n",
"from flowjax.distributions import Normal\n",
"from flowjax.bijections import RationalQuadraticSpline\n",
"import matplotlib.pyplot as plt\n",
Expand Down Expand Up @@ -109,7 +109,7 @@
],
"source": [
"key, subkey = jr.split(key)\n",
"flow, losses = train_flow(subkey, flow, x, learning_rate=1e-3)"
"flow, losses = fit_to_data(subkey, flow, x, learning_rate=1e-3)"
]
},
{
Expand All @@ -128,7 +128,7 @@
{
"data": {
"text/plain": [
"DeviceArray([-6.6872644, -6.6872644, -6.6872644, -6.6872644, -6.6872644], dtype=float32)"
"Array([-6.6872644, -6.6872644, -6.6872644, -6.6872644, -6.6872644], dtype=float32)"
]
},
"execution_count": 5,
Expand Down
20 changes: 10 additions & 10 deletions docs/examples/variational_inference.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions flowjax/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from flowjax.train.data_fit import train_flow
from flowjax.train.variational_fit import variational_fit
from flowjax.train.data_fit import fit_to_data
from flowjax.train.variational_fit import fit_to_variational_target
2 changes: 1 addition & 1 deletion flowjax/train/data_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from flowjax.utils import Array


def train_flow(
def fit_to_data(
key: KeyArray,
dist: Distribution,
x: Array,
Expand Down
26 changes: 12 additions & 14 deletions flowjax/train/variational_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@

@eqx.filter_jit
def elbo_loss(dist: Distribution, target: VariationalTarget, key: random.KeyArray, elbo_samples: int = 500):
samples = dist.sample(key, sample_shape=(elbo_samples,))
approx_density = dist.log_prob(samples).reshape(-1)
target_density = target(samples).reshape(-1)
samples, approx_density = dist.sample_and_log_prob(key, sample_shape=(elbo_samples,))
target_density = target(samples)
losses = approx_density - target_density
return losses.mean()

def variational_fit(
def fit_to_variational_target(
key: random.KeyArray,
dist: Distribution,
target: VariationalTarget,
loss_fn: VariationalLoss = elbo_loss,
learning_rate: float = 5e-4,
clip_norm: float = 0.5,
num_epochs: int = 100,
optimizer: Optional[optax.GradientTransformation] = None,
show_progress: bool = True,
):
"""
Expand All @@ -44,9 +44,9 @@ def variational_fit(
dist (Distribution): Distribution object, trainable parameters are found using equinox.is_inexact_array.
target (VariationalTarget): The target (usually) unormalized log posterior.
loss_fcn (VariationalLoss, optional): Loss function. Defaults to elbo_loss.
learning_rate (float, optional): Adam learning rate. Defaults to 5e-4.
clip_norm (float, optional): Maximum gradient norm before clipping occurs. Defaults to 0.5.
num_epochs (int, optional): The number of training steps to run. Defaults to 100.
optimizer (Optional[optax.Optimizer], optional): An optax optimizer (optimizers are implemented as GradientTransformation objects).
Defaults to an adam optimizer with learning rate 5e-4.
show_progress (bool, optional): Whether to show progress bar. Defaults to True.
"""
@eqx.filter_jit
Expand All @@ -56,12 +56,10 @@ def step(dist, target, key, optimizer, opt_state):
dist = eqx.apply_updates(dist, updates)
return dist, opt_state, loss_val

# Set up a default optimizer if None is provided
if optimizer is None:
optimizer = optax.chain(
optax.clip_by_global_norm(0.5),
optax.adam(learning_rate=5e-4)
)
optimizer = optax.chain(
optax.clip_by_global_norm(clip_norm),
optax.adam(learning_rate=learning_rate)
)

trainable_params, _ = eqx.partition(dist, eqx.is_inexact_array)
opt_state = optimizer.init(trainable_params)
Expand All @@ -78,4 +76,4 @@ def step(dist, target, key, optimizer, opt_state):
if show_progress:
loop.set_postfix({'loss': losses[-1]})

return dist, losses
return dist, losses, optimizer
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flowjax.distributions import Normal, Transformed
from flowjax.train.data_fit import (
count_fruitless,
train_flow,
fit_to_data,
train_val_split,
)

Expand Down Expand Up @@ -40,7 +40,7 @@ def test_train_flow_filter_spec():
# All params should change by default
before = eqx.filter(flow, eqx.is_inexact_array)
x = random.normal(random.PRNGKey(0), (100, dim))
flow, _ = train_flow(random.PRNGKey(0), flow, x, max_epochs=1, batch_size=50)
flow, _ = fit_to_data(random.PRNGKey(0), flow, x, max_epochs=1, batch_size=50)
after = eqx.filter(flow, eqx.is_inexact_array)

assert jnp.all(before.base_dist.bijection.loc != after.base_dist.bijection.loc)
Expand All @@ -50,7 +50,7 @@ def test_train_flow_filter_spec():
before = eqx.filter(flow, eqx.is_inexact_array)
filter_spec = jtu.tree_map(lambda x: eqx.is_inexact_array(x), flow)
filter_spec = eqx.tree_at(lambda tree: tree.base_dist, filter_spec, replace=False)
flow, _ = train_flow(
flow, _ = fit_to_data(
random.PRNGKey(0), flow, x, max_epochs=1, batch_size=50, filter_spec=filter_spec
)
after = eqx.filter(flow, eqx.is_inexact_array)
Expand Down