Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

60 add vi helper #63

Merged
merged 24 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3883c16
Add vi fit function
Tennessee-Wallaceh Jan 9, 2023
187c862
Merge branch 'main' of github.com:Tennessee-Wallaceh/flowjax into 60-…
Tennessee-Wallaceh Jan 13, 2023
0ff775e
Update the VI example notebook
Tennessee-Wallaceh Jan 13, 2023
104642e
Remove "safe" elbo loss
Tennessee-Wallaceh Jan 16, 2023
3237e85
Revert flows and update to new sample API
Tennessee-Wallaceh Jan 16, 2023
2f1a220
Add slightly more docstring
Tennessee-Wallaceh Jan 16, 2023
7133954
Tidy imports
Tennessee-Wallaceh Jan 16, 2023
bf9bb2f
Update examples
Tennessee-Wallaceh Jan 16, 2023
3a7f6b4
Update docs
Tennessee-Wallaceh Jan 16, 2023
5d231ac
Update docs
Tennessee-Wallaceh Jan 16, 2023
513fe22
Move test_train_utils to new test dir
Tennessee-Wallaceh Jan 16, 2023
8238b17
Add elbo loss test
Tennessee-Wallaceh Jan 16, 2023
7c89354
Add test for VI function
Tennessee-Wallaceh Jan 16, 2023
7b35961
Merge branch 'danielward27:main' into 60-add-vi-helper
Tennessee-Wallaceh Jan 18, 2023
c7826e9
Update to use fn rather than fcn abbreviation
Tennessee-Wallaceh Jan 18, 2023
55bbc96
Remove recorder from training script
Tennessee-Wallaceh Jan 18, 2023
fdbdeb1
Update naming of train scripts
Tennessee-Wallaceh Jan 18, 2023
1aaf226
Update tests
Tennessee-Wallaceh Jan 18, 2023
a249b99
Update docs
Tennessee-Wallaceh Jan 18, 2023
f864ac2
Remove distribution mock from VI testing
Tennessee-Wallaceh Jan 18, 2023
4e58a9f
Add VI example to toctree
Tennessee-Wallaceh Jan 18, 2023
9210c67
Update README
Tennessee-Wallaceh Jan 18, 2023
88a687e
Remove type aliases
Tennessee-Wallaceh Jan 18, 2023
3e57c6c
fix docstring
Tennessee-Wallaceh Jan 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update tests
  • Loading branch information
Tennessee-Wallaceh committed Jan 18, 2023
commit 1aaf226c47d2d844c2a5190b2a30caaa820af1d4
2 changes: 1 addition & 1 deletion flowjax/train/variational_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ def step(dist, target, key, optimizer, opt_state):
if show_progress:
loop.set_postfix({'loss': losses[-1]})

return dist, losses, optimizer
return dist, losses
8 changes: 3 additions & 5 deletions tests/train/test_variational_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from flowjax.train.variational_fit import (
elbo_loss,
variational_fit,
fit_to_variational_target,
VariationalLoss,
VariationalTarget
)
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_elbo_loss(mocker, distribution, target, shape):
assert loss.shape == () # expect scalar loss
assert jnp.isfinite(loss) # expect finite loss

def test_variational_fit_e2e():
def test_fit_to_variational_target_e2e():
# A simple E2E test to make sure that the function runs without error
flow_random_key = random.PRNGKey(10)
flow = MaskedAutoregressiveFlow(
Expand All @@ -74,16 +74,14 @@ def test_variational_fit_e2e():
target = target_dist.log_prob

train_random_key = random.PRNGKey(0)
trained_flow, losses, record = variational_fit(
trained_flow, losses = fit_to_variational_target(
key=train_random_key,
dist=flow,
target=target,
num_epochs=10,
show_progress=False,
)

assert record is None

# Check that we have trained the flow
initial_params, initial_static = eqx.partition(flow, eqx.is_inexact_array)
trained_params, trained_static = eqx.partition(trained_flow, eqx.is_inexact_array)
Expand Down