-
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Blackjax SMC from pymc models #267
Blackjax SMC from pymc models #267
Conversation
@junpenglao it seems that blackjax is not included in |
@ciguaran because JAX isn't/wasn't compatible with Windows |
inverse_mass_matrix=jnp.eye(posterior_dimensions), | ||
num_integration_steps=inner_kernel_params["integration_steps"], | ||
) | ||
elif kernel == "NUTS": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add also a random walk metropolis so we can compare with the SMC in PyMC easier?
Also NUTS-SMC is not very well tested IIUC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the idea, but it has been difficult to use the algorithm without any tuning on the inner kernel parameters between SMC runs. I am planning on merging the optimization PR on blackjax first and then use it from here. The SMC in Pymc does apply those optimizations.
gotcha, it seems they have an experimental version for windows, I've added it in the reqs to see what happens, otherwise we can remove the test from the windows suite if you agree @ricardoV94 |
Can you add a notebook example showing how to use this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few changes to the docstrings and this is good to go.
thanks @ciguaran! |
Allows to sample Pymc's models using Blackjax's Sequential Monte Carlo implementations. Apart from getting a Jax-based implementation for Blackjax, this PR allows for using HMC and NUTS as kernels, which aren't available in the existing PyMC implementation of SMC. Moreover, diagnosis are exposed and stored in the resulting arviz.InferenceData
In order to sample using BJ SMC, we need