Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blackjax SMC from pymc models #267

Conversation

ciguaran
Copy link
Contributor

Allows to sample Pymc's models using Blackjax's Sequential Monte Carlo implementations. Apart from getting a Jax-based implementation for Blackjax, this PR allows for using HMC and NUTS as kernels, which aren't available in the existing PyMC implementation of SMC. Moreover, diagnosis are exposed and stored in the resulting arviz.InferenceData

In order to sample using BJ SMC, we need

  • code to compute jaxified logprior and loglikelihood functions, applyable over SMC Particles.
  • code to compute diagnosis over the SMC run.
  • code to build an arviz.InferenceData object from the sampler's output.

@ciguaran
Copy link
Contributor Author

@junpenglao it seems that blackjax is not included in windows-environment-test.yml although is present in environment-test.yml do you know the reason for that?

@ricardoV94
Copy link
Member

@ciguaran because JAX isn't/wasn't compatible with Windows

pymc_experimental/inference/smc/sampling.py Outdated Show resolved Hide resolved
pymc_experimental/inference/smc/sampling.py Outdated Show resolved Hide resolved
inverse_mass_matrix=jnp.eye(posterior_dimensions),
num_integration_steps=inner_kernel_params["integration_steps"],
)
elif kernel == "NUTS":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add also a random walk metropolis so we can compare with the SMC in PyMC easier?
Also NUTS-SMC is not very well tested IIUC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the idea, but it has been difficult to use the algorithm without any tuning on the inner kernel parameters between SMC runs. I am planning on merging the optimization PR on blackjax first and then use it from here. The SMC in Pymc does apply those optimizations.

@ciguaran
Copy link
Contributor Author

@ciguaran because JAX isn't/wasn't compatible with Windows

gotcha, it seems they have an experimental version for windows, I've added it in the reqs to see what happens, otherwise we can remove the test from the windows suite if you agree @ricardoV94

@jessegrabowski
Copy link
Member

Can you add a notebook example showing how to use this?

Copy link
Member

@aloctavodia aloctavodia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few changes to the docstrings and this is good to go.

pymc_experimental/inference/smc/sampling.py Outdated Show resolved Hide resolved
pymc_experimental/inference/smc/sampling.py Outdated Show resolved Hide resolved
pymc_experimental/inference/smc/sampling.py Outdated Show resolved Hide resolved
pymc_experimental/inference/smc/sampling.py Outdated Show resolved Hide resolved
@aloctavodia aloctavodia merged commit 5fc0463 into pymc-devs:main Nov 22, 2023
7 checks passed
@aloctavodia
Copy link
Member

thanks @ciguaran!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants