|
| 1 | +import dataclasses |
| 2 | +from typing import Callable |
| 3 | + |
1 | 4 | from blackjax._version import __version__
|
2 | 5 |
|
3 | 6 | from .adaptation.chees_adaptation import chees_adaptation
|
4 | 7 | from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
|
5 | 8 | from .adaptation.meads_adaptation import meads_adaptation
|
6 | 9 | from .adaptation.pathfinder_adaptation import pathfinder_adaptation
|
7 | 10 | from .adaptation.window_adaptation import window_adaptation
|
| 11 | +from .base import SamplingAlgorithm, VIAlgorithm |
8 | 12 | from .diagnostics import effective_sample_size as ess
|
9 | 13 | from .diagnostics import potential_scale_reduction as rhat
|
10 |
| -from .mcmc.barker import barker_proposal |
11 |
| -from .mcmc.dynamic_hmc import dynamic_hmc |
12 |
| -from .mcmc.elliptical_slice import elliptical_slice |
13 |
| -from .mcmc.ghmc import ghmc |
14 |
| -from .mcmc.hmc import hmc |
15 |
| -from .mcmc.mala import mala |
16 |
| -from .mcmc.marginal_latent_gaussian import mgrad_gaussian |
17 |
| -from .mcmc.mclmc import mclmc |
18 |
| -from .mcmc.nuts import nuts |
19 |
| -from .mcmc.periodic_orbital import orbital_hmc |
20 |
| -from .mcmc.random_walk import additive_step_random_walk, irmh, rmh |
21 |
| -from .mcmc.rmhmc import rmhmc |
| 14 | +from .mcmc import barker |
| 15 | +from .mcmc import dynamic_hmc as _dynamic_hmc |
| 16 | +from .mcmc import elliptical_slice as _elliptical_slice |
| 17 | +from .mcmc import ghmc as _ghmc |
| 18 | +from .mcmc import hmc as _hmc |
| 19 | +from .mcmc import mala as _mala |
| 20 | +from .mcmc import marginal_latent_gaussian |
| 21 | +from .mcmc import mclmc as _mclmc |
| 22 | +from .mcmc import nuts as _nuts |
| 23 | +from .mcmc import periodic_orbital, random_walk |
| 24 | +from .mcmc import rmhmc as _rmhmc |
| 25 | +from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk |
| 26 | +from .mcmc.random_walk import ( |
| 27 | + irmh_as_top_level_api, |
| 28 | + normal_random_walk, |
| 29 | + rmh_as_top_level_api, |
| 30 | +) |
22 | 31 | from .optimizers import dual_averaging, lbfgs
|
23 |
| -from .sgmcmc.csgld import csgld |
24 |
| -from .sgmcmc.sghmc import sghmc |
25 |
| -from .sgmcmc.sgld import sgld |
26 |
| -from .sgmcmc.sgnht import sgnht |
27 |
| -from .smc.adaptive_tempered import adaptive_tempered_smc |
28 |
| -from .smc.inner_kernel_tuning import inner_kernel_tuning |
29 |
| -from .smc.tempered import tempered_smc |
30 |
| -from .vi.meanfield_vi import meanfield_vi |
31 |
| -from .vi.pathfinder import pathfinder |
32 |
| -from .vi.schrodinger_follmer import schrodinger_follmer |
33 |
| -from .vi.svgd import svgd |
| 32 | +from .sgmcmc import csgld as _csgld |
| 33 | +from .sgmcmc import sghmc as _sghmc |
| 34 | +from .sgmcmc import sgld as _sgld |
| 35 | +from .sgmcmc import sgnht as _sgnht |
| 36 | +from .smc import adaptive_tempered |
| 37 | +from .smc import inner_kernel_tuning as _inner_kernel_tuning |
| 38 | +from .smc import tempered |
| 39 | +from .vi import meanfield_vi as _meanfield_vi |
| 40 | +from .vi import pathfinder as _pathfinder |
| 41 | +from .vi import schrodinger_follmer as _schrodinger_follmer |
| 42 | +from .vi import svgd as _svgd |
| 43 | +from .vi.pathfinder import PathFinderAlgorithm |
| 44 | + |
| 45 | +""" |
| 46 | +The above three classes exist as a backwards compatible way of exposing both the high level, differentiable |
| 47 | +factory and the low level components, which may not be differentiable. Moreover, this design allows for the lower |
| 48 | +level to be mostly functional programming in nature and reducing boilerplate code. |
| 49 | +""" |
| 50 | + |
| 51 | + |
| 52 | +@dataclasses.dataclass |
| 53 | +class GenerateSamplingAPI: |
| 54 | + differentiable: Callable |
| 55 | + init: Callable |
| 56 | + build_kernel: Callable |
| 57 | + |
| 58 | + def __call__(self, *args, **kwargs) -> SamplingAlgorithm: |
| 59 | + return self.differentiable(*args, **kwargs) |
| 60 | + |
| 61 | + def register_factory(self, name, callable): |
| 62 | + setattr(self, name, callable) |
| 63 | + |
| 64 | + |
| 65 | +@dataclasses.dataclass |
| 66 | +class GenerateVariationalAPI: |
| 67 | + differentiable: Callable |
| 68 | + init: Callable |
| 69 | + step: Callable |
| 70 | + sample: Callable |
| 71 | + |
| 72 | + def __call__(self, *args, **kwargs) -> VIAlgorithm: |
| 73 | + return self.differentiable(*args, **kwargs) |
| 74 | + |
| 75 | + |
| 76 | +@dataclasses.dataclass |
| 77 | +class GeneratePathfinderAPI: |
| 78 | + differentiable: Callable |
| 79 | + approximate: Callable |
| 80 | + sample: Callable |
| 81 | + |
| 82 | + def __call__(self, *args, **kwargs) -> PathFinderAlgorithm: |
| 83 | + return self.differentiable(*args, **kwargs) |
| 84 | + |
| 85 | + |
| 86 | +def generate_top_level_api_from(module): |
| 87 | + return GenerateSamplingAPI( |
| 88 | + module.as_top_level_api, module.init, module.build_kernel |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +# MCMC |
| 93 | +hmc = generate_top_level_api_from(_hmc) |
| 94 | +nuts = generate_top_level_api_from(_nuts) |
| 95 | +rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh) |
| 96 | +irmh = GenerateSamplingAPI( |
| 97 | + irmh_as_top_level_api, random_walk.init, random_walk.build_irmh |
| 98 | +) |
| 99 | +dynamic_hmc = generate_top_level_api_from(_dynamic_hmc) |
| 100 | +rmhmc = generate_top_level_api_from(_rmhmc) |
| 101 | +mala = generate_top_level_api_from(_mala) |
| 102 | +mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian) |
| 103 | +orbital_hmc = generate_top_level_api_from(periodic_orbital) |
| 104 | + |
| 105 | +additive_step_random_walk = GenerateSamplingAPI( |
| 106 | + _additive_step_random_walk, random_walk.init, random_walk.build_additive_step |
| 107 | +) |
| 108 | + |
| 109 | +additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) |
| 110 | + |
| 111 | +mclmc = generate_top_level_api_from(_mclmc) |
| 112 | +elliptical_slice = generate_top_level_api_from(_elliptical_slice) |
| 113 | +ghmc = generate_top_level_api_from(_ghmc) |
| 114 | +barker_proposal = generate_top_level_api_from(barker) |
| 115 | + |
| 116 | +hmc_family = [hmc, nuts] |
| 117 | + |
| 118 | +# SMC |
| 119 | +adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered) |
| 120 | +tempered_smc = generate_top_level_api_from(tempered) |
| 121 | +inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning) |
| 122 | + |
| 123 | +smc_family = [tempered_smc, adaptive_tempered_smc] |
| 124 | +"Step_fn returning state has a .particles attribute" |
| 125 | + |
| 126 | +# stochastic gradient mcmc |
| 127 | +sgld = generate_top_level_api_from(_sgld) |
| 128 | +sghmc = generate_top_level_api_from(_sghmc) |
| 129 | +sgnht = generate_top_level_api_from(_sgnht) |
| 130 | +csgld = generate_top_level_api_from(_csgld) |
| 131 | +svgd = generate_top_level_api_from(_svgd) |
| 132 | + |
| 133 | +# variational inference |
| 134 | +meanfield_vi = GenerateVariationalAPI( |
| 135 | + _meanfield_vi.as_top_level_api, |
| 136 | + _meanfield_vi.init, |
| 137 | + _meanfield_vi.step, |
| 138 | + _meanfield_vi.sample, |
| 139 | +) |
| 140 | +schrodinger_follmer = GenerateVariationalAPI( |
| 141 | + _schrodinger_follmer.as_top_level_api, |
| 142 | + _schrodinger_follmer.init, |
| 143 | + _schrodinger_follmer.step, |
| 144 | + _schrodinger_follmer.sample, |
| 145 | +) |
| 146 | + |
| 147 | +pathfinder = GeneratePathfinderAPI( |
| 148 | + _pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample |
| 149 | +) |
| 150 | + |
34 | 151 |
|
35 | 152 | __all__ = [
|
36 | 153 | "__version__",
|
37 | 154 | "dual_averaging", # optimizers
|
38 | 155 | "lbfgs",
|
39 |
| - "hmc", # mcmc |
40 |
| - "dynamic_hmc", |
41 |
| - "rmhmc", |
42 |
| - "mala", |
43 |
| - "mgrad_gaussian", |
44 |
| - "nuts", |
45 |
| - "orbital_hmc", |
46 |
| - "additive_step_random_walk", |
47 |
| - "rmh", |
48 |
| - "irmh", |
49 |
| - "mclmc", |
50 |
| - "elliptical_slice", |
51 |
| - "ghmc", |
52 |
| - "barker_proposal", |
53 |
| - "sgld", # stochastic gradient mcmc |
54 |
| - "sghmc", |
55 |
| - "sgnht", |
56 |
| - "csgld", |
57 | 156 | "window_adaptation", # mcmc adaptation
|
58 | 157 | "meads_adaptation",
|
59 | 158 | "chees_adaptation",
|
60 | 159 | "pathfinder_adaptation",
|
61 | 160 | "mclmc_find_L_and_step_size", # mclmc adaptation
|
62 |
| - "adaptive_tempered_smc", # smc |
63 |
| - "tempered_smc", |
64 |
| - "inner_kernel_tuning", |
65 |
| - "meanfield_vi", # variational inference |
66 |
| - "pathfinder", |
67 |
| - "schrodinger_follmer", |
68 |
| - "svgd", |
69 | 161 | "ess", # diagnostics
|
70 | 162 | "rhat",
|
71 | 163 | ]
|
0 commit comments