Barker proposal #723
-
Hello, |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 13 replies
-
Well, I'd argue their implementation follows ours based on timelines :D, it looks similar but I can't (don't have time) to comment on the specifics. As a design rule, we do not implement warm up/adaptation as part of kernels, but as outer utilities, so the answer is neither. Looking at their code, they use HMC adaptation, so you can also do this here with very little trouble: I'm rather sure that if you change warmup = blackjax.window_adaptation(blackjax.nuts, joint_logdensity)
# we use 4 chains for sampling
n_chains = 4
rng_key, init_key, warmup_key = jax.random.split(rng_key, 3)
init_keys = jax.random.split(init_key, n_chains)
init_params = jax.vmap(init_param_fn)(init_keys)
@jax.vmap
def call_warmup(seed, param):
(initial_states, tuned_params), _ = warmup.run(seed, param, 1000)
return initial_states, tuned_params
warmup_keys = jax.random.split(warmup_key, n_chains)
initial_states, tuned_params = jax.jit(call_warmup)(warmup_keys, init_params) (example is from here https://blackjax-devs.github.io/sampling-book/models/change_of_variable_hmc.html) This question makes me think that Zanella and friends have derived results on calibration which we may want to have a look at eventually https://academic.oup.com/biomet/article/110/3/579/6764577 |
Beta Was this translation helpful? Give feedback.
-
Hello @AdrienCorenflos and @junpenglao |
Beta Was this translation helpful? Give feedback.
Well, I'd argue their implementation follows ours based on timelines :D, it looks similar but I can't (don't have time) to comment on the specifics.
As a design rule, we do not implement warm up/adaptation as part of kernels, but as outer utilities, so the answer is neither. Looking at their code, they use HMC adaptation, so you can also do this here with very little trouble:
I'm rather sure that if you change
nuts
forbarker
below, it will just work