Skip to content

Commit bbe5c15

Browse files
AdrienCorenflosjunpenglaociguaranksnxrGaetanLepage
authored
Merge remote to fork (#1)
* Update README.md (blackjax-devs#638) * Update README.md Update citation. * Update README.md * Indexing the notebook showing how to reproduce the GIF. (blackjax-devs#640) Co-authored-by: Junpeng Lao <junpenglao@gmail.com> * Bump python version (blackjax-devs#645) * Bump python version * update bool inverse * SMC: allow each mutation kernel to have different parameters. (blackjax-devs#649) * vmaping over parameters in base * switch from mcmc_factory to just passing in parameters * pre-commit and typing * CRU and docs improvement * pre-commit * code review updates * pre-commit * rename test * Migrate from deprecated `host_callback` to `io_callback` (blackjax-devs#651) * Migrate from deprecated `host_callback` to `io_callback` Co-Authored-By: George Necula <gnecula@users.noreply.github.com> * Format file * Fix bug * Fix MALA transition energy (blackjax-devs#653) * Fix MALA transition energy * Use a different logic. * Change variable names (blackjax-devs#654) * Replace iterative RNG split and carry with `jax.random.fold_in` (blackjax-devs#656) * Replace iterative RNG split and carry with `jax.random.fold_in` * revert unintended change * file formatting * change `jax.tree_map` to `jax.tree.map` * revert unintended file * fiddle with rng_key * seed again * Removal of Algorithm classes. (blackjax-devs#657) * more * removing export * removal of classes, tests passing * linter * fix on test * linter * removing parametrization on test * code review updates * exporting as_top_level_api in dynamic_hmc * linter * code review update: replace imports * Fix deprecated call to jnp.clip (blackjax-devs#664) * Update jax version requirements (blackjax-devs#666) Fix blackjax-devs#665 * Make tests pass on `aarch64-linux` (blackjax-devs#671) * Enable fitlering of AdaptationInfo (blackjax-devs#674) * enable AdaptationInfo filtering * revert progress_bar * fix pre-commit * fix empty sets * enable adapt info filtering for all adaptation algorithms * fix precommit /progressbar=True * change filter tuple to use tree_map * Update `run_inference_algorithm` to split `initial_position` and `initial_state` (blackjax-devs#672) * UPDATE DOCSTRING * ADD STREAMING VERSION * UPDATE TESTS * ADD DOCSTRING * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * ADD INITIAL_POSITION * FIX TEST * RENAME O * FIX DOCSTRING * PUT EXPECTATION AFTER TRANSFORM * Preconditioned mclmc (blackjax-devs#673) * TESTS * TESTS * UPDATE DOCSTRING * ADD STREAMING VERSION * ADD PRECONDITIONING TO MCLMC * ADD PRECONDITIONING TO TUNING FOR MCLMC * UPDATE GITIGNORE * UPDATE GITIGNORE * UPDATE TESTS * UPDATE TESTS * ADD DOCSTRING * ADD TEST * STREAMING AVERAGE * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * GITIGNORE * PRECOMMIT CLEAN UP * ADD INITIAL_POSITION * FIX TEST * ADD TEST * REMOVE BENCHMARKS * BUG FIX * CHANGE PRECISION * CHANGE PRECISION * RENAME O * UPDATE STREAMING AVG * UPDATE PR * RENAME STD_MAT * New integrator, and add some metadata to integrators.py (blackjax-devs#681) * TESTS * TESTS * UPDATE DOCSTRING * ADD STREAMING VERSION * ADD PRECONDITIONING TO MCLMC * ADD PRECONDITIONING TO TUNING FOR MCLMC * UPDATE GITIGNORE * UPDATE GITIGNORE * UPDATE TESTS * UPDATE TESTS * ADD DOCSTRING * ADD TEST * STREAMING AVERAGE * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * GITIGNORE * PRECOMMIT CLEAN UP * FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS * TEMPORARILY ADD BENCHMARKS * ADD INITIAL_POSITION * FIX TEST * CLEAN UP * REMOVE BENCHMARKS * ADD TEST * REMOVE BENCHMARKS * BUG FIX * CHANGE PRECISION * CHANGE PRECISION * ADD OMELYAN TEST * RENAME O * UPDATE STREAMING AVG * UPDATE PR * RENAME STD_MAT * MERGE MAIN * REMOVE COEFFICIENT EXPORTS * Minor formatting (blackjax-devs#685) * Minor formatting * formatting * fix test * formatting * MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT (blackjax-devs#687) * FIX KWARG BUG (blackjax-devs#686) * FIX KWARG BUG * FIX KWARG BUG * Change isokinetic_integrator generation API (blackjax-devs#689) * Apply function on pytree directly. (blackjax-devs#692) * Apply function on pytree directly. Avoiding unnecssary unpacking * Fix kwarg * Fix sampling test. (blackjax-devs#693) * Enable shared mcmc parameters with tempered smc (blackjax-devs#694) * add parameter filtering * fix parameter split + docstring * change extend_paramss * convert to bit twiddling (blackjax-devs#696) * Remove nightly release (blackjax-devs#699) * Fix doc mistakes (blackjax-devs#701) * Fix equation formatting * Clarify JAX gradient error * Fix punctuation + capitalization * Fix grammar Should not begin sentence with "i.e." in English. * Fix math formatting error * Fix typo Change parallel _ensample_ chain adaptation to parallel _ensemble_ chain adaptation. * Add SVGD citation to appear in doc Currently the SVGD paper is only cited in the `kernel` function, which is defined _within_ the `build_kernel` function. Because of this nested function format, the SVGD paper is _not_ cited in the documentation. To fix this, I added a citation to the SVGD paper in the `as_top_level_api` docstring. * Fix grammar + clarify doc * Fix typo --------- Co-authored-by: Junpeng Lao <junpenglao@gmail.com> * Update index.md (blackjax-devs#711) The jitted step remained unused, leading to the example running with an uncompiled nuts.step. Changing this reduces the execution time by a factor of 30 on my system and showcases blackjax' speed. * Enable progress bar under pmap (blackjax-devs#712) * enable pmap progbar * fix bar creation * add locking * fix formatting * switch to using chain state * remove labels (blackjax-devs#716) * Simplify `run_inference_algorithm` (blackjax-devs#714) * fix minor type errors * storing only expectation values * fixed memory efficient sampling * clean up * renaming vars * precommit fixes * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * merge main * burn in and fix tests * burn in and fix tests * minor fixes * minor fixes * minor fixes --------- Co-authored-by: jakob.robnik@gmail.com <jakob.robnik@gmail.com> * Harmonize Quickstart example (blackjax-devs#717) * Update README.md (blackjax-devs#719) --------- Co-authored-by: Junpeng Lao <junpenglao@gmail.com> Co-authored-by: Carlos Iguaran <ciguaran@users.noreply.github.com> Co-authored-by: ksnxr <70186663+ksnxr@users.noreply.github.com> Co-authored-by: Gaétan Lepage <33058747+GaetanLepage@users.noreply.github.com> Co-authored-by: Alberto Cabezas <a.cabezasgonzalez@lancaster.ac.uk> Co-authored-by: andrewdipper <andrewpratt333@gmail.com> Co-authored-by: Reuben <reubenharry@users.noreply.github.com> Co-authored-by: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Co-authored-by: johannahaffner <38662446+johannahaffner@users.noreply.github.com> Co-authored-by: jakob.robnik@gmail.com <jakob.robnik@gmail.com>
1 parent 2e7f024 commit bbe5c15

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+2172
-1585
lines changed

.github/workflows/nightly.yml

-48
This file was deleted.

.github/workflows/publish_documentation.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ jobs:
1414
with:
1515
persist-credentials: false
1616

17-
- name: Set up Python 3.9
17+
- name: Set up Python
1818
uses: actions/setup-python@v4
1919
with:
20-
python-version: 3.9
20+
python-version: 3.11
2121

2222
- name: Build the documentation with Sphinx
2323
run: |

.github/workflows/release.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- name: Set up Python
1515
uses: actions/setup-python@v4
1616
with:
17-
python-version: 3.9
17+
python-version: 3.11
1818
- name: Build sdist and wheel
1919
run: |
2020
python -m pip install -U pip
@@ -51,7 +51,7 @@ jobs:
5151
- name: Set up Python
5252
uses: actions/setup-python@v4
5353
with:
54-
python-version: 3.9
54+
python-version: 3.11
5555
- name: Give PyPI some time to update the index
5656
run: sleep 240
5757
- name: Attempt install from PyPI

.github/workflows/test.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- uses: actions/checkout@v3
1515
- uses: actions/setup-python@v4
1616
with:
17-
python-version: 3.9
17+
python-version: 3.11
1818
- uses: pre-commit/action@v3.0.0
1919

2020
test:
@@ -24,7 +24,7 @@ jobs:
2424
- style
2525
strategy:
2626
matrix:
27-
python-version: [ '3.9', '3.11']
27+
python-version: ['3.11', '3.12']
2828
steps:
2929
- uses: actions/checkout@v3
3030
- name: Set up Python ${{ matrix.python-version }}

README.md

+11-15
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,6 @@ or via conda-forge:
4141
conda install -c conda-forge blackjax
4242
```
4343

44-
Nightly builds (bleeding edge) of Blackjax can also be installed using `pip`:
45-
46-
```bash
47-
pip install blackjax-nightly
48-
```
49-
5044
BlackJAX is written in pure Python but depends on XLA via JAX. By default, the
5145
version of JAX that will be installed along with BlackJAX will make your code
5246
run on CPU only. **If you want to use BlackJAX on GPU/TPU** we recommend you follow
@@ -81,9 +75,10 @@ state = nuts.init(initial_position)
8175

8276
# Iterate
8377
rng_key = jax.random.key(0)
84-
for _ in range(100):
85-
rng_key, nuts_key = jax.random.split(rng_key)
86-
state, _ = nuts.step(nuts_key, state)
78+
step = jax.jit(nuts.step)
79+
for i in range(100):
80+
nuts_key = jax.random.fold_in(rng_key, i)
81+
state, _ = step(nuts_key, state)
8782
```
8883

8984
See [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.
@@ -138,12 +133,13 @@ Please follow our [short guide](https://github.com/blackjax-devs/blackjax/blob/m
138133
To cite this repository:
139134

140135
```
141-
@software{blackjax2020github,
142-
author = {Cabezas, Alberto, Lao, Junpeng, and Louf, R\'emi},
143-
title = {{B}lackjax: A sampling library for {JAX}},
144-
url = {http://github.com/blackjax-devs/blackjax},
145-
version = {<insert current release tag>},
146-
year = {2023},
136+
@misc{cabezas2024blackjax,
137+
title={BlackJAX: Composable {B}ayesian inference in {JAX}},
138+
author={Alberto Cabezas and Adrien Corenflos and Junpeng Lao and Rémi Louf},
139+
year={2024},
140+
eprint={2402.10797},
141+
archivePrefix={arXiv},
142+
primaryClass={cs.MS}
147143
}
148144
```
149145
In the above bibtex entry, names are in alphabetical order, the version number should be the last tag on the `main` branch.

blackjax/__init__.py

+140-48
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,163 @@
1+
import dataclasses
2+
from typing import Callable
3+
14
from blackjax._version import __version__
25

36
from .adaptation.chees_adaptation import chees_adaptation
47
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
58
from .adaptation.meads_adaptation import meads_adaptation
69
from .adaptation.pathfinder_adaptation import pathfinder_adaptation
710
from .adaptation.window_adaptation import window_adaptation
11+
from .base import SamplingAlgorithm, VIAlgorithm
812
from .diagnostics import effective_sample_size as ess
913
from .diagnostics import potential_scale_reduction as rhat
10-
from .mcmc.barker import barker_proposal
11-
from .mcmc.dynamic_hmc import dynamic_hmc
12-
from .mcmc.elliptical_slice import elliptical_slice
13-
from .mcmc.ghmc import ghmc
14-
from .mcmc.hmc import hmc
15-
from .mcmc.mala import mala
16-
from .mcmc.marginal_latent_gaussian import mgrad_gaussian
17-
from .mcmc.mclmc import mclmc
18-
from .mcmc.nuts import nuts
19-
from .mcmc.periodic_orbital import orbital_hmc
20-
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh
21-
from .mcmc.rmhmc import rmhmc
14+
from .mcmc import barker
15+
from .mcmc import dynamic_hmc as _dynamic_hmc
16+
from .mcmc import elliptical_slice as _elliptical_slice
17+
from .mcmc import ghmc as _ghmc
18+
from .mcmc import hmc as _hmc
19+
from .mcmc import mala as _mala
20+
from .mcmc import marginal_latent_gaussian
21+
from .mcmc import mclmc as _mclmc
22+
from .mcmc import nuts as _nuts
23+
from .mcmc import periodic_orbital, random_walk
24+
from .mcmc import rmhmc as _rmhmc
25+
from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk
26+
from .mcmc.random_walk import (
27+
irmh_as_top_level_api,
28+
normal_random_walk,
29+
rmh_as_top_level_api,
30+
)
2231
from .optimizers import dual_averaging, lbfgs
23-
from .sgmcmc.csgld import csgld
24-
from .sgmcmc.sghmc import sghmc
25-
from .sgmcmc.sgld import sgld
26-
from .sgmcmc.sgnht import sgnht
27-
from .smc.adaptive_tempered import adaptive_tempered_smc
28-
from .smc.inner_kernel_tuning import inner_kernel_tuning
29-
from .smc.tempered import tempered_smc
30-
from .vi.meanfield_vi import meanfield_vi
31-
from .vi.pathfinder import pathfinder
32-
from .vi.schrodinger_follmer import schrodinger_follmer
33-
from .vi.svgd import svgd
32+
from .sgmcmc import csgld as _csgld
33+
from .sgmcmc import sghmc as _sghmc
34+
from .sgmcmc import sgld as _sgld
35+
from .sgmcmc import sgnht as _sgnht
36+
from .smc import adaptive_tempered
37+
from .smc import inner_kernel_tuning as _inner_kernel_tuning
38+
from .smc import tempered
39+
from .vi import meanfield_vi as _meanfield_vi
40+
from .vi import pathfinder as _pathfinder
41+
from .vi import schrodinger_follmer as _schrodinger_follmer
42+
from .vi import svgd as _svgd
43+
from .vi.pathfinder import PathFinderAlgorithm
44+
45+
"""
46+
The above three classes exist as a backwards compatible way of exposing both the high level, differentiable
47+
factory and the low level components, which may not be differentiable. Moreover, this design allows for the lower
48+
level to be mostly functional programming in nature and reducing boilerplate code.
49+
"""
50+
51+
52+
@dataclasses.dataclass
53+
class GenerateSamplingAPI:
54+
differentiable: Callable
55+
init: Callable
56+
build_kernel: Callable
57+
58+
def __call__(self, *args, **kwargs) -> SamplingAlgorithm:
59+
return self.differentiable(*args, **kwargs)
60+
61+
def register_factory(self, name, callable):
62+
setattr(self, name, callable)
63+
64+
65+
@dataclasses.dataclass
66+
class GenerateVariationalAPI:
67+
differentiable: Callable
68+
init: Callable
69+
step: Callable
70+
sample: Callable
71+
72+
def __call__(self, *args, **kwargs) -> VIAlgorithm:
73+
return self.differentiable(*args, **kwargs)
74+
75+
76+
@dataclasses.dataclass
77+
class GeneratePathfinderAPI:
78+
differentiable: Callable
79+
approximate: Callable
80+
sample: Callable
81+
82+
def __call__(self, *args, **kwargs) -> PathFinderAlgorithm:
83+
return self.differentiable(*args, **kwargs)
84+
85+
86+
def generate_top_level_api_from(module):
87+
return GenerateSamplingAPI(
88+
module.as_top_level_api, module.init, module.build_kernel
89+
)
90+
91+
92+
# MCMC
93+
hmc = generate_top_level_api_from(_hmc)
94+
nuts = generate_top_level_api_from(_nuts)
95+
rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh)
96+
irmh = GenerateSamplingAPI(
97+
irmh_as_top_level_api, random_walk.init, random_walk.build_irmh
98+
)
99+
dynamic_hmc = generate_top_level_api_from(_dynamic_hmc)
100+
rmhmc = generate_top_level_api_from(_rmhmc)
101+
mala = generate_top_level_api_from(_mala)
102+
mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian)
103+
orbital_hmc = generate_top_level_api_from(periodic_orbital)
104+
105+
additive_step_random_walk = GenerateSamplingAPI(
106+
_additive_step_random_walk, random_walk.init, random_walk.build_additive_step
107+
)
108+
109+
additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk)
110+
111+
mclmc = generate_top_level_api_from(_mclmc)
112+
elliptical_slice = generate_top_level_api_from(_elliptical_slice)
113+
ghmc = generate_top_level_api_from(_ghmc)
114+
barker_proposal = generate_top_level_api_from(barker)
115+
116+
hmc_family = [hmc, nuts]
117+
118+
# SMC
119+
adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered)
120+
tempered_smc = generate_top_level_api_from(tempered)
121+
inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning)
122+
123+
smc_family = [tempered_smc, adaptive_tempered_smc]
124+
"Step_fn returning state has a .particles attribute"
125+
126+
# stochastic gradient mcmc
127+
sgld = generate_top_level_api_from(_sgld)
128+
sghmc = generate_top_level_api_from(_sghmc)
129+
sgnht = generate_top_level_api_from(_sgnht)
130+
csgld = generate_top_level_api_from(_csgld)
131+
svgd = generate_top_level_api_from(_svgd)
132+
133+
# variational inference
134+
meanfield_vi = GenerateVariationalAPI(
135+
_meanfield_vi.as_top_level_api,
136+
_meanfield_vi.init,
137+
_meanfield_vi.step,
138+
_meanfield_vi.sample,
139+
)
140+
schrodinger_follmer = GenerateVariationalAPI(
141+
_schrodinger_follmer.as_top_level_api,
142+
_schrodinger_follmer.init,
143+
_schrodinger_follmer.step,
144+
_schrodinger_follmer.sample,
145+
)
146+
147+
pathfinder = GeneratePathfinderAPI(
148+
_pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample
149+
)
150+
34151

35152
__all__ = [
36153
"__version__",
37154
"dual_averaging", # optimizers
38155
"lbfgs",
39-
"hmc", # mcmc
40-
"dynamic_hmc",
41-
"rmhmc",
42-
"mala",
43-
"mgrad_gaussian",
44-
"nuts",
45-
"orbital_hmc",
46-
"additive_step_random_walk",
47-
"rmh",
48-
"irmh",
49-
"mclmc",
50-
"elliptical_slice",
51-
"ghmc",
52-
"barker_proposal",
53-
"sgld", # stochastic gradient mcmc
54-
"sghmc",
55-
"sgnht",
56-
"csgld",
57156
"window_adaptation", # mcmc adaptation
58157
"meads_adaptation",
59158
"chees_adaptation",
60159
"pathfinder_adaptation",
61160
"mclmc_find_L_and_step_size", # mclmc adaptation
62-
"adaptive_tempered_smc", # smc
63-
"tempered_smc",
64-
"inner_kernel_tuning",
65-
"meanfield_vi", # variational inference
66-
"pathfinder",
67-
"schrodinger_follmer",
68-
"svgd",
69161
"ess", # diagnostics
70162
"rhat",
71163
]

0 commit comments

Comments
 (0)