Skip to content

Commit 070336b

Browse files
Micki-Doschulz
authored andcommitted
Use proposed samples from AdvancedHMC.jl for tuners in case of HMC Proposal
1 parent 34b50bf commit 070336b

File tree

8 files changed

+105
-37
lines changed

8 files changed

+105
-37
lines changed

ext/BATAdvancedHMCExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainS
2323
using BAT: MCMCBasicStats, push!, reweight_relative!
2424
using BAT: RAMTuning
2525
using BAT: MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering, NoMCMCTransformTuning
26-
using BAT: _current_sample_idx, _proposed_sample_idx, _current_sample_z_idx, _proposed_sample_z_idx, _cleanup_samples, current_sample_z, proposed_sample
26+
using BAT: _current_sample_idx, _proposed_sample_idx, _current_sample_z_idx, _proposed_sample_z_idx, _cleanup_samples, current_sample_z, proposed_sample_z, proposed_sample
2727
using BAT: AbstractTransformTarget, NoAdaptiveTransform, TriangularAffineTransform, valgrad_func
2828
using BAT: RNGPartition, get_rng, set_rng!
2929
using BAT: mcmc_step!!, nsamples, nsteps, samples_available, eff_acceptance_ratio
@@ -37,6 +37,8 @@ using BAT: AHMCSampleID, AHMCSampleIDVector
3737
using BAT: HMCMetric, DiagEuclideanMetric, UnitEuclideanMetric, DenseEuclideanMetric
3838
using BAT: HMCTuning, MassMatrixAdaptor, StepSizeAdaptor, NaiveHMCTuning, StanHMCTuning
3939

40+
using ChangesOfVariables: with_logabsdet_jacobian
41+
4042
using LinearAlgebra: cholesky
4143

4244
using MeasureBase: pullback

ext/ahmc_impl/ahmc_sampler_impl.jl

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function BAT.next_cycle!(mc_state::HMCState)
6969
mc_state.nsamples = 0
7070
mc_state.stepno = 0
7171

72-
#reset_rng_counters!(mc_state)
72+
reset_rng_counters!(mc_state)
7373

7474
resize!(mc_state.samples, 1)
7575

@@ -119,16 +119,15 @@ function BAT.mcmc_propose!!(mc_state::HMCState)
119119
z_phase = AdvancedHMC.phasepoint(hamiltonian, vec(z_current[:]), rand(rng, hamiltonian.metric, hamiltonian.kinetic))
120120
# Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics.
121121

122-
proposal.transition = AdvancedHMC.transition(rng, τ, hamiltonian, z_phase)
123-
z_proposed[:] = proposal.transition.z.θ
124-
x_proposed[:] = f_transform(z_proposed)
125-
126-
proposed_log_posterior = logdensityof(target, x_proposed)
127-
samples.logd[proposed_x_idx] = proposed_log_posterior
122+
proposal.transition, z_proposed_hmc, p_accept = _bat_transition(rng, τ, hamiltonian, z_phase)
123+
accepted = z_current[:] != proposal.transition.z.θ
124+
z_proposed[:] = accepted ? proposal.transition.z.θ : z_proposed_hmc
125+
126+
p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate
128127

129-
accepted = z_current != z_proposed
130-
tstat = AdvancedHMC.stat(proposal.transition)
131-
p_accept = accepted ? tstat.acceptance_rate : 0.0
128+
x_proposed[:] = f_transform(z_proposed)
129+
logd_x_proposed = logdensityof(target, x_proposed)
130+
samples.logd[proposed_x_idx] = logd_x_proposed
132131

