Skip to content

Commit

Permalink
Introduction of SwapSampler + make TemperedSampler a fancy versio…
Browse files Browse the repository at this point in the history
…n of `CompositionSampler` (#152)

* split the transitions and states field in TemperedState

* improved internals of CompositionSampler

* ongoing work

* added swap sampler

* added ordering specification and a TemperedComposition

* integrated work on TemperedComposition into TemperedSampler and
removed the former

* reorederd stuff so it actually works

* fixed bug in swapping computation

* added length implementation for MultiModel

* improved construct for TemperedSampler and added some convenience methods

* fixed bundle_samples for Chains and TemperedTransition

* fixed breaking bug in setparams_and_logprob!! for SwapState

* remove usage of adapted HMC in tests

* remove doubling of iterations when testing tempering

* fixed bugs with MALA and tempering

* relax atol a bit for HMC

* relax another atol

* TemperedComposition is now truly just a wrapper around a CompositionSampler

* added method for computing roundtrips

* fixed testing + added test for roundtrips

* added docs for roundtrips method

* added some tests for SwapSampler without tempering

* remove ordering from SwapSampler since it should only interact with ProcessOrdering

* simplified the sorting according to chains and processes

* added some comments

* some minor refactoring

* some refactoring + TemperedSampler now orders the samplers correctly

* remove expected_ordering and make ordering assumptions more explicit

* relax type-constraints in state_for_chain so it also works with TemperedState

* removed redundant implementations of swap_attempt

* rename swap_betas! to swap!

* moved swap_attempt as it now requires definition of SwapSampler

* removed unnecessary setparams_and_logprob!! that should never be hit
with the current codebase

* removed expected_order

* Apply suggestions from code review

Co-authored-by: Harrison Wilde <harrisondwilde@outlook.com>

* removed unnecessary variable in tests

* Update src/sampler.jl

Co-authored-by: Harrison Wilde <harrisondwilde@outlook.com>

* Apply suggestions from code review

Co-authored-by: Harrison Wilde <harrisondwilde@outlook.com>

* removed burn-in from step in prep for AbstractMCMC improvements

* remove getparams_and_logprob implementation for SwapState as it's
unclear what is the right approach

* split the transitions and states field in TemperedState

* improved internals of CompositionSampler

* ongoing work

* added swap sampler

* added ordering specification and a TemperedComposition

* integrated work on TemperedComposition into TemperedSampler and
removed the former

* reorederd stuff so it actually works

* fixed bug in swapping computation

* added length implementation for MultiModel

* improved construct for TemperedSampler and added some convenience methods

* fixed bundle_samples for Chains and TemperedTransition

* fixed breaking bug in setparams_and_logprob!! for SwapState

* remove usage of adapted HMC in tests

* remove doubling of iterations when testing tempering

* fixed bugs with MALA and tempering

* relax atol a bit for HMC

* relax another atol

* TemperedComposition is now truly just a wrapper around a CompositionSampler

* added method for computing roundtrips

* fixed testing + added test for roundtrips

* added docs for roundtrips method

* added some tests for SwapSampler without tempering

* remove ordering from SwapSampler since it should only interact with ProcessOrdering

* simplified the sorting according to chains and processes

* added some comments

* some minor refactoring

* some refactoring + TemperedSampler now orders the samplers correctly

* remove expected_ordering and make ordering assumptions more explicit

* relax type-constraints in state_for_chain so it also works with TemperedState

* removed redundant implementations of swap_attempt

* rename swap_betas! to swap!

* moved swap_attempt as it now requires definition of SwapSampler

* removed unnecessary setparams_and_logprob!! that should never be hit
with the current codebase

* removed expected_order

* removed unnecessary variable in tests

* Apply suggestions from code review

Co-authored-by: Harrison Wilde <harrisondwilde@outlook.com>

* removed burn-in from step in prep for AbstractMCMC improvements

* remove getparams_and_logprob implementation for SwapState as it's
unclear what is the right approach

* Apply suggestions from code review

Co-authored-by: Harrison Wilde <harrisondwilde@outlook.com>

* added CompositionTransition + quite a few bundle_samples with a
`bundle_resolve_swaps` kwarg to allow converting into chains more easily

* more samples

* reduce requirement for ess comparison for AHMC a bit

* significant improvements to the simple Gaussian example, now testing
using MCSE to get tolerances, etc. and small improvements to the rest
of the tests

* trying to debug these tests

* more debug

* fixed typy

* reduce significance even further

---------

Co-authored-by: Harrison Wilde <harrisondwilde@outlook.com>
  • Loading branch information
torfjelde and HarrisonWilde authored Mar 11, 2023
1 parent ef97a94 commit 180a928
Show file tree
Hide file tree
Showing 18 changed files with 1,087 additions and 450 deletions.
153 changes: 137 additions & 16 deletions src/MCMCTempering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ include("abstractmcmc.jl")
include("adaptation.jl")
include("swapping.jl")
include("state.jl")
include("swapsampler.jl")
include("sampler.jl")
include("sampling.jl")
include("ladders.jl")
include("stepping.jl")
include("model.jl")
include("utils.jl")

export tempered,
tempered_sample,
Expand All @@ -43,21 +45,82 @@ maybe_wrap_model(model) = implements_logdensity(model) ? AbstractMCMC.LogDensity
maybe_wrap_model(model::AbstractMCMC.LogDensityModel) = model

# Bundling.
# TODO: Improve this, somehow.
# TODO: Move this to an extension.
# Bundling of non-tempered samples.
function bundle_nontempered_samples(
ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
::Type{T};
kwargs...
) where {T}
# Create the same model and sampler as we do in the initial step for `TemperedSampler`.
multimodel = MultiModel([
make_tempered_model(sampler, model, sampler.chain_to_beta[i])
for i in 1:numtemps(sampler)
])
multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)])
multitransitions = [
MultipleTransitions(sort_by_chain(ProcessOrder(), t.swaptransition, t.transition.transitions))
for t in ts
]

