44
55struct HMCState{
66 TV<: AbstractVarInfo ,
7- TTraj <: AHMC.AbstractTrajectory ,
7+ TKernel <: AHMC.HMCKernel ,
88 THam<: AHMC.Hamiltonian ,
99 PhType<: AHMC.PhasePoint ,
1010 TAdapt<: AHMC.Adaptation.AbstractAdaptor ,
1111}
1212 vi:: TV
1313 i:: Int
14- traj :: TTraj
14+ kernel :: TKernel
1515 hamiltonian:: THam
1616 z:: PhType
1717 adaptor:: TAdapt
@@ -190,18 +190,18 @@ function DynamicPPL.initialstep(
190190 ϵ = spl. alg. ϵ
191191 end
192192
193- # Generate a trajectory .
194- traj = gen_traj (spl. alg, ϵ)
193+ # Generate a kernel .
194+ kernel = make_ahmc_kernel (spl. alg, ϵ)
195195
196196 # Create initial transition and state.
197197 # Already perform one step since otherwise we don't get any statistics.
198- t = AHMC. step (rng, hamiltonian, traj , z)
198+ t = AHMC. transition (rng, hamiltonian, kernel , z)
199199
200200 # Adaptation
201201 adaptor = AHMCAdaptor (spl. alg, hamiltonian. metric; ϵ= ϵ)
202202 if spl. alg isa AdaptiveHamiltonian
203- hamiltonian, traj , _ =
204- AHMC. adapt! (hamiltonian, traj , adaptor,
203+ hamiltonian, kernel , _ =
204+ AHMC. adapt! (hamiltonian, kernel , adaptor,
205205 1 , nadapts, t. z. θ, t. stat. acceptance_rate)
206206 end
207207
@@ -215,7 +215,7 @@ function DynamicPPL.initialstep(
215215 end
216216
217217 transition = HMCTransition (vi, t)
218- state = HMCState (vi, 1 , traj , hamiltonian, t. z, adaptor)
218+ state = HMCState (vi, 1 , kernel , hamiltonian, t. z, adaptor)
219219
220220 return transition, state
221221end
@@ -234,16 +234,16 @@ function AbstractMCMC.step(
234234 # Compute transition.
235235 hamiltonian = state. hamiltonian
236236 z = state. z
237- t = AHMC. step (rng, hamiltonian, state. traj , z)
237+ t = AHMC. transition (rng, hamiltonian, state. kernel , z)
238238
239239 # Adaptation
240240 i = state. i + 1
241241 if spl. alg isa AdaptiveHamiltonian
242- hamiltonian, traj , _ =
243- AHMC. adapt! (hamiltonian, state. traj , state. adaptor,
242+ hamiltonian, kernel , _ =
243+ AHMC. adapt! (hamiltonian, state. kernel , state. adaptor,
244244 i, nadapts, t. z. θ, t. stat. acceptance_rate)
245245 else
246- traj = state. traj
246+ kernel = state. kernel
247247 end
248248
249249 # Update variables
@@ -255,7 +255,7 @@ function AbstractMCMC.step(
255255
256256 # Compute next transition and state.
257257 transition = HMCTransition (vi, t)
258- newstate = HMCState (vi, i, traj , hamiltonian, t. z, state. adaptor)
258+ newstate = HMCState (vi, i, kernel , hamiltonian, t. z, state. adaptor)
259259
260260 return transition, newstate
261261end
@@ -459,9 +459,9 @@ function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state)
459459 return AHMC. renew (state. hamiltonian. metric, AHMC. getM⁻¹ (state. adaptor. pc))
460460end
461461
462- gen_traj (alg:: HMC , ϵ) = AHMC. StaticTrajectory (AHMC. Leapfrog (ϵ), alg. n_leapfrog)
463- gen_traj (alg:: HMCDA , ϵ) = AHMC. HMCDA (AHMC. Leapfrog (ϵ), alg. λ)
464- gen_traj (alg:: NUTS , ϵ) = AHMC. NUTS (AHMC. Leapfrog (ϵ), alg. max_depth, alg. Δ_max)
462+ make_ahmc_kernel (alg:: HMC , ϵ) = AHMC. StaticTrajectory (AHMC. Leapfrog (ϵ), alg. n_leapfrog)
463+ make_ahmc_kernel (alg:: HMCDA , ϵ) = AHMC. HMCDA (AHMC. Leapfrog (ϵ), alg. λ)
464+ make_ahmc_kernel (alg:: NUTS , ϵ) = AHMC. NUTS (AHMC. Leapfrog (ϵ), alg. max_depth, alg. Δ_max)
465465
466466# ###
467467# ### Compiler interface, i.e. tilde operators.
@@ -584,14 +584,14 @@ function HMCState(
584584 ϵ = spl. alg. ϵ
585585 end
586586
587- # Generate a trajectory .
588- traj = gen_traj (spl. alg, ϵ)
587+ # Generate a kernel .
588+ kernel = make_ahmc_kernel (spl. alg, ϵ)
589589
590590 # Generate a phasepoint. Replaced during sample_init!
591591 h, t = AHMC. sample_init (rng, h, θ_init) # this also ensure AHMC has the same dim as θ.
592592
593593 # Unlink everything.
594594 invlink! (vi, spl)
595595
596- return HMCState (vi, 0 , 0 , traj , h, AHMCAdaptor (spl. alg, metric; ϵ= ϵ), t. z)
596+ return HMCState (vi, 0 , 0 , kernel . τ , h, AHMCAdaptor (spl. alg, metric; ϵ= ϵ), t. z)
597597end
0 commit comments