133132
return mc_state, accepted, p_accept
134133
end
@@ -142,7 +141,7 @@ function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float
142141
samples.info.sampletype[current] = ACCEPTED_SAMPLE
143142
samples.info.sampletype[proposed] = CURRENT_SAMPLE
144143
mc_state.nsamples += 1
145-
144+
146145
tstat = AdvancedHMC.stat(proposal.transition)
147146
samples.info.hamiltonian_energy[proposed] = tstat.hamiltonian_energy
148147
# ToDo: Handle proposal-dependent tstat (only NUTS has tree_depth):
@@ -176,3 +175,79 @@ function BAT.set_mc_state_transform!!(mc_state::HMCState, f_transform_new::Funct
176175
mc_state_new = @set mc_state_new.f_transform = f_transform_new
177176
return mc_state_new
178177
end
178+
179+
180+
# Copied from AdvancedHMC.jl, but also return proposed point
181+
function _bat_transition(
182+
rng::AbstractRNG,
183+
τ::AdvancedHMC.Trajectory{TS,I,TC},
184+
h::AdvancedHMC.Hamiltonian,
185+
z0::AdvancedHMC.PhasePoint,
186+
) where {
187+
TS<:AdvancedHMC.AbstractTrajectorySampler,
188+
I<:AdvancedHMC.AbstractIntegrator,
189+
TC<:AdvancedHMC.DynamicTerminationCriterion,
190+
}
191+
H0 = AdvancedHMC.energy(z0)
192+
tree = AdvancedHMC.BinaryTree(
193+
z0,
194+
z0,
195+
AdvancedHMC.TurnStatistic.termination_criterion, z0),
196+
zero(H0),
197+
zero(Int),
198+
zero(H0),
199+
)
200+
sampler = TS(rng, z0)
201+
termination = AdvancedHMC.Termination(false, false)
202+
zcand = z0
203+
proposed_zs = Vector[]
204+
205+
j = 0
206+
while !AdvancedHMC.isterminated(termination) && j < τ.termination_criterion.max_depth
207+
v = rand(rng, [-1, 1])
208+
if v == -1
209+
tree′, sampler′, termination′ =
210+
AdvancedHMC.build_tree(rng, τ, h, tree.zleft, sampler, v, j, H0)
211+
treeleft, treeright = tree′, tree
212+
else
213+
tree′, sampler′, termination′ =
214+
AdvancedHMC.build_tree(rng, τ, h, tree.zright, sampler, v, j, H0)
215+
treeleft, treeright = tree, tree′
216+
end
217+
if !AdvancedHMC.isterminated(termination′)
218+
j = j + 1
219+
if AdvancedHMC.mh_accept(rng, sampler, sampler′)
220+
zcand = sampler′.zcand
221+
end
222+
end
223+
push!(proposed_zs, sampler′.zcand.θ)
224+
225+
tree = AdvancedHMC.combine(treeleft, treeright)
226+
sampler = AdvancedHMC.combine(zcand, sampler, sampler′)
227+
termination =
228+
termination *
229+
termination′ *
230+
AdvancedHMC.isterminated.termination_criterion, h, tree, treeleft, treeright)
231+
end
232+
233+
H = AdvancedHMC.energy(zcand)
234+
tstat = AdvancedHMC.merge(
235+
(
236+
n_steps = tree.nα,
237+
is_accept = true,
238+
acceptance_rate = tree.sum_α / tree.nα,
239+
log_density = zcand.ℓπ.value,
240+
hamiltonian_energy = H,
241+
hamiltonian_energy_error = H - H0,
242+
max_hamiltonian_energy_error = tree.ΔH_max,
243+
tree_depth = j,
244+
numerical_error = termination.numerical,
245+
),
246+
AdvancedHMC.stat.integrator),
247+
)
248+
249+
z_proposed = proposed_zs[end]
250+
p_accept = tstat.acceptance_rate
251+
252+
return AdvancedHMC.Transition(zcand, tstat), z_proposed, p_accept
253+
end

