Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.16.5"
version = "0.16.6"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -34,7 +34,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "3.2"
AdvancedHMC = "0.2.24"
AdvancedHMC = "0.3.0"
AdvancedMH = "0.6"
AdvancedPS = "0.2.4"
AdvancedVI = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion src/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ function gibbs_state(
state.z.θ .= θ_old
z = state.z

return HMCState(varinfo, state.i, state.traj, hamiltonian, z, state.adaptor)
return HMCState(varinfo, state.i, state.kernel, hamiltonian, z, state.adaptor)
end

"""
Expand Down
38 changes: 19 additions & 19 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

struct HMCState{
TV<:AbstractVarInfo,
TTraj<:AHMC.AbstractTrajectory,
TKernel<:AHMC.HMCKernel,
THam<:AHMC.Hamiltonian,
PhType<:AHMC.PhasePoint,
TAdapt<:AHMC.Adaptation.AbstractAdaptor,
}
vi::TV
i::Int
traj::TTraj
kernel::TKernel
hamiltonian::THam
z::PhType
adaptor::TAdapt
Expand Down Expand Up @@ -190,18 +190,18 @@ function DynamicPPL.initialstep(
ϵ = spl.alg.ϵ
end

# Generate a trajectory.
traj = gen_traj(spl.alg, ϵ)
# Generate a kernel.
kernel = make_ahmc_kernel(spl.alg, ϵ)

# Create initial transition and state.
# Already perform one step since otherwise we don't get any statistics.
t = AHMC.step(rng, hamiltonian, traj, z)
t = AHMC.transition(rng, hamiltonian, kernel, z)

# Adaptation
adaptor = AHMCAdaptor(spl.alg, hamiltonian.metric; ϵ=ϵ)
if spl.alg isa AdaptiveHamiltonian
hamiltonian, traj, _ =
AHMC.adapt!(hamiltonian, traj, adaptor,
hamiltonian, kernel, _ =
AHMC.adapt!(hamiltonian, kernel, adaptor,
1, nadapts, t.z.θ, t.stat.acceptance_rate)
end

Expand All @@ -215,7 +215,7 @@ function DynamicPPL.initialstep(
end

transition = HMCTransition(vi, t)
state = HMCState(vi, 1, traj, hamiltonian, t.z, adaptor)
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)

return transition, state
end
Expand All @@ -234,16 +234,16 @@ function AbstractMCMC.step(
# Compute transition.
hamiltonian = state.hamiltonian
z = state.z
t = AHMC.step(rng, hamiltonian, state.traj, z)
t = AHMC.transition(rng, hamiltonian, state.kernel, z)

# Adaptation
i = state.i + 1
if spl.alg isa AdaptiveHamiltonian
hamiltonian, traj, _ =
AHMC.adapt!(hamiltonian, state.traj, state.adaptor,
hamiltonian, kernel, _ =
AHMC.adapt!(hamiltonian, state.kernel, state.adaptor,
i, nadapts, t.z.θ, t.stat.acceptance_rate)
else
traj = state.traj
kernel = state.kernel
end

# Update variables
Expand All @@ -255,7 +255,7 @@ function AbstractMCMC.step(

# Compute next transition and state.
transition = HMCTransition(vi, t)
newstate = HMCState(vi, i, traj, hamiltonian, t.z, state.adaptor)
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)

return transition, newstate
end
Expand Down Expand Up @@ -459,9 +459,9 @@ function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state)
return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc))
end

gen_traj(alg::HMC, ϵ) = AHMC.StaticTrajectory(AHMC.Leapfrog(ϵ), alg.n_leapfrog)
gen_traj(alg::HMCDA, ϵ) = AHMC.HMCDA(AHMC.Leapfrog(ϵ), alg.λ)
gen_traj(alg::NUTS, ϵ) = AHMC.NUTS(AHMC.Leapfrog(ϵ), alg.max_depth, alg.Δ_max)
make_ahmc_kernel(alg::HMC, ϵ) = AHMC.StaticTrajectory(AHMC.Leapfrog(ϵ), alg.n_leapfrog)
make_ahmc_kernel(alg::HMCDA, ϵ) = AHMC.HMCDA(AHMC.Leapfrog(ϵ), alg.λ)
make_ahmc_kernel(alg::NUTS, ϵ) = AHMC.NUTS(AHMC.Leapfrog(ϵ), alg.max_depth, alg.Δ_max)

####
#### Compiler interface, i.e. tilde operators.
Expand Down Expand Up @@ -584,14 +584,14 @@ function HMCState(
ϵ = spl.alg.ϵ
end

# Generate a trajectory.
traj = gen_traj(spl.alg, ϵ)
# Generate a kernel.
kernel = make_ahmc_kernel(spl.alg, ϵ)

# Generate a phasepoint. Replaced during sample_init!
h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ.

# Unlink everything.
invlink!(vi, spl)

return HMCState(vi, 0, 0, traj, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ AdvancedPS = "0.2"
AdvancedVI = "0.1"
Clustering = "0.14"
CmdStan = "6.0.8"
Distributions = "0.23.8, 0.24, 0.25"
Distributions = "< 0.25.11"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.12"
Expand Down
2 changes: 1 addition & 1 deletion test/inference/Inference.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testset "io.jl" begin
@testset "inference.jl" begin
# Only test threading if 1.3+.
if VERSION > v"1.2"
@testset "threaded sampling" begin
Expand Down
9 changes: 0 additions & 9 deletions test/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,10 @@
chain = sample(gdemo(1.5, 2.0), alg, 5_000)
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1)

setadsafe(true)

Random.seed!(200)
gibbs = Gibbs(PG(15, :z1, :z2, :z3, :z4), HMC(0.15, 3, :mu1, :mu2))
chain = sample(MoGtest_default, gibbs, 5_000)
check_MoGtest_default(chain, atol=0.15)

setadsafe(false)

Random.seed!(200)
gibbs = Gibbs(PG(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2))
chain = sample(MoGtest_default, gibbs, 5_000)
check_MoGtest_default(chain, atol=0.1)
end

@turing_testset "transitions" begin
Expand Down
20 changes: 15 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,27 @@ include("test_utils/AllUtils.jl")
include("core/ad.jl")
end

@testset "samplers (without AD)" begin
include("inference/AdvancedSMC.jl")
include("inference/emcee.jl")
include("inference/ess.jl")
include("inference/is.jl")
end

Turing.setrdcache(false)
for adbackend in (:forwarddiff, :tracker, :reversediff)
Turing.setadbackend(adbackend)
@info "Testing $(adbackend)"
start = time()
@testset "inference: $adbackend" begin
@testset "samplers" begin
include("inference/gibbs.jl")
include("inference/gibbs_conditional.jl")
include("inference/hmc.jl")
include("inference/is.jl")
include("inference/mh.jl")
include("inference/ess.jl")
include("inference/emcee.jl")
include("inference/AdvancedSMC.jl")
include("inference/Inference.jl")
include("contrib/inference/dynamichmc.jl")
include("contrib/inference/sghmc.jl")
include("inference/mh.jl")
end
end

Expand All @@ -72,6 +77,11 @@ include("test_utils/AllUtils.jl")
@testset "modes" begin
include("modes/ModeEstimation.jl")
end

# Useful for
# a) discovering performance regressions,
# b) figuring out why CI is timing out.
@info "Tests for $(adbackend) took $(time() - start) seconds"
end
@testset "variational optimisers" begin
include("variational/optimisers.jl")
Expand Down