Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NUTS kernel options #342

Merged
merged 16 commits into from
Jul 28, 2023
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.5.1"
version = "0.5.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
13 changes: 10 additions & 3 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function AbstractMCMC.step(

# Define integration algorithm
# Find good eps if not provided one
init_params = make_init_params(spl, logdensity, init_params)
init_params = make_init_params(rng, spl, logdensity, init_params)
ϵ = make_step_size(rng, spl, hamiltonian, init_params)
integrator = make_integrator(spl, ϵ)

Expand Down Expand Up @@ -251,7 +251,12 @@ end
#############
### Utils ###
#############
function make_init_params(spl::AbstractHMCSampler, logdensity, init_params)
function make_init_params(
rng::AbstractRNG,
spl::AbstractHMCSampler,
logdensity,
init_params,
)
T = sampler_eltype(spl)
if init_params == nothing
d = LogDensityProblems.dimension(logdensity)
Expand Down Expand Up @@ -354,7 +359,9 @@ end
#########

function make_kernel(spl::NUTS, integrator::AbstractIntegrator)
return HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
return HMCKernel(
Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn(spl.max_depth, spl.Δ_max)),
)
end

function make_kernel(spl::HMC, integrator::AbstractIntegrator)
Expand Down
54 changes: 52 additions & 2 deletions test/constructors.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
using AdvancedHMC, AbstractMCMC, Random
include("common.jl")

get_kernel_hyperparams(spl::HMC, state) = state.κ.τ.termination_criterion.L
get_kernel_hyperparams(spl::HMCDA, state) = state.κ.τ.termination_criterion.λ
get_kernel_hyperparams(spl::NUTS, state) =
state.κ.τ.termination_criterion.max_depth, state.κ.τ.termination_criterion.Δ_max

get_kernel_hyperparamsT(spl::HMC, state) = typeof(state.κ.τ.termination_criterion.L)
get_kernel_hyperparamsT(spl::HMCDA, state) = typeof(state.κ.τ.termination_criterion.λ)
get_kernel_hyperparamsT(spl::NUTS, state) = typeof(state.κ.τ.termination_criterion.Δ_max)

@testset "Constructors" begin
d = 2
θ_init = randn(d)
rng = Random.default_rng()
model = AbstractMCMC.LogDensityModel(ℓπ_gdemo)

@testset "$T" for T in [Float32, Float64]
Expand All @@ -14,6 +24,7 @@ include("common.jl")
adaptor_type = NoAdaptation,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = 25,
),
),
(
Expand All @@ -22,6 +33,7 @@ include("common.jl")
adaptor_type = NoAdaptation,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = 25,
),
),
(
Expand All @@ -30,6 +42,7 @@ include("common.jl")
adaptor_type = NoAdaptation,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = 25,
),
),
(
Expand All @@ -38,6 +51,7 @@ include("common.jl")
adaptor_type = NoAdaptation,
metric_type = UnitEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = 25,
),
),
(
Expand All @@ -46,6 +60,7 @@ include("common.jl")
adaptor_type = NoAdaptation,
metric_type = DenseEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = 25,
),
),
(
Expand All @@ -54,6 +69,7 @@ include("common.jl")
adaptor_type = NesterovDualAveraging,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = one(T),
),
),
# This should perform the correct promotion for the 2nd argument.
Expand All @@ -63,14 +79,16 @@ include("common.jl")
adaptor_type = NesterovDualAveraging,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = one(T),
),
),
(
NUTS(T(0.8)),
NUTS(T(0.8); max_depth = 20, Δ_max = T(2000.0)),
(
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = (20, T(2000.0)),
),
),
(
Expand All @@ -79,6 +97,7 @@ include("common.jl")
adaptor_type = StanHMCAdaptor,
metric_type = UnitEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = (10, T(1000.0)),
),
),
(
Expand All @@ -87,6 +106,7 @@ include("common.jl")
adaptor_type = StanHMCAdaptor,
metric_type = DenseEuclideanMetric{T},
integrator_type = Leapfrog{T},
kernel_hp = (10, T(1000.0)),
),
),
(
Expand All @@ -95,6 +115,7 @@ include("common.jl")
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
integrator_type = JitteredLeapfrog{T,T},
kernel_hp = (10, T(1000.0)),
),
),
(
Expand All @@ -103,14 +124,14 @@ include("common.jl")
adaptor_type = StanHMCAdaptor,
metric_type = DiagEuclideanMetric{T},
integrator_type = TemperedLeapfrog{T,T},
kernel_hp = (10, T(1000.0)),
),
),
]
# Make sure the sampler element type is preserved.
@test AdvancedHMC.sampler_eltype(sampler) == T

# Step.
rng = Random.default_rng()
transition, state =
AbstractMCMC.step(rng, model, sampler; n_adapts = 0, init_params = θ_init)

Expand All @@ -126,6 +147,35 @@ include("common.jl")
@test AdvancedHMC.getmetric(state) isa expected.metric_type
@test AdvancedHMC.getintegrator(state) isa expected.integrator_type
@test AdvancedHMC.getadaptor(state) isa expected.adaptor_type

# Verify that the kernel is receiving the hyperparameters
@test get_kernel_hyperparams(sampler, state) == expected.kernel_hp
if typeof(sampler) <: HMC
@test get_kernel_hyperparamsT(sampler, state) == Int64
else
@test get_kernel_hyperparamsT(sampler, state) == T
end
end
end
end

@testset "Utils" begin
@testset "init_params" begin
d = 2
θ_init = randn(d)
rng = Random.default_rng()
model = AbstractMCMC.LogDensityModel(ℓπ_gdemo)
logdensity = model.logdensity
spl = NUTS(0.8)
T = AdvancedHMC.sampler_eltype(spl)

metric = make_metric(spl, logdensity)
hamiltonian = Hamiltonian(metric, model)

init_params1 = make_init_params(rng, spl, logdensity, nothing)
@test typeof(init_params1) == Vector{T}
@test length(init_params1) == d
init_params2 = make_init_params(rng, spl, logdensity, θ_init)
@test init_params2 === θ_init
end
end
Loading