Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kwarg nadapt #332

Merged
merged 5 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ function AbstractMCMC.sample(
model::LogDensityModel,
sampler::AbstractHMCSampler,
N::Integer;
n_adapts::Int = min(div(N, 10), 1_000),
progress = true,
verbose = false,
callback = nothing,
Expand All @@ -52,6 +53,7 @@ function AbstractMCMC.sample(
model,
sampler,
N;
n_adapts = n_adapts,
progress = progress,
verbose = verbose,
callback = callback,
Expand All @@ -66,6 +68,7 @@ function AbstractMCMC.sample(
parallel::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
nchains::Integer;
n_adapts::Int = min(div(N, 10), 1_000),
progress = true,
verbose = false,
callback = nothing,
Expand All @@ -84,6 +87,7 @@ function AbstractMCMC.sample(
parallel,
N,
nchains;
n_adapts = n_adapts,
progress = progress,
verbose = verbose,
callback = callback,
Expand Down Expand Up @@ -150,7 +154,7 @@ function AbstractMCMC.step(

# Adapt h and spl.
tstat = stat(t)
n_adapts = get_nadapts(spl)
n_adapts = kwargs[:n_adapts]
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt = isadapted,))

Expand Down Expand Up @@ -336,11 +340,6 @@ end

#########

get_nadapts(spl::Union{HMCSampler,NUTS,HMCDA}) = spl.n_adapts
get_nadapts(spl::HMC) = 0

#########

function make_kernel(spl::NUTS, integrator::AbstractIntegrator)
return HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
end
Expand Down
11 changes: 3 additions & 8 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ NUTS(n_adapts=1000, δ=0.65) # Use 1000 adaption steps, and target accept ratio
```
"""
struct NUTS{T<:Real} <: AbstractHMCSampler{T}
"Number of adaptation steps."
n_adapts::Int
"Target acceptance rate for dual averaging."
δ::T
"Maximum doubling tree depth."
Expand All @@ -73,7 +71,6 @@ struct NUTS{T<:Real} <: AbstractHMCSampler{T}
end

function NUTS(
n_adapts,
δ;
max_depth = 10,
Δ_max = 1000.0,
Expand All @@ -82,7 +79,7 @@ function NUTS(
metric = :diagonal,
)
T = typeof(δ)
return NUTS(n_adapts, δ, max_depth, T(Δ_max), T(init_ϵ), integrator, metric)
return NUTS(δ, max_depth, T(Δ_max), T(init_ϵ), integrator, metric)
end

###########
Expand Down Expand Up @@ -143,8 +140,6 @@ For more information, please view the following paper ([arXiv link](https://arxi
Research 15, no. 1 (2014): 1593-1623.
"""
struct HMCDA{T<:Real} <: AbstractHMCSampler{T}
"`Number of adaptation steps."
n_adapts::Int
"Target acceptance rate for dual averaging."
δ::T
"Target leapfrog length."
Expand All @@ -157,10 +152,10 @@ struct HMCDA{T<:Real} <: AbstractHMCSampler{T}
metric::Union{Symbol,AbstractMetric}
end

function HMCDA(n_adapts, δ, λ; init_ϵ = 0.0, integrator = :leapfrog, metric = :diagonal)
function HMCDA(δ, λ; init_ϵ = 0.0, integrator = :leapfrog, metric = :diagonal)
if typeof(δ) != typeof(λ)
@warn "typeof(δ) != typeof(λ) --> using typeof(δ)"
end
T = typeof(δ)
return HMCDA(n_adapts, δ, T(λ), T(init_ϵ), integrator, metric)
return HMCDA(δ, T(λ), T(init_ϵ), integrator, metric)
end
14 changes: 7 additions & 7 deletions test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ include("common.jl")
n_adapts = 5_000
θ_init = randn(rng, 2)

nuts = NUTS(n_adapts, 0.8)
nuts = NUTS(0.8)
hmc = HMC(0.05, 100)
hmcda = HMCDA(n_adapts, 0.8, 0.1)
hmcda = HMCDA(0.8, 0.1)

integrator = Leapfrog(1e-3)
κ = AdvancedHMC.make_kernel(nuts, integrator)
Expand All @@ -27,7 +27,7 @@ include("common.jl")
model,
nuts,
n_adapts + n_samples;
nadapts = n_adapts,
n_adapts = n_adapts,
init_params = θ_init,
progress = false,
verbose = false,
Expand All @@ -50,7 +50,7 @@ include("common.jl")
model,
hmc,
n_adapts + n_samples;
nadapts = n_adapts,
n_adapts = n_adapts,
init_params = θ_init,
progress = false,
verbose = false,
Expand All @@ -73,7 +73,7 @@ include("common.jl")
model,
custom,
n_adapts + n_samples;
nadapts = n_adapts,
n_adapts = 0,
init_params = θ_init,
progress = false,
verbose = false,
Expand All @@ -99,7 +99,7 @@ include("common.jl")
model,
custom,
10;
nadapts = 0,
n_adapts = 0,
init_params = θ_init,
progress = false,
verbose = false,
Expand All @@ -109,7 +109,7 @@ include("common.jl")
model,
custom,
10;
nadapts = 0,
n_adapts = 0,
init_params = θ_init,
progress = false,
verbose = false,
Expand Down
2 changes: 1 addition & 1 deletion test/adaptation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function runnuts(ℓπ, metric; n_samples = 3_000)
θ_init = rand(D)
rng = MersenneTwister(0)

nuts = NUTS(n_adapts, 0.8)
nuts = NUTS(0.8)
h = Hamiltonian(metric, ℓπ, ForwardDiff)
step_size = AdvancedHMC.make_step_size(rng, nuts, h, θ_init)
integrator = AdvancedHMC.make_integrator(nuts, step_size)
Expand Down
23 changes: 11 additions & 12 deletions test/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ using AdvancedHMC, AbstractMCMC, Random
include("common.jl")

# Initalize samplers
nuts = NUTS(1000, 0.8)
nuts_32 = NUTS(1000, 0.8f0)
nuts = NUTS(0.8)
nuts_32 = NUTS(0.8f0)
hmc = HMC(0.1, 25)
hmcda = HMCDA(1000, 0.8, 1.0)
hmcda_32 = HMCDA(1000, 0.8f0, 1.0)
hmcda = HMCDA(0.8, 1.0)
hmcda_32 = HMCDA(0.8f0, 1.0)

integrator = Leapfrog(1e-3)
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
Expand All @@ -25,7 +25,6 @@ custom = HMCSampler(kernel, metric, adaptor)
@test typeof(nuts) <: AbstractMCMC.AbstractSampler

# NUTS
@test nuts.n_adapts == 1000
@test nuts.δ == 0.8
@test nuts.max_depth == 10
@test nuts.Δ_max == 1000.0
Expand All @@ -34,7 +33,6 @@ custom = HMCSampler(kernel, metric, adaptor)
@test nuts.metric == :diagonal

# NUTS Float32
@test nuts_32.n_adapts == 1000
@test nuts_32.δ == 0.8f0
@test nuts_32.max_depth == 10
@test nuts_32.Δ_max == 1000.0f0
Expand All @@ -47,15 +45,13 @@ custom = HMCSampler(kernel, metric, adaptor)
@test hmc.metric == :diagonal

# HMCDA
@test hmcda.n_adapts == 1000
@test hmcda.δ == 0.8
@test hmcda.λ == 1.0
@test hmcda.init_ϵ == 0.0
@test hmcda.integrator == :leapfrog
@test hmcda.metric == :diagonal

# HMCDA Float32
@test hmcda_32.n_adapts == 1000
@test hmcda_32.δ == 0.8f0
@test hmcda_32.λ == 1.0f0
@test hmcda_32.init_ϵ == 0.0f0
Expand All @@ -65,11 +61,14 @@ end
rng = MersenneTwister(0)
θ_init = randn(rng, 2)
logdensitymodel = AbstractMCMC.LogDensityModel(ℓπ_gdemo)
_, nuts_state = AbstractMCMC.step(rng, logdensitymodel, nuts; init_params = θ_init)
_, hmc_state = AbstractMCMC.step(rng, logdensitymodel, hmc; init_params = θ_init)
_, nuts_state =
AbstractMCMC.step(rng, logdensitymodel, nuts; n_adapts = 0, init_params = θ_init)
_, hmc_state =
AbstractMCMC.step(rng, logdensitymodel, hmc; n_adapts = 0, init_params = θ_init)
_, nuts_32_state =
AbstractMCMC.step(rng, logdensitymodel, nuts_32; init_params = θ_init)
_, custom_state = AbstractMCMC.step(rng, logdensitymodel, custom; init_params = θ_init)
AbstractMCMC.step(rng, logdensitymodel, nuts_32; n_adapts = 0, init_params = θ_init)
_, custom_state =
AbstractMCMC.step(rng, logdensitymodel, custom; n_adapts = 0, init_params = θ_init)

# Metric
@test typeof(nuts_state.metric) == DiagEuclideanMetric{Float64,Vector{Float64}}
Expand Down
2 changes: 1 addition & 1 deletion test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ end
end
end
@testset "drop_warmup" begin
nuts = NUTS(n_adapts, 0.8)
nuts = NUTS(0.8)
metric = DiagEuclideanMetric(D)
h = Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
integrator = Leapfrog(ϵ)
Expand Down
Loading