ext/ahmc_impl/ahmc_stan_tuner_impl.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ function BAT.mcmc_tune_post_step!!(
5555
is_in_window = stan_state.i >= stan_state.window_start && stan_state.i <= stan_state.window_end
5656
is_window_end = stan_state.i in stan_state.window_splits
5757

58-
# What to append?
59-
is_in_window && BAT.push!(stats, proposed_sample(chain_state))
58+
if is_in_window
59+
BAT.push!(stats, proposed_sample(chain_state))
60+
end
6061

6162
if is_window_end
6263
A = chain_state.f_transform.A

src/samplers/mcmc/mcmc_algorithm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ function mcmc_iterate!!(
280280
while (
281281
(nsteps(mcmc_state) - start_nsteps) < max_nsteps &&
282282
(time() - start_time) < max_time
283-
)
283+
)
284284
mcmc_state = mcmc_step!!(mcmc_state)
285285

286286
if !isnothing(output)

src/samplers/mcmc/mcmc_sample.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,23 +88,22 @@ function bat_sample_impl(m::BATMeasure, samplingalg::TransformedMCMC, context::B
8888
if !samplingalg.store_burnin
8989
chain_outputs .= DensitySampleVector.(mcmc_states)
9090
end
91-
91+
9292
mcmc_states = mcmc_burnin!(
9393
samplingalg.store_burnin ? chain_outputs : nothing,
9494
mcmc_states,
9595
samplingalg,
9696
samplingalg.store_burnin ? samplingalg.callback : nop_func
9797
)
98-
98+
9999
next_cycle!.(mcmc_states)
100-
100+
101101
mcmc_states = mcmc_iterate!!(
102102
chain_outputs,
103103
mcmc_states;
104104
max_nsteps = samplingalg.nsteps,
105105
nonzero_weights = samplingalg.nonzero_weights
106-
)
107-
106+
)
108107
samples_transformed = DensitySampleVector(first(mcmc_states))
109108
isempty(chain_outputs) || append!.(Ref(samples_transformed), chain_outputs)
110109

src/samplers/mcmc/mcmc_state.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ end
145145
function mcmc_step!!(mcmc_state::MCMCState)
146146
_cleanup_samples(mcmc_state)
147147

148-
#reset_rng_counters!(mcmc_state)
148+
reset_rng_counters!(mcmc_state)
149149

150150
chain_state = mcmc_state.chain_state
151151

@@ -156,7 +156,7 @@ function mcmc_step!!(mcmc_state::MCMCState)
156156
resize!(samples, size(samples, 1) + 1)
157157

158158
samples.info[lastindex(samples)] = _get_sample_id(proposal, chain_state.info.id, chain_state.info.cycle, chain_state.stepno, PROPOSED_SAMPLE)[1]
159-
159+
160160
chain_state, accepted, p_accept = mcmc_propose!!(chain_state)
161161

162162
mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept)
@@ -219,7 +219,7 @@ function next_cycle!(chain_state::MCMCChainState)
219219
chain_state.nsamples = 0
220220
chain_state.stepno = 0
221221

222-
#reset_rng_counters!(chain_state)
222+
reset_rng_counters!(chain_state)
223223

224224
resize!(chain_state.samples, 1)
225225

@@ -277,7 +277,6 @@ end
277277

278278
function mcmc_update_z_position!!(mc_state::MCMCChainState)
279279
f_transform = mc_state.f_transform
280-
sample_z = mc_state.sample_z
281280

282281
current_sample_x = current_sample(mc_state)
283282
proposed_sample_x = proposed_sample(mc_state)

src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,34 +76,28 @@ function mcmc_tune_post_step!!(
7676
mc_state::MCMCChainState,
7777
p_accept::Real,
7878
)
79-
# TODO: MD: Discuss; apparently the RandomWalk sampler wants the trafo to be tuned even if p_accept = 0. If not, the burnin does not converge.
80-
if iszero(p_accept) && !(mc_state isa MHChainState)
81-
return mc_state, tuner_state
82-
end
83-
8479
(; f_transform, sample_z) = mc_state
8580
(; target_acceptance, gamma) = tuner_state.tuning
8681
b = f_transform.b
87-
82+
8883
tuner_state_new = @set tuner_state.nsteps = tuner_state.nsteps + 1
89-
84+
9085
n_dims = size(sample_z.v[1], 1)
9186
η = min(1, n_dims * tuner_state.nsteps^(-gamma))
9287

9388
s_L = f_transform.A
9489

9590
u = sample_z.v[2] - sample_z.v[1] # proposed - current
96-
M = s_L * (I + η * (p_accept - target_acceptance) * (u * u') / norm(u)^2 ) * s_L'
97-
91+
M = s_L * (I + η * (p_accept - target_acceptance) * (u * u') / norm(u)^2) * s_L'
9892
new_s_L = oftype(s_L, cholesky(Positive, M).L)
99-
93+
10094
x = mc_state.samples[_proposed_sample_idx(mc_state)] # proposed in x-space
10195
mean_update_rate = η / 10 # heuristic
10296
α = mean_update_rate * p_accept
10397
new_b = oftype(b, (1- α) * b + α * x.v)
10498

10599
f_transform_new = MulAdd(new_s_L, new_b)
106-
100+
107101
mc_state_new = set_mc_state_transform!!(mc_state, f_transform_new)
108102
mc_state_new = mcmc_update_z_position!!(mc_state_new)
109103

src/samplers/mcmc/mh_sampler.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,6 @@ function _accept_reject!(mc_state::MHChainState, accepted::Bool, p_accept::Float
144144
samples.info.sampletype[proposed] = CURRENT_SAMPLE
145145

146146
mc_state.nsamples += 1
147-
148-
mc_state.sample_z[1] = deepcopy(proposed_sample_z(mc_state))
149147
else
150148
samples.info.sampletype[proposed] = REJECTED_SAMPLE
151149
end

0 commit comments

Comments
 (0)