@@ -69,7 +69,7 @@ function BAT.next_cycle!(mc_state::HMCState)
69
69
mc_state. nsamples = 0
70
70
mc_state. stepno = 0
71
71
72
- # reset_rng_counters!(mc_state)
72
+ reset_rng_counters! (mc_state)
73
73
74
74
resize! (mc_state. samples, 1 )
75
75
@@ -119,16 +119,15 @@ function BAT.mcmc_propose!!(mc_state::HMCState)
119
119
z_phase = AdvancedHMC. phasepoint (hamiltonian, vec (z_current[:]), rand (rng, hamiltonian. metric, hamiltonian. kinetic))
120
120
# Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics.
121
121
122
- proposal. transition = AdvancedHMC. transition (rng, τ, hamiltonian, z_phase)
123
- z_proposed[:] = proposal. transition. z. θ
124
- x_proposed[:] = f_transform (z_proposed)
125
-
126
- proposed_log_posterior = logdensityof (target, x_proposed)
127
- samples. logd[proposed_x_idx] = proposed_log_posterior
122
+ proposal. transition, z_proposed_hmc, p_accept = _bat_transition (rng, τ, hamiltonian, z_phase)
123
+ accepted = z_current[:] != proposal. transition. z. θ
124
+ z_proposed[:] = accepted ? proposal. transition. z. θ : z_proposed_hmc
125
+
126
+ p_accept = AdvancedHMC. stat (proposal. transition). acceptance_rate
128
127
129
- accepted = z_current != z_proposed
130
- tstat = AdvancedHMC . stat (proposal . transition )
131
- p_accept = accepted ? tstat . acceptance_rate : 0.0
128
+ x_proposed[:] = f_transform (z_proposed)
129
+ logd_x_proposed = logdensityof (target, x_proposed )
130
+ samples . logd[proposed_x_idx] = logd_x_proposed
132
131
133
132
return mc_state, accepted, p_accept
134
133
end
@@ -142,7 +141,7 @@ function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float
142
141
samples. info. sampletype[current] = ACCEPTED_SAMPLE
143
142
samples. info. sampletype[proposed] = CURRENT_SAMPLE
144
143
mc_state. nsamples += 1
145
-
144
+
146
145
tstat = AdvancedHMC. stat (proposal. transition)
147
146
samples. info. hamiltonian_energy[proposed] = tstat. hamiltonian_energy
148
147
# ToDo: Handle proposal-dependent tstat (only NUTS has tree_depth):
@@ -176,3 +175,79 @@ function BAT.set_mc_state_transform!!(mc_state::HMCState, f_transform_new::Funct
176
175
mc_state_new = @set mc_state_new. f_transform = f_transform_new
177
176
return mc_state_new
178
177
end
178
+
179
+
180
+ # Copied from AdvancedHMC.jl, but also return proposed point
181
+ function _bat_transition (
182
+ rng:: AbstractRNG ,
183
+ τ:: AdvancedHMC.Trajectory{TS,I,TC} ,
184
+ h:: AdvancedHMC.Hamiltonian ,
185
+ z0:: AdvancedHMC.PhasePoint ,
186
+ ) where {
187
+ TS<: AdvancedHMC.AbstractTrajectorySampler ,
188
+ I<: AdvancedHMC.AbstractIntegrator ,
189
+ TC<: AdvancedHMC.DynamicTerminationCriterion ,
190
+ }
191
+ H0 = AdvancedHMC. energy (z0)
192
+ tree = AdvancedHMC. BinaryTree (
193
+ z0,
194
+ z0,
195
+ AdvancedHMC. TurnStatistic (τ. termination_criterion, z0),
196
+ zero (H0),
197
+ zero (Int),
198
+ zero (H0),
199
+ )
200
+ sampler = TS (rng, z0)
201
+ termination = AdvancedHMC. Termination (false , false )
202
+ zcand = z0
203
+ proposed_zs = Vector[]
204
+
205
+ j = 0
206
+ while ! AdvancedHMC. isterminated (termination) && j < τ. termination_criterion. max_depth
207
+ v = rand (rng, [- 1 , 1 ])
208
+ if v == - 1
209
+ tree′, sampler′, termination′ =
210
+ AdvancedHMC. build_tree (rng, τ, h, tree. zleft, sampler, v, j, H0)
211
+ treeleft, treeright = tree′, tree
212
+ else
213
+ tree′, sampler′, termination′ =
214
+ AdvancedHMC. build_tree (rng, τ, h, tree. zright, sampler, v, j, H0)
215
+ treeleft, treeright = tree, tree′
216
+ end
217
+ if ! AdvancedHMC. isterminated (termination′)
218
+ j = j + 1
219
+ if AdvancedHMC. mh_accept (rng, sampler, sampler′)
220
+ zcand = sampler′. zcand
221
+ end
222
+ end
223
+ push! (proposed_zs, sampler′. zcand. θ)
224
+
225
+ tree = AdvancedHMC. combine (treeleft, treeright)
226
+ sampler = AdvancedHMC. combine (zcand, sampler, sampler′)
227
+ termination =
228
+ termination *
229
+ termination′ *
230
+ AdvancedHMC. isterminated (τ. termination_criterion, h, tree, treeleft, treeright)
231
+ end
232
+
233
+ H = AdvancedHMC. energy (zcand)
234
+ tstat = AdvancedHMC. merge (
235
+ (
236
+ n_steps = tree. nα,
237
+ is_accept = true ,
238
+ acceptance_rate = tree. sum_α / tree. nα,
239
+ log_density = zcand. ℓπ. value,
240
+ hamiltonian_energy = H,
241
+ hamiltonian_energy_error = H - H0,
242
+ max_hamiltonian_energy_error = tree. ΔH_max,
243
+ tree_depth = j,
244
+ numerical_error = termination. numerical,
245
+ ),
246
+ AdvancedHMC. stat (τ. integrator),
247
+ )
248
+
249
+ z_proposed = proposed_zs[end ]
250
+ p_accept = tstat. acceptance_rate
251
+
252
+ return AdvancedHMC. Transition (zcand, tstat), z_proposed, p_accept
253
+ end
0 commit comments