-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
60 add vi helper #63
60 add vi helper #63
Conversation
Nice, I can add in the sample_and_log_prob method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added some comments, thanks again for adding this :)
I've added the |
By the way, I have no strong feelings about "train" vs "fit", but feel that we should be consistent and use one or the other and not both. |
How about |
Initial solution to #60.
Key points:
Creates a new
flowjax.train
submodule, for organisational purposes. The old train utils are moved toflowjax.train.data_fit
New VI helper is in
flowjax.train.variational_fit
Doesn't include a
sample_and_log_prob
method, which can come in separate work to keep the PRs simplerI am currently adding some tests