return AbstractMCMC.bundle_samples(
multitransitions,
multimodel,
multisampler,
MultipleStates(sort_by_chain(ProcessOrder(), state.swapstate, state.state.states)),
T
)
end

function AbstractMCMC.bundle_samples(
ts::Vector{<:MultipleTransitions},
model::MultiModel,
sampler::MultiSampler,
state::MultipleStates,
# TODO: Generalize for any eltype `T`? Then need to overload for `Real`, etc.?
::Type{Vector{MCMCChains.Chains}};
kwargs...
)
return map(1:length(model), model.models, sampler.samplers, state.states) do i, model, sampler, state
AbstractMCMC.bundle_samples([t.transitions[i] for t in ts], model, sampler, state, MCMCChains.Chains; kwargs...)
end
end

# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
function AbstractMCMC.bundle_samples(
ts::Vector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
::Type{Vector{T}};
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
if bundle_resolve_swaps
return bundle_nontempered_samples(ts, model, sampler, state, Vector{T}; kwargs...)
end

# TODO: Do better?
return ts
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:TemperedTransition},
ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
::Type{MCMCChains.Chains};
kwargs...
)
# Extract the transitions ordered, which are ordered according to processes, according to the chains.
ts_actual = [t.transition.transitions[first(t.swaptransition.chain_to_process)] for t in ts]
return AbstractMCMC.bundle_samples(
map(Base.Fix2(getproperty, :transition), filter(!Base.Fix2(getproperty, :is_swap), ts)), # Remove the swaps.
ts_actual,
model,
sampler_for_chain(sampler, state),
state_for_chain(state),
sampler_for_chain(sampler, state, 1),
state_for_chain(state, 1),
MCMCChains.Chains;
kwargs...
)
Expand All @@ -68,27 +131,85 @@ function AbstractMCMC.bundle_samples(
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler,
state::CompositionState,
::Type{MCMCChains.Chains};
::Type{T};
kwargs...
)
) where {T}
# In the case of `!saveall(sampler)`, the state is not a `CompositionTransition` so we just propagate
# the transitions to the `bundle_samples` for the outer stuff. Otherwise, we flatten the transitions.
ts_actual = saveall(sampler) ? mapreduce(t -> [inner_transition(t), outer_transition(t)], vcat, ts) : ts
# TODO: Should we really always default to outer sampler?
return AbstractMCMC.bundle_samples(
ts, model, sampler.sampler_outer, state.state_outer, MCMCChains.Chains;
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
kwargs...
)
end

# Unflatten in the case of `SequentialTransitions`
# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:SequentialTransitions},
ts::Vector,
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler,
state::SequentialStates,
::Type{MCMCChains.Chains};
state::CompositionState,
::Type{Vector{T}};
kwargs...
)
ts_actual = [t for tseq in ts for t in tseq.transitions]
) where {T}
if !saveall(sampler)
# In this case, we just use the `outer` for everything since this is the only
# transitions we're keeping around.
return AbstractMCMC.bundle_samples(
ts, model, sampler.sampler_outer, state.state_outer, Vector{T};
kwargs...
)
end

