Skip to content

Commit 7cb94d6

Browse files
authored
Updated HMC implementation for new AHMC version (#1660)
* updated HMC implementation according to new AHMC interface * bump compat bound for AdvancedHMC * bumped patch version * disable GMM Gibbs conditional test to see if it fixes CI * include tests again * dont test non-AD samplers for every AD backend * added back a test * added back a test * removed some redundant tests and fixed a typo * added macro timed_testset * upper-bound Distributions.jl apparently fixes the test-freeze * hyphen compat specifies arent compatible with Julia 1.3 * removed the timed_testset stuff
1 parent d029198 commit 7cb94d6

File tree

7 files changed

+39
-38
lines changed

7 files changed

+39
-38
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.16.5"
3+
version = "0.16.6"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -34,7 +34,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3434

3535
[compat]
3636
AbstractMCMC = "3.2"
37-
AdvancedHMC = "0.2.24"
37+
AdvancedHMC = "0.3.0"
3838
AdvancedMH = "0.6"
3939
AdvancedPS = "0.2.4"
4040
AdvancedVI = "0.1"

src/inference/gibbs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ function gibbs_state(
126126
state.z.θ .= θ_old
127127
z = state.z
128128

129-
return HMCState(varinfo, state.i, state.traj, hamiltonian, z, state.adaptor)
129+
return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor)
130130
end
131131

132132
"""

src/inference/hmc.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
struct HMCState{
66
TV<:AbstractVarInfo,
7-
TTraj<:AHMC.AbstractTrajectory,
7+
TKernel<:AHMC.HMCKernel,
88
THam<:AHMC.Hamiltonian,
99
PhType<:AHMC.PhasePoint,
1010
TAdapt<:AHMC.Adaptation.AbstractAdaptor,
1111
}
1212
vi::TV
1313
i::Int
14-
traj::TTraj
14+
kernel::TKernel
1515
hamiltonian::THam
1616
z::PhType
1717
adaptor::TAdapt
@@ -190,18 +190,18 @@ function DynamicPPL.initialstep(
190190
ϵ = spl.alg.ϵ
191191
end
192192

193-
# Generate a trajectory.
194-
traj = gen_traj(spl.alg, ϵ)
193+
# Generate a kernel.
194+
kernel = make_ahmc_kernel(spl.alg, ϵ)
195195

196196
# Create initial transition and state.
197197
# Already perform one step since otherwise we don't get any statistics.
198-
t = AHMC.step(rng, hamiltonian, traj, z)
198+
t = AHMC.transition(rng, hamiltonian, kernel, z)
199199

200200
# Adaptation
201201
adaptor = AHMCAdaptor(spl.alg, hamiltonian.metric; ϵ=ϵ)
202202
if spl.alg isa AdaptiveHamiltonian
203-
hamiltonian, traj, _ =
204-
AHMC.adapt!(hamiltonian, traj, adaptor,
203+
hamiltonian, kernel, _ =
204+
AHMC.adapt!(hamiltonian, kernel, adaptor,
205205
1, nadapts, t.z.θ, t.stat.acceptance_rate)
206206
end
207207

@@ -215,7 +215,7 @@ function DynamicPPL.initialstep(
215215
end
216216

217217
transition = HMCTransition(vi, t)
218-
state = HMCState(vi, 1, traj, hamiltonian, t.z, adaptor)
218+
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)
219219

220220
return transition, state
221221
end
@@ -234,16 +234,16 @@ function AbstractMCMC.step(
234234
# Compute transition.
235235
hamiltonian = state.hamiltonian
236236
z = state.z
237-
t = AHMC.step(rng, hamiltonian, state.traj, z)
237+
t = AHMC.transition(rng, hamiltonian, state.kernel, z)
238238

239239
# Adaptation
240240
i = state.i + 1
241241
if spl.alg isa AdaptiveHamiltonian
242-
hamiltonian, traj, _ =
243-
AHMC.adapt!(hamiltonian, state.traj, state.adaptor,
242+
hamiltonian, kernel, _ =
243+
AHMC.adapt!(hamiltonian, state.kernel, state.adaptor,
244244
i, nadapts, t.z.θ, t.stat.acceptance_rate)
245245
else
246-
traj = state.traj
246+
kernel = state.kernel
247247
end
248248

249249
# Update variables
@@ -255,7 +255,7 @@ function AbstractMCMC.step(
255255

256256
# Compute next transition and state.
257257
transition = HMCTransition(vi, t)
258-
newstate = HMCState(vi, i, traj, hamiltonian, t.z, state.adaptor)
258+
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)
259259

260260
return transition, newstate
261261
end
@@ -459,9 +459,9 @@ function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state)
459459
return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc))
460460
end
461461

462-
gen_traj(alg::HMC, ϵ) = AHMC.StaticTrajectory(AHMC.Leapfrog(ϵ), alg.n_leapfrog)
463-
gen_traj(alg::HMCDA, ϵ) = AHMC.HMCDA(AHMC.Leapfrog(ϵ), alg.λ)
464-
gen_traj(alg::NUTS, ϵ) = AHMC.NUTS(AHMC.Leapfrog(ϵ), alg.max_depth, alg.Δ_max)
462+
make_ahmc_kernel(alg::HMC, ϵ) = AHMC.StaticTrajectory(AHMC.Leapfrog(ϵ), alg.n_leapfrog)
463+
make_ahmc_kernel(alg::HMCDA, ϵ) = AHMC.HMCDA(AHMC.Leapfrog(ϵ), alg.λ)
464+
make_ahmc_kernel(alg::NUTS, ϵ) = AHMC.NUTS(AHMC.Leapfrog(ϵ), alg.max_depth, alg.Δ_max)
465465

466466
####
467467
#### Compiler interface, i.e. tilde operators.
@@ -584,14 +584,14 @@ function HMCState(
584584
ϵ = spl.alg.ϵ
585585
end
586586

587-
# Generate a trajectory.
588-
traj = gen_traj(spl.alg, ϵ)
587+
# Generate a kernel.
588+
kernel = make_ahmc_kernel(spl.alg, ϵ)
589589

590590
# Generate a phasepoint. Replaced during sample_init!
591591
h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ.
592592

593593
# Unlink everything.
594594
invlink!(vi, spl)
595595

596-
return HMCState(vi, 0, 0, traj, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
596+
return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
597597
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ AdvancedPS = "0.2"
3434
AdvancedVI = "0.1"
3535
Clustering = "0.14"
3636
CmdStan = "6.0.8"
37-
Distributions = "0.23.8, 0.24, 0.25"
37+
Distributions = "< 0.25.11"
3838
DistributionsAD = "0.6.3"
3939
DynamicHMC = "2.1.6, 3.0"
4040
DynamicPPL = "0.12"

test/inference/Inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "io.jl" begin
1+
@testset "inference.jl" begin
22
# Only test threading if 1.3+.
33
if VERSION > v"1.2"
44
@testset "threaded sampling" begin

test/inference/gibbs.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,10 @@
5050
chain = sample(gdemo(1.5, 2.0), alg, 5_000)
5151
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1)
5252

53-
setadsafe(true)
54-
5553
Random.seed!(200)
5654
gibbs = Gibbs(PG(15, :z1, :z2, :z3, :z4), HMC(0.15, 3, :mu1, :mu2))
5755
chain = sample(MoGtest_default, gibbs, 5_000)
5856
check_MoGtest_default(chain, atol=0.15)
59-
60-
setadsafe(false)
61-
62-
Random.seed!(200)
63-
gibbs = Gibbs(PG(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2))
64-
chain = sample(MoGtest_default, gibbs, 5_000)
65-
check_MoGtest_default(chain, atol=0.1)
6657
end
6758

6859
@turing_testset "transitions" begin

test/runtests.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,27 @@ include("test_utils/AllUtils.jl")
4646
include("core/ad.jl")
4747
end
4848

49+
@testset "samplers (without AD)" begin
50+
include("inference/AdvancedSMC.jl")
51+
include("inference/emcee.jl")
52+
include("inference/ess.jl")
53+
include("inference/is.jl")
54+
end
55+
4956
Turing.setrdcache(false)
5057
for adbackend in (:forwarddiff, :tracker, :reversediff)
5158
Turing.setadbackend(adbackend)
59+
@info "Testing $(adbackend)"
60+
start = time()
5261
@testset "inference: $adbackend" begin
5362
@testset "samplers" begin
5463
include("inference/gibbs.jl")
5564
include("inference/gibbs_conditional.jl")
5665
include("inference/hmc.jl")
57-
include("inference/is.jl")
58-
include("inference/mh.jl")
59-
include("inference/ess.jl")
60-
include("inference/emcee.jl")
61-
include("inference/AdvancedSMC.jl")
6266
include("inference/Inference.jl")
6367
include("contrib/inference/dynamichmc.jl")
6468
include("contrib/inference/sghmc.jl")
69+
include("inference/mh.jl")
6570
end
6671
end
6772

@@ -72,6 +77,11 @@ include("test_utils/AllUtils.jl")
7277
@testset "modes" begin
7378
include("modes/ModeEstimation.jl")
7479
end
80+
81+
# Useful for
82+
# a) discovering performance regressions,
83+
# b) figuring out why CI is timing out.
84+
@info "Tests for $(adbackend) took $(time() - start) seconds"
7585
end
7686
@testset "variational optimisers" begin
7787
include("variational/optimisers.jl")

0 commit comments

Comments
 (0)