# Otherwise, we don't know what to do.
return ts
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
state::CompositionState{<:MultipleStates,<:SwapState},
::Type{T};
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
!bundle_resolve_swaps && return ts

# Resolve the swaps.
sampler_without_saveall = @set sampler.sampler_inner.saveall = Val(false)
ts_actual = map(ts) do t
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
end

AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
kwargs...
)
end

# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
function AbstractMCMC.bundle_samples(
ts::Vector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
state::CompositionState{<:MultipleStates,<:SwapState},
::Type{Vector{T}};
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
!bundle_resolve_swaps && return ts

# Resolve the swaps (using the already implemented resolution in `composition_transition`
# for this particular sampler but without `saveall`).
sampler_without_saveall = @set sampler.saveall = Val(false)
ts_actual = map(ts) do t
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
end

return AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler_outer, state.states[end], MCMCChains.Chains;
ts_actual, model, sampler.sampler_outer, state.state_outer, Vector{T};
kwargs...
)
end
Expand Down
2 changes: 1 addition & 1 deletion src/adaptation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and
"""
struct Geometric end

defaultscale(::Geometric, Δ) = eltype(Δ)(0.9)
defaultscale(::Geometric, Δ) = float(eltype))(0.9)

"""
InverselyAdditive
Expand Down
4 changes: 1 addition & 3 deletions src/ladders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ end
Checks and returns a sorted `Δ` containing `{β₀, ..., βₙ}` conforming such that `1 = β₀ > β₁ > ... > βₙ ≥ 0`
"""
function check_inverse_temperatures(Δ)
if length(Δ) <= 1
error("More than one inverse temperatures must be provided.")
end
!isempty(Δ) || error("Inverse temperatures array is empty.")
if !all(zero.(Δ) .≤ Δ .≤ one.(Δ))
error("The temperature ladder provided has values outside of the acceptable range, ensure all values are in [0, 1].")
end
Expand Down
107 changes: 80 additions & 27 deletions src/sampler.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
"""
TemperedState
A state for a tempered sampler.
# Fields
$(FIELDS)
"""
@concrete struct TemperedState
"state for swap-sampler"
swapstate
"state for the main sampler"
state
"inverse temperature for each of the chains"
chain_to_beta
end

"""
TemperedSampler <: AbstractMCMC.AbstractSampler
Expand All @@ -7,44 +24,39 @@ A `TemperedSampler` struct wraps a sampler upon which to apply the Parallel Temp
$(FIELDS)
"""
@concrete struct TemperedSampler <: AbstractMCMC.AbstractSampler
Base.@kwdef struct TemperedSampler{SplT,A,SwapT,Adapt} <: AbstractMCMC.AbstractSampler
"sampler(s) used to target the tempered distributions"
sampler
sampler::SplT
"collection of inverse temperatures β; β[i] correponds i-th tempered model"
inverse_temperatures
"number of steps of `sampler` to take before proposing swaps"
swap_every
"the swap strategy that will be used when proposing swaps"
swap_strategy
# TODO: This should be replaced with `P` just being some `NoAdapt` type.
chain_to_beta::A
"strategy to use for swapping"
swapstrategy::SwapT=ReversibleSwap()
# TODO: Remove `adapt` and just consider `adaptation_states=nothing` as no adaptation.
"boolean flag specifying whether or not to adapt"
adapt
adapt=false
"adaptation parameters"
adaptation_states
adaptation_states::Adapt=nothing
end

swapstrategy(sampler::TemperedSampler) = sampler.swap_strategy
TemperedSampler(sampler, chain_to_beta; kwargs...) = TemperedSampler(; sampler, chain_to_beta, kwargs...)

swapsampler(sampler::TemperedSampler) = SwapSampler(sampler.swapstrategy)

# TODO: Do we need this now?
getsampler(samplers, I...) = getindex(samplers, I...)
getsampler(sampler::AbstractMCMC.AbstractSampler, I...) = sampler
getsampler(sampler::TemperedSampler, I...) = getsampler(sampler.sampler, I...)

"""
numsteps(sampler::TemperedSampler)
Return number of inverse temperatures used by `sampler`.
"""
numtemps(sampler::TemperedSampler) = length(sampler.inverse_temperatures)
chain_to_process(state::TemperedState, I...) = chain_to_process(state.swapstate, I...)
process_to_chain(state::TemperedState, I...) = process_to_chain(state.swapstate, I...)

"""
sampler_for_chain(sampler::TemperedSampler, state::TemperedState[, I...])
sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...)
Return the sampler corresponding to the chain indexed by `I...`.
If `I...` is not specified, the sampler corresponding to `β=1.0` will be returned.
"""
sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1)
function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...)
return getsampler(sampler.sampler, chain_to_process(state, I...))
return sampler_for_process(sampler, state, chain_to_process(state, I...))
end

"""
Expand All @@ -53,9 +65,51 @@ end
Return the sampler corresponding to the process indexed by `I...`.
"""
function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...)
return getsampler(sampler.sampler, I...)
return _sampler_for_process_temper(sampler.sampler, state, I...)
end

# If `sampler` is a `MultiSampler`, we assume it's ordered according to chains.
_sampler_for_process_temper(sampler::MultiSampler, state, I...) = sampler.samplers[process_to_chain(state, I...)]
# Otherwise, we just use the same sampler for everything.
_sampler_for_process_temper(sampler, state, I...) = sampler

# Defer extracting the corresponding state to the `swapstate`.
state_for_process(state::TemperedState, I...) = state_for_process(state.swapstate, I...)

# Here we make the model(s) using the temperatures.
function model_for_process(sampler::TemperedSampler, model, state::TemperedState, I...)
return make_tempered_model(sampler, model, beta_for_process(state, I...))
end

"""
beta_for_chain(state[, I...])
Return the β corresponding to the chain indexed by `I...`.
If `I...` is not specified, the β corresponding to `β=1.0` will be returned.
"""
beta_for_chain(state::TemperedState) = beta_for_chain(state, 1)
beta_for_chain(state::TemperedState, I...) = beta_for_chain(state.chain_to_beta, I...)
# NOTE: Array impl. is useful for testing.
beta_for_chain(chain_to_beta::AbstractArray, I...) = chain_to_beta[I...]

"""
beta_for_process(state, I...)
Return the β corresponding to the process indexed by `I...`.
"""
beta_for_process(state::TemperedState, I...) = beta_for_process(state.chain_to_beta, state.swapstate.process_to_chain, I...)
# NOTE: Array impl. is useful for testing.
function beta_for_process(chain_to_beta::AbstractArray, proc2chain::AbstractArray, I...)
return beta_for_chain(chain_to_beta, process_to_chain(proc2chain, I...))
end

"""
numsteps(sampler::TemperedSampler)
Return number of inverse temperatures used by `sampler`.
"""
numtemps(sampler::TemperedSampler) = length(sampler.chain_to_beta)

"""
tempered(sampler, inverse_temperatures; kwargs...)
OR
Expand Down Expand Up @@ -99,7 +153,7 @@ function tempered(
inverse_temperatures::Vector{<:Real};
swap_strategy::AbstractSwapStrategy=ReversibleSwap(),
# TODO: Change `swap_every` to something like `number_of_iterations_per_swap`.
swap_every::Integer=1,
steps_per_swap::Integer=1,
adapt::Bool=false,
adapt_target::Real=0.234,
adapt_stepsize::Real=1,
Expand All @@ -109,14 +163,13 @@ function tempered(
kwargs...
)
!(adapt && typeof(swap_strategy) <: Union{RandomSwap, SingleRandomSwap}) || error("Adaptation of the inverse temperature ladder is not currently supported under the chosen swap strategy.")
swap_every 1 || error("`swap_every` must take a positive integer value.")
steps_per_swap > 0 || error("`steps_per_swap` must take a positive integer value.")
inverse_temperatures = check_inverse_temperatures(inverse_temperatures)
adaptation_states = init_adaptation(
adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize
)
# NOTE: We just make a repeated sampler for `sampler_inner`.
# TODO: Generalize. Allow passing in a `MultiSampler`, etc.
sampler_inner = sampler^swap_every
# FIXME: Remove the hard-coded `2` for swap-every, and change `should_swap` acoordingly.
return TemperedSampler(sampler_inner, inverse_temperatures, 2, swap_strategy, adapt, adaptation_states)
sampler_inner = sampler^steps_per_swap
return TemperedSampler(sampler_inner, inverse_temperatures, swap_strategy, adapt, adaptation_states)
end
Loading

0 comments on commit 180a928

Please sign in to comment.