From 180a92881857676c377efe957daef4a2b120325c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 11 Mar 2023 11:42:46 +0000 Subject: [PATCH] Introduction of `SwapSampler` + make `TemperedSampler` a fancy version of `CompositionSampler` (#152) * split the transitions and states field in TemperedState * improved internals of CompositionSampler * ongoing work * added swap sampler * added ordering specification and a TemperedComposition * integrated work on TemperedComposition into TemperedSampler and removed the former * reorederd stuff so it actually works * fixed bug in swapping computation * added length implementation for MultiModel * improved construct for TemperedSampler and added some convenience methods * fixed bundle_samples for Chains and TemperedTransition * fixed breaking bug in setparams_and_logprob!! for SwapState * remove usage of adapted HMC in tests * remove doubling of iterations when testing tempering * fixed bugs with MALA and tempering * relax atol a bit for HMC * relax another atol * TemperedComposition is now truly just a wrapper around a CompositionSampler * added method for computing roundtrips * fixed testing + added test for roundtrips * added docs for roundtrips method * added some tests for SwapSampler without tempering * remove ordering from SwapSampler since it should only interact with ProcessOrdering * simplified the sorting according to chains and processes * added some comments * some minor refactoring * some refactoring + TemperedSampler now orders the samplers correctly * remove expected_ordering and make ordering assumptions more explicit * relax type-constraints in state_for_chain so it also works with TemperedState * removed redundant implementations of swap_attempt * rename swap_betas! to swap! * moved swap_attempt as it now requires definition of SwapSampler * removed unnecessary setparams_and_logprob!! that should never be hit with the current codebase * removed expected_order * Apply suggestions from code review Co-authored-by: Harrison Wilde * removed unnecessary variable in tests * Update src/sampler.jl Co-authored-by: Harrison Wilde * Apply suggestions from code review Co-authored-by: Harrison Wilde * removed burn-in from step in prep for AbstractMCMC improvements * remove getparams_and_logprob implementation for SwapState as it's unclear what is the right approach * split the transitions and states field in TemperedState * improved internals of CompositionSampler * ongoing work * added swap sampler * added ordering specification and a TemperedComposition * integrated work on TemperedComposition into TemperedSampler and removed the former * reorederd stuff so it actually works * fixed bug in swapping computation * added length implementation for MultiModel * improved construct for TemperedSampler and added some convenience methods * fixed bundle_samples for Chains and TemperedTransition * fixed breaking bug in setparams_and_logprob!! for SwapState * remove usage of adapted HMC in tests * remove doubling of iterations when testing tempering * fixed bugs with MALA and tempering * relax atol a bit for HMC * relax another atol * TemperedComposition is now truly just a wrapper around a CompositionSampler * added method for computing roundtrips * fixed testing + added test for roundtrips * added docs for roundtrips method * added some tests for SwapSampler without tempering * remove ordering from SwapSampler since it should only interact with ProcessOrdering * simplified the sorting according to chains and processes * added some comments * some minor refactoring * some refactoring + TemperedSampler now orders the samplers correctly * remove expected_ordering and make ordering assumptions more explicit * relax type-constraints in state_for_chain so it also works with TemperedState * removed redundant implementations of swap_attempt * rename swap_betas! to swap! * moved swap_attempt as it now requires definition of SwapSampler * removed unnecessary setparams_and_logprob!! that should never be hit with the current codebase * removed expected_order * removed unnecessary variable in tests * Apply suggestions from code review Co-authored-by: Harrison Wilde * removed burn-in from step in prep for AbstractMCMC improvements * remove getparams_and_logprob implementation for SwapState as it's unclear what is the right approach * Apply suggestions from code review Co-authored-by: Harrison Wilde * added CompositionTransition + quite a few bundle_samples with a `bundle_resolve_swaps` kwarg to allow converting into chains more easily * more samples * reduce requirement for ess comparison for AHMC a bit * significant improvements to the simple Gaussian example, now testing using MCSE to get tolerances, etc. and small improvements to the rest of the tests * trying to debug these tests * more debug * fixed typy * reduce significance even further --------- Co-authored-by: Harrison Wilde --- src/MCMCTempering.jl | 153 +++++++++++++++++++++--- src/adaptation.jl | 2 +- src/ladders.jl | 4 +- src/sampler.jl | 107 ++++++++++++----- src/samplers/composition.jl | 75 ++++++------ src/samplers/multi.jl | 14 ++- src/state.jl | 172 ++++++++++++++------------- src/stepping.jl | 217 ++++++++++++++-------------------- src/swapping.jl | 64 +++-------- src/swapsampler.jl | 224 ++++++++++++++++++++++++++++++++++++ src/utils.jl | 34 ++++++ test/Project.toml | 5 +- test/abstractmcmc.jl | 41 +++++-- test/compat.jl | 20 +++- test/runtests.jl | 203 +++++++++++++++++++------------- test/setup.jl | 2 +- test/simple_gaussian.jl | 74 ++++++++++++ test/test_utils.jl | 126 ++++++++++++++++++++ 18 files changed, 1087 insertions(+), 450 deletions(-) create mode 100644 src/swapsampler.jl create mode 100644 src/utils.jl create mode 100644 test/simple_gaussian.jl create mode 100644 test/test_utils.jl diff --git a/src/MCMCTempering.jl b/src/MCMCTempering.jl index a63b890..cb658d5 100644 --- a/src/MCMCTempering.jl +++ b/src/MCMCTempering.jl @@ -20,11 +20,13 @@ include("abstractmcmc.jl") include("adaptation.jl") include("swapping.jl") include("state.jl") +include("swapsampler.jl") include("sampler.jl") include("sampling.jl") include("ladders.jl") include("stepping.jl") include("model.jl") +include("utils.jl") export tempered, tempered_sample, @@ -43,21 +45,82 @@ maybe_wrap_model(model) = implements_logdensity(model) ? AbstractMCMC.LogDensity maybe_wrap_model(model::AbstractMCMC.LogDensityModel) = model # Bundling. -# TODO: Improve this, somehow. -# TODO: Move this to an extension. +# Bundling of non-tempered samples. +function bundle_nontempered_samples( + ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}}, + model::AbstractMCMC.AbstractModel, + sampler::TemperedSampler, + state::TemperedState, + ::Type{T}; + kwargs... +) where {T} + # Create the same model and sampler as we do in the initial step for `TemperedSampler`. + multimodel = MultiModel([ + make_tempered_model(sampler, model, sampler.chain_to_beta[i]) + for i in 1:numtemps(sampler) + ]) + multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)]) + multitransitions = [ + MultipleTransitions(sort_by_chain(ProcessOrder(), t.swaptransition, t.transition.transitions)) + for t in ts + ] + + return AbstractMCMC.bundle_samples( + multitransitions, + multimodel, + multisampler, + MultipleStates(sort_by_chain(ProcessOrder(), state.swapstate, state.state.states)), + T + ) +end + +function AbstractMCMC.bundle_samples( + ts::Vector{<:MultipleTransitions}, + model::MultiModel, + sampler::MultiSampler, + state::MultipleStates, + # TODO: Generalize for any eltype `T`? Then need to overload for `Real`, etc.? + ::Type{Vector{MCMCChains.Chains}}; + kwargs... +) + return map(1:length(model), model.models, sampler.samplers, state.states) do i, model, sampler, state + AbstractMCMC.bundle_samples([t.transitions[i] for t in ts], model, sampler, state, MCMCChains.Chains; kwargs...) + end +end + +# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118 +function AbstractMCMC.bundle_samples( + ts::Vector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}}, + model::AbstractMCMC.AbstractModel, + sampler::TemperedSampler, + state::TemperedState, + ::Type{Vector{T}}; + bundle_resolve_swaps::Bool=false, + kwargs... +) where {T} + if bundle_resolve_swaps + return bundle_nontempered_samples(ts, model, sampler, state, Vector{T}; kwargs...) + end + + # TODO: Do better? + return ts +end + function AbstractMCMC.bundle_samples( - ts::AbstractVector{<:TemperedTransition}, + ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}}, model::AbstractMCMC.AbstractModel, sampler::TemperedSampler, state::TemperedState, ::Type{MCMCChains.Chains}; kwargs... ) + # Extract the transitions ordered, which are ordered according to processes, according to the chains. + ts_actual = [t.transition.transitions[first(t.swaptransition.chain_to_process)] for t in ts] return AbstractMCMC.bundle_samples( - map(Base.Fix2(getproperty, :transition), filter(!Base.Fix2(getproperty, :is_swap), ts)), # Remove the swaps. + ts_actual, model, - sampler_for_chain(sampler, state), - state_for_chain(state), + sampler_for_chain(sampler, state, 1), + state_for_chain(state, 1), MCMCChains.Chains; kwargs... ) @@ -68,27 +131,85 @@ function AbstractMCMC.bundle_samples( model::AbstractMCMC.AbstractModel, sampler::CompositionSampler, state::CompositionState, - ::Type{MCMCChains.Chains}; + ::Type{T}; kwargs... -) +) where {T} + # In the case of `!saveall(sampler)`, the state is not a `CompositionTransition` so we just propagate + # the transitions to the `bundle_samples` for the outer stuff. Otherwise, we flatten the transitions. + ts_actual = saveall(sampler) ? mapreduce(t -> [inner_transition(t), outer_transition(t)], vcat, ts) : ts + # TODO: Should we really always default to outer sampler? return AbstractMCMC.bundle_samples( - ts, model, sampler.sampler_outer, state.state_outer, MCMCChains.Chains; + ts_actual, model, sampler.sampler_outer, state.state_outer, T; kwargs... ) end -# Unflatten in the case of `SequentialTransitions` +# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118 function AbstractMCMC.bundle_samples( - ts::AbstractVector{<:SequentialTransitions}, + ts::Vector, model::AbstractMCMC.AbstractModel, sampler::CompositionSampler, - state::SequentialStates, - ::Type{MCMCChains.Chains}; + state::CompositionState, + ::Type{Vector{T}}; kwargs... -) - ts_actual = [t for tseq in ts for t in tseq.transitions] +) where {T} + if !saveall(sampler) + # In this case, we just use the `outer` for everything since this is the only + # transitions we're keeping around. + return AbstractMCMC.bundle_samples( + ts, model, sampler.sampler_outer, state.state_outer, Vector{T}; + kwargs... + ) + end + + # Otherwise, we don't know what to do. + return ts +end + +function AbstractMCMC.bundle_samples( + ts::AbstractVector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}}, + model::AbstractMCMC.AbstractModel, + sampler::CompositionSampler{<:MultiSampler,<:SwapSampler}, + state::CompositionState{<:MultipleStates,<:SwapState}, + ::Type{T}; + bundle_resolve_swaps::Bool=false, + kwargs... +) where {T} + !bundle_resolve_swaps && return ts + + # Resolve the swaps. + sampler_without_saveall = @set sampler.sampler_inner.saveall = Val(false) + ts_actual = map(ts) do t + composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t)) + end + + AbstractMCMC.bundle_samples( + ts_actual, model, sampler.sampler_outer, state.state_outer, T; + kwargs... + ) +end + +# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118 +function AbstractMCMC.bundle_samples( + ts::Vector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}}, + model::AbstractMCMC.AbstractModel, + sampler::CompositionSampler{<:MultiSampler,<:SwapSampler}, + state::CompositionState{<:MultipleStates,<:SwapState}, + ::Type{Vector{T}}; + bundle_resolve_swaps::Bool=false, + kwargs... +) where {T} + !bundle_resolve_swaps && return ts + + # Resolve the swaps (using the already implemented resolution in `composition_transition` + # for this particular sampler but without `saveall`). + sampler_without_saveall = @set sampler.saveall = Val(false) + ts_actual = map(ts) do t + composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t)) + end + return AbstractMCMC.bundle_samples( - ts_actual, model, sampler.sampler_outer, state.states[end], MCMCChains.Chains; + ts_actual, model, sampler.sampler_outer, state.state_outer, Vector{T}; kwargs... ) end diff --git a/src/adaptation.jl b/src/adaptation.jl index 5d80a9d..134ad6a 100644 --- a/src/adaptation.jl +++ b/src/adaptation.jl @@ -18,7 +18,7 @@ See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and """ struct Geometric end -defaultscale(::Geometric, Δ) = eltype(Δ)(0.9) +defaultscale(::Geometric, Δ) = float(eltype(Δ))(0.9) """ InverselyAdditive diff --git a/src/ladders.jl b/src/ladders.jl index 0ebf615..28305d1 100644 --- a/src/ladders.jl +++ b/src/ladders.jl @@ -37,9 +37,7 @@ end Checks and returns a sorted `Δ` containing `{β₀, ..., βₙ}` conforming such that `1 = β₀ > β₁ > ... > βₙ ≥ 0` """ function check_inverse_temperatures(Δ) - if length(Δ) <= 1 - error("More than one inverse temperatures must be provided.") - end + !isempty(Δ) || error("Inverse temperatures array is empty.") if !all(zero.(Δ) .≤ Δ .≤ one.(Δ)) error("The temperature ladder provided has values outside of the acceptable range, ensure all values are in [0, 1].") end diff --git a/src/sampler.jl b/src/sampler.jl index 27f8e26..7ecd097 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -1,3 +1,20 @@ +""" + TemperedState + +A state for a tempered sampler. + +# Fields +$(FIELDS) +""" +@concrete struct TemperedState + "state for swap-sampler" + swapstate + "state for the main sampler" + state + "inverse temperature for each of the chains" + chain_to_beta +end + """ TemperedSampler <: AbstractMCMC.AbstractSampler @@ -7,44 +24,39 @@ A `TemperedSampler` struct wraps a sampler upon which to apply the Parallel Temp $(FIELDS) """ -@concrete struct TemperedSampler <: AbstractMCMC.AbstractSampler +Base.@kwdef struct TemperedSampler{SplT,A,SwapT,Adapt} <: AbstractMCMC.AbstractSampler "sampler(s) used to target the tempered distributions" - sampler + sampler::SplT "collection of inverse temperatures β; β[i] correponds i-th tempered model" - inverse_temperatures - "number of steps of `sampler` to take before proposing swaps" - swap_every - "the swap strategy that will be used when proposing swaps" - swap_strategy - # TODO: This should be replaced with `P` just being some `NoAdapt` type. + chain_to_beta::A + "strategy to use for swapping" + swapstrategy::SwapT=ReversibleSwap() + # TODO: Remove `adapt` and just consider `adaptation_states=nothing` as no adaptation. "boolean flag specifying whether or not to adapt" - adapt + adapt=false "adaptation parameters" - adaptation_states + adaptation_states::Adapt=nothing end -swapstrategy(sampler::TemperedSampler) = sampler.swap_strategy +TemperedSampler(sampler, chain_to_beta; kwargs...) = TemperedSampler(; sampler, chain_to_beta, kwargs...) + +swapsampler(sampler::TemperedSampler) = SwapSampler(sampler.swapstrategy) +# TODO: Do we need this now? getsampler(samplers, I...) = getindex(samplers, I...) getsampler(sampler::AbstractMCMC.AbstractSampler, I...) = sampler getsampler(sampler::TemperedSampler, I...) = getsampler(sampler.sampler, I...) -""" - numsteps(sampler::TemperedSampler) - -Return number of inverse temperatures used by `sampler`. -""" -numtemps(sampler::TemperedSampler) = length(sampler.inverse_temperatures) +chain_to_process(state::TemperedState, I...) = chain_to_process(state.swapstate, I...) +process_to_chain(state::TemperedState, I...) = process_to_chain(state.swapstate, I...) """ - sampler_for_chain(sampler::TemperedSampler, state::TemperedState[, I...]) + sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...) Return the sampler corresponding to the chain indexed by `I...`. -If `I...` is not specified, the sampler corresponding to `β=1.0` will be returned. """ -sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1) function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...) - return getsampler(sampler.sampler, chain_to_process(state, I...)) + return sampler_for_process(sampler, state, chain_to_process(state, I...)) end """ @@ -53,9 +65,51 @@ end Return the sampler corresponding to the process indexed by `I...`. """ function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...) - return getsampler(sampler.sampler, I...) + return _sampler_for_process_temper(sampler.sampler, state, I...) end +# If `sampler` is a `MultiSampler`, we assume it's ordered according to chains. +_sampler_for_process_temper(sampler::MultiSampler, state, I...) = sampler.samplers[process_to_chain(state, I...)] +# Otherwise, we just use the same sampler for everything. +_sampler_for_process_temper(sampler, state, I...) = sampler + +# Defer extracting the corresponding state to the `swapstate`. +state_for_process(state::TemperedState, I...) = state_for_process(state.swapstate, I...) + +# Here we make the model(s) using the temperatures. +function model_for_process(sampler::TemperedSampler, model, state::TemperedState, I...) + return make_tempered_model(sampler, model, beta_for_process(state, I...)) +end + +""" + beta_for_chain(state[, I...]) + +Return the β corresponding to the chain indexed by `I...`. +If `I...` is not specified, the β corresponding to `β=1.0` will be returned. +""" +beta_for_chain(state::TemperedState) = beta_for_chain(state, 1) +beta_for_chain(state::TemperedState, I...) = beta_for_chain(state.chain_to_beta, I...) +# NOTE: Array impl. is useful for testing. +beta_for_chain(chain_to_beta::AbstractArray, I...) = chain_to_beta[I...] + +""" + beta_for_process(state, I...) + +Return the β corresponding to the process indexed by `I...`. +""" +beta_for_process(state::TemperedState, I...) = beta_for_process(state.chain_to_beta, state.swapstate.process_to_chain, I...) +# NOTE: Array impl. is useful for testing. +function beta_for_process(chain_to_beta::AbstractArray, proc2chain::AbstractArray, I...) + return beta_for_chain(chain_to_beta, process_to_chain(proc2chain, I...)) +end + +""" + numsteps(sampler::TemperedSampler) + +Return number of inverse temperatures used by `sampler`. +""" +numtemps(sampler::TemperedSampler) = length(sampler.chain_to_beta) + """ tempered(sampler, inverse_temperatures; kwargs...) OR @@ -99,7 +153,7 @@ function tempered( inverse_temperatures::Vector{<:Real}; swap_strategy::AbstractSwapStrategy=ReversibleSwap(), # TODO: Change `swap_every` to something like `number_of_iterations_per_swap`. - swap_every::Integer=1, + steps_per_swap::Integer=1, adapt::Bool=false, adapt_target::Real=0.234, adapt_stepsize::Real=1, @@ -109,14 +163,13 @@ function tempered( kwargs... ) !(adapt && typeof(swap_strategy) <: Union{RandomSwap, SingleRandomSwap}) || error("Adaptation of the inverse temperature ladder is not currently supported under the chosen swap strategy.") - swap_every ≥ 1 || error("`swap_every` must take a positive integer value.") + steps_per_swap > 0 || error("`steps_per_swap` must take a positive integer value.") inverse_temperatures = check_inverse_temperatures(inverse_temperatures) adaptation_states = init_adaptation( adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize ) # NOTE: We just make a repeated sampler for `sampler_inner`. # TODO: Generalize. Allow passing in a `MultiSampler`, etc. - sampler_inner = sampler^swap_every - # FIXME: Remove the hard-coded `2` for swap-every, and change `should_swap` acoordingly. - return TemperedSampler(sampler_inner, inverse_temperatures, 2, swap_strategy, adapt, adaptation_states) + sampler_inner = sampler^steps_per_swap + return TemperedSampler(sampler_inner, inverse_temperatures, swap_strategy, adapt, adaptation_states) end diff --git a/src/samplers/composition.jl b/src/samplers/composition.jl index 94ee691..b8fb153 100644 --- a/src/samplers/composition.jl +++ b/src/samplers/composition.jl @@ -58,21 +58,45 @@ function setparams_and_logprob!!(model, state::CompositionState, params, logprob return @set state.state_outer = setparams_and_logprob!!(model, state.state_outer, params, logprob) end +struct CompositionTransition{S1,S2} + "The outer transition" + transition_outer::S1 + "The inner transition" + transition_inner::S2 +end + +# Useful functions for interacting with composition sampler and states. +inner_sampler(sampler::CompositionSampler) = sampler.sampler_inner +outer_sampler(sampler::CompositionSampler) = sampler.sampler_outer + +inner_state(state::CompositionState) = state.state_inner +outer_state(state::CompositionState) = state.state_outer + +inner_transition(transition::CompositionTransition) = transition.transition_inner +outer_transition(transition::CompositionTransition) = transition.transition_outer +outer_transition(transition) = transition # in case we don't have `saveall` + +# TODO: We really don't need to use `SequentialStates` here, do we? +composition_state(sampler, state_inner, state_outer) = CompositionState(state_outer, state_inner) +function composition_transition(sampler, transition_inner, transition_outer) + return if saveall(sampler) + CompositionTransition(transition_outer, transition_inner) + else + transition_outer + end +end + function AbstractMCMC.step( rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel, sampler::CompositionSampler; kwargs... ) - state_inner_initial = last(AbstractMCMC.step(rng, model, sampler.sampler_inner; kwargs...)) - state_outer_initial = last(AbstractMCMC.step(rng, model, sampler.sampler_outer; kwargs...)) + state_inner_initial = last(AbstractMCMC.step(rng, model, inner_sampler(sampler); kwargs...)) + state_outer_initial = last(AbstractMCMC.step(rng, model, outer_sampler(sampler); kwargs...)) # Create the composition state, and take a full step. - state = if saveall(sampler) - SequentialStates((state_inner_initial, state_outer_initial)) - else - CompositionState(state_outer_initial, state_inner_initial) - end + state = composition_state(sampler, state_inner_initial, state_outer_initial) return AbstractMCMC.step(rng, model, sampler, state; kwargs...) end @@ -80,18 +104,14 @@ end # in place of `CompositionState` and just have one version. # The annoying part here is that we'll have to check `saveall` on every `step` # rather than just for the initial step. - -# NOTE: Version which does keep track of all transitions and states. function AbstractMCMC.step( rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel, sampler::CompositionSampler, - state::SequentialStates; + state; kwargs... ) - @assert length(state.states) == 2 "Composition samplers only support SequentialStates with two states." - - state_inner_prev, state_outer_prev = state.states + state_inner_prev, state_outer_prev = inner_state(state), outer_state(state) # Update the inner state. current_state_inner = state_from(model, state_inner_prev, state_outer_prev) @@ -103,29 +123,8 @@ function AbstractMCMC.step( current_state_outer = state_from(model, state_outer_prev, state_inner) transition_outer, state_outer = AbstractMCMC.step(rng, model, sampler.sampler_outer, current_state_outer; kwargs...) - return SequentialTransitions((transition_inner, transition_outer)), SequentialStates((state_inner, state_outer)) -end - -# NOTE: Version which does NOT keep track of all transitions and states. -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::AbstractMCMC.AbstractModel, - sampler::CompositionSampler, - state::CompositionState; - kwargs... -) - # Update the inner state. - current_state_inner = state_from(model, state.state_inner, state.state_outer) - - # Take a step in the inner sampler. - state_inner = last(AbstractMCMC.step(rng, model, sampler.sampler_inner, current_state_inner; kwargs...)) - - # Take a step in the outer sampler. - current_state_outer = state_from(model, state.state_outer, state_inner) - transition_outer, state_outer = AbstractMCMC.step(rng, model, sampler.sampler_outer, current_state_outer; kwargs...) - - # Create the composition state. - state = CompositionState(state_outer, state_inner) - - return transition_outer, state + return ( + composition_transition(sampler, transition_inner, transition_outer), + composition_state(sampler, state_inner, state_outer) + ) end diff --git a/src/samplers/multi.jl b/src/samplers/multi.jl index db4fa02..f007e00 100644 --- a/src/samplers/multi.jl +++ b/src/samplers/multi.jl @@ -57,6 +57,8 @@ end ×(model1::AbstractMCMC.AbstractModel, model2::MultiModel) = MultiModel(combine(model1, model2.models)) ×(model1::MultiModel, model2::MultiModel) = MultiModel(combine(model1.models, model2.models)) +Base.length(model::MultiModel) = length(model.models) + # TODO: Make these subtypes of `AbstractVector`? """ MultipleTransitions @@ -108,13 +110,13 @@ function getparams_and_logprob(model::MultiModel, state::MultipleStates) return map(first, params_and_logprobs), map(last, params_and_logprobs) end -function setparams_and_logprob!!(state::MultipleStates, params, logprob) - @assert length(params) == length(logprob) == length(state.states) "The number of parameters and log probabilities must match the number of states." - return @set state.states = map(setparams_and_logprob!!, state.states, params, logprob) +function setparams_and_logprob!!(state::MultipleStates, params, logprobs) + @assert length(params) == length(logprobs) == length(state.states) "The number of parameters and log probabilities must match the number of states." + return @set state.states = map(setparams_and_logprob!!, state.states, params, logprobs) end -function setparams_and_logprob!!(model::MultiModel, state::MultipleStates, params, logprob) - @assert length(params) == length(logprob) == length(state.states) "The number of parameters and log probabilities must match the number of states." - return @set state.states = map(setparams_and_logprob!!, model.models, state.states, params, logprob) +function setparams_and_logprob!!(model::MultiModel, state::MultipleStates, params, logprobs) + @assert length(model.models) == length(params) == length(logprobs) == length(state.states) "The number of models, states, parameters, and log probabilities must match." + return @set state.states = map(setparams_and_logprob!!, model.models, state.states, params, logprobs) end # TODO: Clean this up. diff --git a/src/state.jl b/src/state.jl index 753ea72..96e6c98 100644 --- a/src/state.jl +++ b/src/state.jl @@ -1,5 +1,19 @@ """ - TemperedState + ProcessOrder + +Specifies that the `model` should be treated as process-ordered. +""" +struct ProcessOrder end + +""" + ChainOrder + +Specifies that the `model` should be treated as chain-ordered. +""" +struct ChainOrder end + +""" + SwapState A general implementation of a state for a [`TemperedSampler`](@ref). @@ -17,7 +31,7 @@ Moreover, suppose we also have 4 workers/processes for which we run these chains (can also be serial wlog). We can then perform a swap in two different ways: -1. Swap the the _states_ between each process, i.e. permute `transitions_and_states`. +1. Swap the the _states_ between each process, i.e. permute `transitions` and `states`. 2. Swap the _temperatures_ between each process, i.e. permute `chain_to_beta`. (1) is possibly the most intuitive approach since it means that the i-th worker/process @@ -52,72 +66,87 @@ Chains: process_to_chain chain_to_process inverse_temperatures[process_t In this case, the chain `X` can be reconstructed as: ```julia -X[1] = states[1].transitions_and_states[1] -X[2] = states[2].transitions_and_states[2] -X[3] = states[3].transitions_and_states[2] -X[4] = states[4].transitions_and_states[3] -X[5] = states[5].transitions_and_states[3] +X[1] = states[1].states[1] +X[2] = states[2].states[2] +X[3] = states[3].states[2] +X[4] = states[4].states[3] +X[5] = states[5].states[3] ``` +and similarly for the states. + The indices here are exactly those represented by `states[k].chain_to_process[1]`. """ -@concrete struct TemperedState - "collection of `(transition, state)` pairs for each process" - transitions_and_states - "collection of (inverse) temperatures β corresponding to each chain" - chain_to_beta +@concrete struct SwapState + "collection of states for each process" + states "collection indices such that `chain_to_process[i] = j` if the j-th process corresponds to the i-th chain" chain_to_process "collection indices such that `process_chain_to[j] = i` if the i-th chain corresponds to the j-th process" process_to_chain "total number of steps taken" total_steps - "number of burn-in steps taken" - burnin_steps - "contains all necessary information for adaptation of inverse_temperatures" - adaptation_states - "flag which specifies wether this was a swap-step or not" - is_swap "swap acceptance ratios on log-scale" swap_acceptance_ratios end +# TODO: Can we support more? +function SwapState(state::MultipleStates) + process_to_chain = collect(1:length(state.states)) + chain_to_process = copy(process_to_chain) + return SwapState(state.states, chain_to_process, process_to_chain, 1, Dict{Int,Float64}()) +end + +# Defer these to `MultipleStates`. +# TODO: What is the best way to implement these? Should we sort according to the chain indices +# to match the order of the models? +# getparams_and_logprob(state::SwapState) = getparams_and_logprob(MultipleStates(state.states)) +# getparams_and_logprob(model, state::SwapState) = getparams_and_logprob(model, MultipleStates(state.states)) + +function setparams_and_logprob!!(model, state::SwapState, params, logprobs) + # Use the `MultipleStates`'s implementation to update the underlying states. + multistate = setparams_and_logprob!!(model, MultipleStates(state.states), params, logprobs) + # Update the states! + return @set state.states = multistate.states +end + """ - process_to_chain(state, I...) + sort_by_chain(::ChainOrdering, state, xs) + sort_by_chain(::ProcessOrdering, state, xs) -Return the chain index corresponding to the process index `I`. +Return `xs` sorted according to the chain indices, as specified by `state`. """ -process_to_chain(state::TemperedState, I...) = process_to_chain(state.process_to_chain, I...) -# NOTE: Array impl. is useful for testing. -process_to_chain(proc2chain::AbstractArray, I...) = proc2chain[I...] +sort_by_chain(::ChainOrder, ::Any, xs) = xs +sort_by_chain(::ProcessOrder, state, xs) = [xs[chain_to_process(state, i)] for i = 1:length(xs)] +sort_by_chain(::ProcessOrder, state, xs::Tuple) = ntuple(i -> xs[chain_to_process(state, i)], length(xs)) """ - chain_to_process(state, I...) + sort_by_process(::ProcessOrdering, state, xs) + sort_by_process(::ChainOrdering, state, xs) -Return the process index corresponding to the chain index `I`. +Return `xs` sorted according to the process indices, as specified by `state`. """ -chain_to_process(state::TemperedState, I...) = chain_to_process(state.chain_to_process, I...) -# NOTE: Array impl. is useful for testing. -chain_to_process(chain2proc::AbstractArray, I...) = chain2proc[I...] +sort_by_process(::ProcessOrder, ::Any, xs) = xs +sort_by_process(::ChainOrder, state, xs) = [xs[process_to_chain(state, i)] for i = 1:length(xs)] +sort_by_process(::ChainOrder, state, xs::Tuple) = ntuple(i -> xs[process_to_chain(state, i)], length(xs)) """ - transition_for_chain(state[, I...]) + process_to_chain(state, I...) -Return the transition corresponding to the chain indexed by `I...`. -If `I...` is not specified, the transition corresponding to `β=1.0` will be returned, i.e. `I = (1, )`. +Return the chain index corresponding to the process index `I`. """ -transition_for_chain(state::TemperedState) = transition_for_chain(state, 1) -transition_for_chain(state::TemperedState, I...) = transition_for_process(state, chain_to_process(state, I...)) +process_to_chain(state::SwapState, I...) = process_to_chain(state.process_to_chain, I...) +# NOTE: Array impl. is useful for testing. +process_to_chain(proc2chain, I...) = proc2chain[I...] """ - transition_for_process(state, I...) + chain_to_process(state, I...) -Return the transition corresponding to the process indexed by `I...`. +Return the process index corresponding to the chain index `I`. """ -transition_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][1] -function transition_for_process(state::TemperedState{<:Tuple{<:MultipleTransitions,<:MultipleStates}}, I...) - return state.transitions_and_states[1].transitions[I...] -end +chain_to_process(state::SwapState, I...) = chain_to_process(state.chain_to_process, I...) +# NOTE: Array impl. is useful for testing. +chain_to_process(chain2proc, I...) = chain2proc[I...] """ state_for_chain(state[, I...]) @@ -125,72 +154,49 @@ end Return the state corresponding to the chain indexed by `I...`. If `I...` is not specified, the state corresponding to `β=1.0` will be returned. """ -state_for_chain(state::TemperedState) = state_for_chain(state, 1) -state_for_chain(state::TemperedState, I...) = state_for_process(state, chain_to_process(state, I...)) +state_for_chain(state) = state_for_chain(state, 1) +state_for_chain(state, I...) = state_for_process(state, chain_to_process(state, I...)) """ state_for_process(state, I...) Return the state corresponding to the process indexed by `I...`. """ -state_for_process(state::TemperedState, I...) = state.transitions_and_states[I...][2] -function state_for_process(state::TemperedState{<:Tuple{<:MultipleTransitions,<:MultipleStates}}, I...) - return state.transitions_and_states[2].states[I...] -end +state_for_process(state::SwapState, I...) = state_for_process(state.states, I...) +state_for_process(proc2state, I...) = proc2state[I...] """ - beta_for_chain(state[, I...]) + model_for_chain(ordering, sampler, model, state, I...) -Return the β corresponding to the chain indexed by `I...`. -If `I...` is not specified, the β corresponding to `β=1.0` will be returned. -""" -beta_for_chain(state::TemperedState) = beta_for_chain(state, 1) -beta_for_chain(state::TemperedState, I...) = beta_for_chain(state.chain_to_beta, I...) -# NOTE: Array impl. is useful for testing. -beta_for_chain(chain_to_beta::AbstractArray, I...) = chain_to_beta[I...] +Return the model corresponding to the chain indexed by `I...`. +`ordering` specifies what sort of order the input models follow. """ - beta_for_process(state, I...) +function model_for_chain end -Return the β corresponding to the process indexed by `I...`. """ -beta_for_process(state::TemperedState, I...) = beta_for_process(state.chain_to_beta, state.process_to_chain, I...) -# NOTE: Array impl. is useful for testing. -function beta_for_process(chain_to_beta::AbstractArray, proc2chain::AbstractArray, I...) - return beta_for_chain(chain_to_beta, process_to_chain(proc2chain, I...)) -end + model_for_process(ordering, sampler, model, state, I...) -""" - model_for_chain(sampler, model, state, I...) - -Return the model corresponding to the chain indexed by `I...`. -""" -function model_for_chain(sampler, model, state, I...) - return make_tempered_model(sampler, model, beta_for_chain(state, I...)) -end +Return the model corresponding to the process indexed by `I...`. +`ordering` specifies what sort of order the input models follow. """ - model_for_process(sampler, model, state, I...) +function model_for_process end -Return the model corresponding to the process indexed by `I...`. """ -function model_for_process(sampler, model, state, I...) - return make_tempered_model(sampler, model, beta_for_process(state, I...)) -end + models_by_processes(ordering, models, state) +Return the models in the order of processes, assuming `models` is sorted according to `ordering`. +See also: [`ProcessOrdering`](@ref), [`ChainOrdering`](@ref). """ - TemperedTransition +models_by_processes(ordering, models, state) = sort_by_process(ordering, state, models) -Transition type for tempered samplers. """ -struct TemperedTransition{S} - transition::S - is_swap::Bool -end + samplers_by_processes(ordering, samplers, state) -TemperedTransition(transition::S) where {S} = TemperedTransition(transition, false) - -getparams_and_logprob(transition::TemperedTransition) = getparams_and_logprob(transition.transition) -getparams_and_logprob(model, transition::TemperedTransition) = getparams_and_logprob(model, transition.transition) +Return the `samplers` in the order of processes, assuming `samplers` is sorted according to `ordering`. +See also: [`ProcessOrdering`](@ref), [`ChainOrdering`](@ref). +""" +samplers_by_processes(ordering, samplers, state) = sort_by_process(ordering, state, samplers) diff --git a/src/stepping.jl b/src/stepping.jl index f37206e..3ea2ce1 100644 --- a/src/stepping.jl +++ b/src/stepping.jl @@ -1,138 +1,88 @@ -""" - should_swap(sampler, state) - -Return `true` if a swap should happen at this iteration, and `false` otherwise. -""" -function should_swap(sampler::TemperedSampler, state::TemperedState) - return state.total_steps % sampler.swap_every == 1 -end - get_init_params(x, _) = x get_init_params(init_params::Nothing, _) = nothing get_init_params(init_params::AbstractVector{<:Real}, _) = copy(init_params) get_init_params(init_params::AbstractVector{<:AbstractVector{<:Real}}, i) = init_params[i] +@concrete struct TemperedTransition + swaptransition + transition +end + +function transition_for_chain(transition::TemperedTransition, I...) + chain_idx = transition.swaptransition.chain_to_process[I...] + return transition.transition.transitions[chain_idx] +end + function AbstractMCMC.step( rng::Random.AbstractRNG, - model, + model::AbstractMCMC.AbstractModel, sampler::TemperedSampler; - N_burnin::Integer=0, - burnin_progress::Bool=AbstractMCMC.PROGRESS[], - init_params=nothing, kwargs... ) - - # `TemperedState` has the transitions and states in the order of - # the processes, and performs swaps by moving the (inverse) temperatures - # `β` between the processes, rather than moving states between processes - # and keeping the `β` local to each process. - # - # Therefore we iterate over the processes and then extract the corresponding - # `β`, `sampler` and `state`, and take a initialize. - # Create a `MultiSampler` and `MultiModel`. - multimodel = MultiModel( - make_tempered_model(sampler, model, sampler.inverse_temperatures[i]) + multimodel = MultiModel([ + make_tempered_model(sampler, model, sampler.chain_to_beta[i]) for i in 1:numtemps(sampler) - ) - multisampler = MultiSampler(getsampler(sampler, i) for i in 1:numtemps(sampler)) - multitransition, multistate = AbstractMCMC.step( - rng, multimodel, multisampler; - init_params=init_params, - kwargs... - ) + ]) + multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)]) + multistate = last(AbstractMCMC.step(rng, multimodel, multisampler; kwargs...)) # Make sure to collect, because we'll be using `setindex!(!)` later. - process_to_chain = collect(1:length(sampler.inverse_temperatures)) + process_to_chain = collect(1:length(sampler.chain_to_beta)) # Need to `copy` because this might be mutated. chain_to_process = copy(process_to_chain) - state = TemperedState( - (multitransition, multistate), - sampler.inverse_temperatures, - process_to_chain, + swapstate = SwapState( + multistate.states, chain_to_process, + process_to_chain, 1, - 0, - sampler.adaptation_states, - false, - Dict{Int,Float64}() + Dict{Int,Float64}(), ) - # TODO: Move this to AbstractMCMC. Or better, add to AbstractMCMC a way to - # specify a callback to be used for the `discard_initial`. - if N_burnin > 0 - AbstractMCMC.@ifwithprogresslogger burnin_progress name = "Burn-in" begin - # Determine threshold values for progress logging - # (one update per 0.5% of progress) - if burnin_progress - threshold = N_burnin ÷ 200 - next_update = threshold - end - - for i in 1:N_burnin - if burnin_progress && i >= next_update - ProgressLogging.@logprogress i / N_burnin - next_update = i + threshold - end - state = no_swap_step(rng, model, sampler, state; kwargs...) - @set! state.burnin_steps += 1 - end - end - end - - return TemperedTransition(transition_for_chain(state)), state + return AbstractMCMC.step(rng, model, sampler, TemperedState(swapstate, multistate, sampler.chain_to_beta)) end function AbstractMCMC.step( rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState; - kwargs... -) - # Reset state - @set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios) - - isswap = should_swap(sampler, state) - if isswap - state = swap_step(rng, model, sampler, state) - @set! state.is_swap = true - else - state = no_swap_step(rng, model, sampler, state; kwargs...) - @set! state.is_swap = false - end - - @set! state.total_steps += 1 - - # We want to return the transition for the _first_ chain, i.e. the chain usually corresponding to `β=1.0`. - return TemperedTransition(transition_for_chain(state), isswap), state -end - -function no_swap_step( - rng::Random.AbstractRNG, - model, + model::AbstractMCMC.AbstractModel, sampler::TemperedSampler, state::TemperedState; kwargs... ) - # Create the multi-versions with the ordering corresponding to the processes. - multimodel = MultiModel(model_for_process(sampler, model, state, i) for i in 1:numtemps(sampler)) - multisampler = MultiSampler(sampler_for_process(sampler, state, i) for i in 1:numtemps(sampler)) - multistate = MultipleStates(state_for_process(state, i) for i in 1:numtemps(sampler)) - - # And then step. - multitransition, multistate_next = AbstractMCMC.step( + # Create the tempered `MultiModel`. + multimodel = MultiModel([make_tempered_model(sampler, model, beta) for beta in state.chain_to_beta]) + # Create the tempered `MultiSampler`. + # We're assuming the user has given the samplers in an order according to the initial models. + multisampler = MultiSampler(samplers_by_processes( + ChainOrder(), + [getsampler(sampler, i) for i in 1:numtemps(sampler)], + state.swapstate + )) + # Create the composition which applies `SwapSampler` first. + sampler_composition = multisampler ∘ swapsampler(sampler) + + # Step! + # NOTE: This will internally re-order the models according to processes before taking steps, + # hence the resulting transitions and states will be in the order of processes, as we desire. + transition_composition, state_composition = AbstractMCMC.step( rng, multimodel, - multisampler, - multistate; + sampler_composition, + composition_state(sampler_composition, state.swapstate, state.state); kwargs... ) - # TODO: Maybe separate `transitions` and `states`? - @set! state.transitions_and_states = (multitransition, multistate_next) + # Construct the `TemperedTransition` and `TemperedState`. + swaptransition = inner_transition(transition_composition) + outertransition = outer_transition(transition_composition) - return state + swapstate = inner_state(state_composition) + outerstate = outer_state(state_composition) + + return ( + TemperedTransition(swaptransition, outertransition), + TemperedState(swapstate, outerstate, state.chain_to_beta) + ) end """ @@ -146,8 +96,8 @@ is used. function swap_step( rng::Random.AbstractRNG, model, - sampler::TemperedSampler, - state::TemperedState + sampler, + state ) return swap_step(swapstrategy(sampler), rng, model, sampler, state) end @@ -155,31 +105,34 @@ end function swap_step( strategy::ReversibleSwap, rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState + model::MultiModel, + sampler, + state ) # Randomly select whether to attempt swaps between chains # corresponding to odd or even indices of the temperature ladder - odd = rand([true, false]) - for k in [Int(2 * i - odd) for i in 1:(floor((numtemps(sampler) - 1 + odd) / 2))] - state = swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt) + odd = rand(rng, Bool) + # TODO: Use integer-division. + for k in [Int(2 * i - odd) for i in 1:(floor((length(model) - 1 + odd) / 2))] + state = swap_attempt(rng, model, sampler, state, k, k + 1) end return state end + function swap_step( strategy::NonReversibleSwap, rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState + model::MultiModel, + sampler, + state::SwapState # we're accessing `total_steps` restrict the type here ) # Alternate between attempting to swap chains corresponding # to odd and even indices of the temperature ladder - odd = state.total_steps % (2 * sampler.swap_every) != 0 - for k in [Int(2 * i - odd) for i in 1:(floor((numtemps(sampler) - 1 + odd) / 2))] - state = swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt) + odd = state.total_steps % 2 != 0 + # TODO: Use integer-division. + for k in [Int(2 * i - odd) for i in 1:(floor((length(model) - 1 + odd) / 2))] + state = swap_attempt(rng, model, sampler, state, k, k + 1) end return state end @@ -187,45 +140,45 @@ end function swap_step( strategy::SingleSwap, rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState + model::MultiModel, + sampler, + state ) # Randomly pick one index `k` of the temperature ladder and # attempt a swap between the corresponding chain and its neighbour - k = rand(rng, 1:(numtemps(sampler) - 1)) - return swap_attempt(rng, model, sampler, state, k, k + 1, sampler.adapt) + k = rand(rng, 1:(length(model) - 1)) + return swap_attempt(rng, model, sampler, state, k, k + 1) end function swap_step( strategy::SingleRandomSwap, rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState + model::MultiModel, + sampler, + state ) # Randomly pick two temperature ladder indices in order to # attempt a swap between the corresponding chains - chains = Set(1:numtemps(sampler)) + chains = Set(1:length(model)) i = pop!(chains, rand(rng, chains)) j = pop!(chains, rand(rng, chains)) - return swap_attempt(rng, model, sampler, state, i, j, sampler.adapt) + return swap_attempt(rng, model, sampler, state, i, j) end function swap_step( strategy::RandomSwap, rng::Random.AbstractRNG, - model, - sampler::TemperedSampler, - state::TemperedState + model::MultiModel, + sampler, + state ) # Iterate through all of temperature ladder indices, picking random # pairs and attempting swaps between the corresponding chains - chains = Set(1:numtemps(sampler)) + chains = Set(1:length(model)) while length(chains) >= 2 i = pop!(chains, rand(rng, chains)) j = pop!(chains, rand(rng, chains)) - state = swap_attempt(rng, model, sampler, state, i, j, sampler.adapt) + state = swap_attempt(rng, model, sampler, state, i, j) end return state end @@ -234,8 +187,8 @@ function swap_step( strategy::NoSwap, rng::Random.AbstractRNG, model, - sampler::TemperedSampler, - state::TemperedState + sampler, + state ) return state end diff --git a/src/swapping.jl b/src/swapping.jl index 26ca560..6bd8eae 100644 --- a/src/swapping.jl +++ b/src/swapping.jl @@ -79,11 +79,11 @@ this overrides and disables all swapping functionality. struct NoSwap <: AbstractSwapStrategy end """ - swap_betas!(chain_to_process, process_to_chain, i, j) + swap!(chain_to_process, process_to_chain, i, j) Swaps the `i`th and `j`th temperatures in place. """ -function swap_betas!(chain_to_process, process_to_chain, i, j) +function swap!(chain_to_process, process_to_chain, i, j) # TODO: Use BangBang's `@set!!` to also support tuples? # Extract the process index for each of the chains. process_for_chain_i, process_for_chain_j = chain_to_process[i], chain_to_process[j] @@ -123,6 +123,19 @@ function compute_tempered_logdensities( return compute_tempered_logdensities(model, sampler, transition, transition_other, β) end +function compute_logdensities( + model::AbstractMCMC.AbstractModel, + model_other::AbstractMCMC.AbstractModel, + state, + state_other, +) + # TODO: Make use of `getparams_and_logprob` instead? + return ( + logdensity(model, getparams(model, state)), + logdensity(model, getparams(model_other, state_other)) + ) +end + """ swap_acceptance_pt(logπi, logπj) @@ -134,50 +147,3 @@ function swap_acceptance_pt(logπiθi, logπiθj, logπjθi, logπjθj) return (logπjθi + logπiθj) - (logπiθi + logπjθj) end - -""" - swap_attempt(rng, model, sampler, state, i, j) - -Attempt to swap the temperatures of two chains by tempering the densities and -calculating the swap acceptance ratio; then swapping if it is accepted. -""" -function swap_attempt(rng, model, sampler, state, i, j, adapt) - # Extract the relevant transitions. - sampler_i = sampler_for_chain(sampler, state, i) - sampler_j = sampler_for_chain(sampler, state, j) - transition_i = transition_for_chain(state, i) - transition_j = transition_for_chain(state, j) - state_i = state_for_chain(state, i) - state_j = state_for_chain(state, j) - β_i = beta_for_chain(state, i) - β_j = beta_for_chain(state, j) - # Evaluate logdensity for both parameters for each tempered density. - logπiθi, logπiθj = compute_tempered_logdensities( - model, sampler_i, sampler_j, transition_i, transition_j, state_i, state_j, β_i, β_j - ) - logπjθj, logπjθi = compute_tempered_logdensities( - model, sampler_j, sampler_i, transition_j, transition_i, state_j, state_i, β_j, β_i - ) - - # If the proposed temperature swap is accepted according `logα`, - # swap the temperatures for future steps. - logα = swap_acceptance_pt(logπiθi, logπiθj, logπjθi, logπjθj) - should_swap = -Random.randexp(rng) ≤ logα - if should_swap - swap_betas!(state.chain_to_process, state.process_to_chain, i, j) - end - - # Keep track of the (log) acceptance ratios. - state.swap_acceptance_ratios[i] = logα - - # Adaptation steps affects `ρs` and `inverse_temperatures`, as the `ρs` is - # adapted before a new `inverse_temperatures` is generated and returned. - if adapt - ρs = adapt!!( - state.adaptation_states, state.chain_to_beta, i, min(one(logα), exp(logα)) - ) - @set! state.adaptation_states = ρs - @set! state.chain_to_beta = update_inverse_temperatures(ρs, state.chain_to_beta) - end - return state -end diff --git a/src/swapsampler.jl b/src/swapsampler.jl new file mode 100644 index 0000000..6e43ff1 --- /dev/null +++ b/src/swapsampler.jl @@ -0,0 +1,224 @@ +""" + SwapSampler <: AbstractMCMC.AbstractSampler + +# Fields +$(FIELDS) +""" +struct SwapSampler{S} <: AbstractMCMC.AbstractSampler + "swap strategy to use" + strategy::S +end + +SwapSampler() = SwapSampler(ReversibleSwap()) + +swapstrategy(sampler::SwapSampler) = sampler.strategy + +# Interaction with the state. +# NOTE: `SwapSampler` should only every interact with `ProcessOrdering`, so we don't implement `ChainOrdering`. +function model_for_chain(ordering::ProcessOrder, sampler::SwapSampler, model::MultiModel, state::SwapState, I...) + # `model` is expected to be ordered according to process index, hence we map chain index to process index + # and extract the model corresponding to said process. + return model_for_process(ordering, sampler, model, state, chain_to_process(state, I...)) +end + +function model_for_process(::ProcessOrder, sampler::SwapSampler, model::MultiModel, state::SwapState, I...) + # `model` is expected to be ordered according to process index, hence we just extract the corresponding index. + return model.models[I...] +end + +""" + SwapTransition + +Transition type for tempered samplers. +""" +@concrete struct SwapTransition + chain_to_process + process_to_chain +end + +function composition_transition( + sampler::CompositionSampler{<:AbstractMCMC.AbstractSampler,<:SwapSampler}, + swaptransition::SwapTransition, + outertransition::MultipleTransitions +) + saveall(sampler) && return CompositionTransition(outertransition, swaptransition) + # Otherwise we have to re-order the transitions, since without the `swaptransition` there's + # no way to recover the true ordering of the transitions. + return MultipleTransitions(sort_by_chain(ProcessOrder(), swaptransition, outertransition.transitions)) +end + +# NOTE: This does not have an initial `step`! This is because we need +# states to work with before we can do anything. Hence it only makes +# sense to use this sampler in composition with other samplers. +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model, + sampler::SwapSampler, + state::SwapState; + kwargs... +) + # Reset state + @set! state.swap_acceptance_ratios = empty(state.swap_acceptance_ratios) + + # Perform a swap step. + state = swap_step(rng, model, sampler, state) + @set! state.total_steps += 1 + + # We want to return the transition for the _first_ chain, i.e. the chain usually corresponding to `β=1.0`. + # TODO: What should we return here? + return SwapTransition(deepcopy(state.chain_to_process), deepcopy(state.process_to_chain)), state +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + sampler::CompositionSampler{<:AbstractMCMC.AbstractSampler,<:SwapSampler}, + state; + kwargs... +) + # Reminder: a `swap` can be implemented in two different ways: + # + # 1. Swap the models and leave ordering of (sampler, state)-pair unchanged. + # 2. Swap (sampler, state)-pairs and leave ordering of models unchanged. + # + # (1) has the properties: + # + Easy to keep `outerstate` and `swapstate` in sync since their ordering is never changed. + # - Ordering of `outerstate` no longer corresponds to ordering of models, i.e. the returned + # `outerstate.states[i]` does no longer correspond to a state targeting `model.models[i]`. + # This will have to be adjusted in the `AbstractMCMC.bundle_samples` before, say, converting + # into a `MCMCChains.Chains`. + # + # (2) has the properties: + # + Returned `outertransition` (and `outerstate`, if we want) has the same ordering as the models, + # i.e. `outerstate.states[i]` now corresponds to `model.models[i]`! + # - Need to keep `outerstate` and `swapstate` in sync since their ordering now changes. + # - Need to also re-order `outersampler.samplers` :/ + # + # Here (as in, below) we go with option (1), i.e. re-order the `models`. + # A full `step` then is as follows: + # 1. Sort models according to index processes using the `swapstate` from previous iteration. + # 2. Take step with `swapsampler`. + # 3. Sort models _again_ according to index processes using the new `swapstate`, since we + # might have made a swap in (2). + # 4. Run multi-sampler. + + outersampler, swapsampler = outer_sampler(sampler), inner_sampler(sampler) + + # Get the states. + outerstate_prev, swapstate_prev = outer_state(state), inner_state(state) + + # Re-order the models. + chain2models = model.models # but keep the original chain → model around because we'll re-order again later + @set! model.models = models_by_processes(ChainOrder(), chain2models, swapstate_prev) + + # Step for the swap-sampler. + swaptransition, swapstate = AbstractMCMC.step( + rng, model, swapsampler, state_from(model, swapstate_prev, outerstate_prev); + kwargs... + ) + + # Re-order the models AGAIN, since we might have swapped some. + @set! model.models = models_by_processes(ChainOrder(), chain2models, swapstate) + + # Create the current state from `outerstate_prev` and `swapstate`, and `step` for `outersampler`.` + outertransition, outerstate = AbstractMCMC.step( + # HACK: We really need the `state_from` here despite the fact that `SwapSampler` does note + # change the `swapstates.states` itself, but we might require a re-computation of certain + # quantities from the `model`, which has now potentially been re-ordered (see above). + # NOTE: We do NOT do `state_from(model, outerstate_prev, swapstate)` because as of now, + # `swapstate` does not implement `getparams_and_logprob`. + rng, model, outersampler, state_from(model, outerstate_prev, outerstate_prev); + kwargs... + ) + + # TODO: Should we re-order the transitions? + # Currently, one has to re-order the `outertransition` according to `swaptransition` + # in the `bundle_samples`. Is this the right approach though? + # TODO: We should at least re-order transitions in the case where `saveall(sampler) == false`! + # In this case, we'll just return the transition without the swap-transition, hence making it + # impossible to reconstruct the actual ordering! + return ( + composition_transition(sampler, swaptransition, outertransition), + composition_state(sampler, swapstate, outerstate) + ) +end + +# NOTE: The default initial `step` for `CompositionSampler` simply calls the two different +# `step` methods, but since `SwapSampler` does not have such an implementation this will fail. +# Instead we overload the initial `step` for `CompositionSampler` involving `SwapSampler` to +# first take a `step` using the non-swapsampler and then construct `SwapState` from the resulting state. +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + sampler::CompositionSampler{<:AbstractMCMC.AbstractSampler,<:SwapSampler}; + kwargs... +) + # This should hopefully be a `MultipleStates` or something since we're working with a `MultiModel`. + state_outer_initial = last(AbstractMCMC.step(rng, model, outer_sampler(sampler); kwargs...)) + # NOTE: Since `SwapState` wraps a sequence of states from another sampler, we need `state_outer_initial` + # to initialize the `SwapState`. + state_inner_initial = SwapState(state_outer_initial) + + # Create the composition state, and take a full step. + state = composition_state(sampler, state_inner_initial, state_outer_initial) + return AbstractMCMC.step(rng, model, sampler, state; kwargs...) +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + sampler::CompositionSampler{<:SwapSampler,<:AbstractMCMC.AbstractSampler}; + kwargs... +) + # This should hopefully be a `MultipleStates` or something since we're working with a `MultiModel`. + state_inner_initial = last(AbstractMCMC.step(rng, model, inner_sampler(sampler); kwargs...)) + # NOTE: Since `SwapState` wraps a sequence of states from another sampler, we need `state_outer_initial` + # to initialize the `SwapState`. + state_outer_initial = SwapState(state_inner_initial) + + # Create the composition state, and take a full step. + state = composition_state(sampler, state_inner_initial, state_outer_initial) + return AbstractMCMC.step(rng, model, sampler, state; kwargs...) +end + +@nospecialize function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + sampler::CompositionSampler{<:SwapSampler,<:SwapSampler}; + kwargs... +) + error("`SwapSampler` requires states from sampler other than `SwapSampler` to be initialized") +end + +""" + swap_attempt(rng, model, sampler, state, i, j) + +Attempt to swap the temperatures of two chains by tempering the densities and +calculating the swap acceptance ratio; then swapping if it is accepted. +""" +function swap_attempt(rng::Random.AbstractRNG, model::MultiModel, sampler::SwapSampler, state, i, j) + # Extract the relevant transitions. + state_i = state_for_chain(state, i) + state_j = state_for_chain(state, j) + # Evaluate logdensity for both parameters for each tempered density. + # NOTE: `SwapSampler` should only be working with models ordered according to `ProcessOrder`, + # never `ChainOrder`, hence why we have the below. + model_i = model_for_chain(ProcessOrder(), sampler, model, state, i) + model_j = model_for_chain(ProcessOrder(), sampler, model, state, j) + logπiθi, logπiθj = compute_logdensities(model_i, model_j, state_i, state_j) + logπjθj, logπjθi = compute_logdensities(model_j, model_i, state_j, state_i) + + # If the proposed temperature swap is accepted according `logα`, + # swap the temperatures for future steps. + logα = swap_acceptance_pt(logπiθi, logπiθj, logπjθi, logπjθj) + should_swap = -Random.randexp(rng) ≤ logα + if should_swap + swap!(state.chain_to_process, state.process_to_chain, i, j) + end + + # Keep track of the (log) acceptance ratios. + state.swap_acceptance_ratios[i] = logα + + # TODO: Handle adaptation. + return state +end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..17b9125 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,34 @@ +# TODO: Move. +chain_to_process(transition::SwapTransition, I...) = transition.chain_to_process[I...] + +""" + roundtrips(transitions) + +Return sequence of `(start_index, turnpoint_index, end_index)`-triples representing roundtrips. +""" +function roundtrips(transitions::AbstractVector{<:TemperedTransition}) + return roundtrips(map(Base.Fix2(getproperty, :swaptransition), transitions)) +end +function roundtrips(transitions::AbstractVector{<:SwapTransition}) + result = Tuple{Int,Int,Int}[] + start_index, turn_index = 1, nothing + for (i, t) in enumerate(transitions) + n = length(t.chain_to_process) + if isnothing(turn_index) + # Looking for the turn. + if chain_to_process(t, 1) == n + turn_index = i + end + else + # Looking for the return/end. + if chain_to_process(t, 1) == 1 + push!(result, (start_index, turn_index, i)) + # Reset. + start_index = i + turn_index = nothing + end + end + end + + return result +end diff --git a/test/Project.toml b/test/Project.toml index 0f4c639..492577c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -10,8 +11,10 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" @@ -23,6 +26,6 @@ Bijectors = "0.10" Distributions = "0.24, 0.25" LogDensityProblems = "2" LogDensityProblemsAD = "1" -MCMCChains = "5.5" +MCMCChains = "6" Turing = "0.24" julia = "1" diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index 440e2a4..3cc9c8a 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -19,11 +19,7 @@ state_initial, ) - if MCMCTempering.saveall(spl_composed) - @test state_composed_initial isa MCMCTempering.SequentialStates - else - @test state_composed_initial isa MCMCTempering.CompositionState - end + @test state_composed_initial isa MCMCTempering.CompositionState # Take two steps with `spl`. rng = Random.MersenneTwister(42) @@ -42,11 +38,9 @@ # Make sure the state types stay consistent. if MCMCTempering.saveall(spl_composed) - @test transition isa MCMCTempering.SequentialTransitions - @test state_composed isa MCMCTempering.SequentialStates - else - @test state_composed isa MCMCTempering.CompositionState + @test transition isa MCMCTempering.CompositionTransition end + @test state_composed isa MCMCTempering.CompositionState end params_composed, logp_composed = MCMCTempering.getparams_and_logprob(logdensity_model, state_composed) @@ -62,6 +56,7 @@ ) # Should be the same length because the `SequentialTransitions` will be unflattened. + @test chain_composed isa MCMCChains.Chains @test length(chain_composed) == length(chain) end @@ -160,4 +155,32 @@ @test map(last, params_and_logp) == logp_multi end end + + @testset "SwapSampler" begin + # SwapSampler without tempering (i.e. in a composition and using `MultiModel`, etc.) + init_params = [[5.0], [5.0]] + mdl1 = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), DistributionLogDensity(Normal(4.9999, 1))) + mdl2 = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), DistributionLogDensity(Normal(5.0001, 1))) + spl1 = RWMH(MvNormal(Zeros(dimension(mdl1)), I)) + spl2 = let σ² = 1e-2 + MALA(∇ -> MvNormal(σ² * ∇, 2σ² * I)) + end + swapspl = MCMCTempering.SwapSampler() + spl_full = (spl1 × spl2) ∘ swapspl + product_model = LogDensityModel(mdl1) × LogDensityModel(mdl2) + # Sample! + multisamples = sample(product_model, spl_full, 1000; init_params=init_params, progress=false) + # Extract the transitions corresponding to each of the models. + model_transitions = mapreduce(hcat, multisamples) do t + [MCMCTempering.outer_transition(t).transitions[MCMCTempering.inner_transition(t).process_to_chain]...] + end + # Make sure we actually got some swaps going and we were using different types of states + # for both models. + @test length(unique(typeof, model_transitions[1, :])) ≥ 1 + @test length(unique(typeof, model_transitions[2, :])) ≥ 1 + + # Check that means are roughly okay. + model_params = map(first ∘ MCMCTempering.getparams, model_transitions) + @test vec(mean(model_params; dims=2)) ≈ [5.0, 5.0] atol=0.2 + end end diff --git a/test/compat.jl b/test/compat.jl index d244b50..7052bb1 100644 --- a/test/compat.jl +++ b/test/compat.jl @@ -6,11 +6,23 @@ function MCMCTempering.setparams_and_logprob!!(transition::AdvancedMH.Transition return transition end MCMCTempering.getparams_and_logprob(transition::AdvancedMH.GradientTransition) = transition.params, transition.lp -function MCMCTempering.setparams_and_logprob!!(transition::AdvancedMH.GradientTransition, params, lp) - Setfield.@set! transition.params = params - Setfield.@set! transition.lp = lp - return transition +# TODO: Implement `state_from` instead, to avoid re-computation of gradients if possible. +function MCMCTempering.setparams_and_logprob!!(model, transition::AdvancedMH.GradientTransition, params, lp) + # NOTE: We have to re-compute the gradient here because this will be used in the subsequent `step` for + # the MALA sampler. + return AdvancedMH.GradientTransition(params, AdvancedMH.logdensity_and_gradient(model, params)...) end # AdvancedHMC.jl MCMCTempering.getparams_and_logprob(t::AdvancedHMC.Transition) = t.z.θ, t.z.ℓπ.value +MCMCTempering.getparams_and_logprob(state::AdvancedHMC.HMCState) = MCMCTempering.getparams_and_logprob(state.transition) + +# TODO: Implement `state_from` instead, to avoid re-computation of gradients if possible. +function MCMCTempering.setparams_and_logprob!!(model, state::AdvancedHMC.HMCState, params, lp) + # NOTE: Need to recompute the gradient because it might be used in the next integration step. + hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) + return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, params, state.transition.z.r; + ℓκ=state.transition.z.ℓκ + ) +end diff --git a/test/runtests.jl b/test/runtests.jl index 85a076c..160cabf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,7 +17,8 @@ Several properties of the tempered sampler are tested before returning: # Keyword arguments - `num_iterations`: The number of iterations to run the sampler for. Defaults to `2_000`. -- `swap_every`: The number of iterations between each swap attempt. Defaults to `2`. +- `steps_per_swap`: The number of iterations between each swap attempt. Defaults to `1`. +- `adapt`: Whether to adapt the sampler. Defaults to `false`. - `adapt_target`: The target acceptance rate for the swaps. Defaults to `0.234`. - `adapt_rtol`: The relative tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.1`. - `adapt_atol`: The absolute tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.05`. @@ -26,59 +27,63 @@ Several properties of the tempered sampler are tested before returning: - `init_params`: The initial parameters to use for the sampler. Defaults to `nothing`. - `param_names`: The names of the parameters in the chain; used to construct the resulting chain. Defaults to `missing`. - `progress`: Whether to show a progress bar. Defaults to `false`. -- `kwargs...`: Additional keyword arguments to pass to `MCMCTempering.tempered`. """ function test_and_sample_model( model, sampler, - inverse_temperatures, - swap_strategy=MCMCTempering.SingleSwap(); + inverse_temperatures; + swap_strategy=MCMCTempering.SingleSwap(), mean_swap_rate_bound=0.1, compare_mean_swap_rate=≥, num_iterations=2_000, - swap_every=1, + steps_per_swap=1, + adapt=false, adapt_target=0.234, adapt_rtol=0.1, adapt_atol=0.05, init_params=nothing, param_names=missing, progress=false, - kwargs... + minimum_roundtrips=nothing ) - # NOTE: Every other `step` will perform a swap. - num_iterations_tempered = 2 * num_iterations - # Make the tempered sampler. sampler_tempered = tempered( sampler, inverse_temperatures; swap_strategy=swap_strategy, - swap_every=swap_every, + steps_per_swap=steps_per_swap, adapt_target=adapt_target, - kwargs... ) + @test sampler_tempered.swapstrategy == swap_strategy + @test MCMCTempering.swapsampler(sampler_tempered).strategy == swap_strategy + # Store the states. states_tempered = [] callback = StateHistoryCallback(states_tempered) # Sample. samples_tempered = AbstractMCMC.sample( - model, sampler_tempered, num_iterations_tempered; + model, sampler_tempered, num_iterations; callback=callback, progress=progress, init_params=init_params ) + if !isnothing(minimum_roundtrips) + # Make sure we've had at least some roundtrips. + @test length(MCMCTempering.roundtrips(samples_tempered)) ≥ minimum_roundtrips + end + # Let's make sure the process ↔ chain mapping is valid. numtemps = MCMCTempering.numtemps(sampler_tempered) - for state in states_tempered - for i = 1:numtemps + @test all(states_tempered) do state + all(1:numtemps) do i # These two should be inverses of each other. - @test MCMCTempering.process_to_chain(state, MCMCTempering.chain_to_process(state, i)) == i + MCMCTempering.process_to_chain(state, MCMCTempering.chain_to_process(state, i)) == i end end # Extract the states that were swapped. - states_swapped = filter(Base.Fix2(getproperty, :is_swap), states_tempered) + states_swapped = map(Base.Fix2(getproperty, :swapstate), states_tempered) # Swap acceptance ratios should be compared against the target acceptance in case of adaptation. swap_acceptance_ratios = mapreduce( collect ∘ values ∘ Base.Fix2(getproperty, :swap_acceptance_ratios), @@ -97,37 +102,27 @@ function test_and_sample_model( end # Extract the history of chain indices. - process_to_chain_history_list = map(states_tempered) do state + process_to_chain_history_list = map(states_swapped) do state state.process_to_chain end process_to_chain_history = permutedims(reduce(hcat, process_to_chain_history_list), (2, 1)) # Check that the swapping has been done correctly. - process_to_chain_uniqueness = map(states_tempered) do state + process_to_chain_uniqueness = map(states_swapped) do state length(unique(state.process_to_chain)) == length(state.process_to_chain) end @test all(process_to_chain_uniqueness) - # For the currently implemented strategies, the index process should not move by more than 1. - @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) + # For every strategy except `RandomSwap`, the index process should not move by more than 1. + if !(swap_strategy isa Union{MCMCTempering.SingleRandomSwap,MCMCTempering.RandomSwap}) + @test all(abs.(diff(process_to_chain_history[:, 1])) .≤ 1) + end - chain_to_process_uniqueness = map(states_tempered) do state + chain_to_process_uniqueness = map(states_swapped) do state length(unique(state.chain_to_process)) == length(state.chain_to_process) end @test all(chain_to_process_uniqueness) - # Tests that we have at least swapped some times (say at least 10% of attempted swaps). - swap_success_indicators = map(eachrow(diff(process_to_chain_history; dims=1))) do row - # Some of the strategies performs multiple swaps in a swap-iteration, - # but we want to count the number of iterations for which we had a successful swap, - # i.e. only count non-zero elements in a row _once_. Hence the `min`. - min(1, sum(abs, row)) - end - @test compare_mean_swap_rate( - sum(swap_success_indicators), - (num_iterations_tempered / swap_every) * mean_swap_rate_bound - ) - # Compare the tempered sampler to the untempered sampler. state_tempered = states_tempered[end] chain_tempered = AbstractMCMC.bundle_samples( @@ -140,44 +135,50 @@ function test_and_sample_model( param_names=param_names ) + # Tests that we have at least swapped some times (say at least 10% of attempted swaps). + swap_success_indicators = map(eachrow(diff(process_to_chain_history; dims=1))) do row + # Some of the strategies performs multiple swaps in a swap-iteration, + # but we want to count the number of iterations for which we had a successful swap, + # i.e. only count non-zero elements in a row _once_. Hence the `min`. + min(1, sum(abs, row)) + end + + num_nonswap_steps_taken = length(chain_tempered) + @test num_nonswap_steps_taken == (num_iterations * steps_per_swap) + @test compare_mean_swap_rate( + sum(swap_success_indicators), + (num_nonswap_steps_taken / steps_per_swap) * mean_swap_rate_bound + ) + return chain_tempered end function compare_chains( chain::MCMCChains.Chains, chain_tempered::MCMCChains.Chains; atol=1e-6, rtol=1e-6, - compare_std=true, compare_ess=true, - compare_ess_slack=0.8, + compare_ess_slack=0.5, # HACK: this is very low which is unnecessary in most cases, but it's too random isbroken=false ) - desc = describe(chain)[1].nt - desc_tempered = describe(chain_tempered)[1].nt + mean = to_dict(MCMCChains.mean(chain)) + mean_tempered = to_dict(MCMCChains.mean(chain_tempered)) # Compare the means. if isbroken - @test_broken desc.mean ≈ desc_tempered.mean atol = atol rtol = rtol + @test_broken all(isapprox(mean[sym], mean_tempered[sym]; atol, rtol) for sym in keys(mean)) else - @test desc.mean ≈ desc_tempered.mean atol = atol rtol = rtol - end - - # Compare the std. of the chains. - if compare_std - if isbroken - @test_broken desc.std ≈ desc_tempered.std atol = atol rtol = rtol - else - @test desc.std ≈ desc_tempered.std atol = atol rtol = rtol - end + @test all(isapprox(mean[sym], mean_tempered[sym]; atol, rtol) for sym in keys(mean)) end # Compare the ESS. if compare_ess - ess = MCMCChains.ess_rhat(chain).nt.ess - ess_tempered = MCMCChains.ess_rhat(chain_tempered).nt.ess + ess = to_dict(MCMCChains.ess(chain)) + ess_tempered = to_dict(MCMCChains.ess(chain_tempered)) + @info "" ess ess_tempered if isbroken - @test_broken all(ess_tempered .≥ ess .* compare_ess_slack) + @test_broken all(ess_tempered[sym] ≥ ess[sym] * compare_ess_slack for sym in keys(ess)) else - @test all(ess_tempered .≥ ess .* compare_ess_slack) + @test all(ess_tempered[sym] ≥ ess[sym] * compare_ess_slack for sym in keys(ess)) end end end @@ -197,7 +198,7 @@ end chain_to_beta = [1.0, 0.75, 0.5, 0.25] # Make swap chain 1 (now on process 1) ↔ chain 2 (now on process 2) - MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 1, 2) + MCMCTempering.swap!(chain_to_process, process_to_chain, 1, 2) # Expected result: chain 1 is now on process 2, chain 2 is now on process 1. target_process_to_chain = [2, 1, 3, 4] @test process_to_chain[chain_to_process] == 1:length(process_to_chain) @@ -213,7 +214,7 @@ end end # Make swap chain 2 (now on process 1) ↔ chain 3 (now on process 3) - MCMCTempering.swap_betas!(chain_to_process, process_to_chain, 2, 3) + MCMCTempering.swap!(chain_to_process, process_to_chain, 2, 3) # Expected result: chain 3 is now on process 1, chain 2 is now on process 3. target_process_to_chain = [3, 1, 2, 4] @test process_to_chain[chain_to_process] == 1:length(process_to_chain) @@ -244,13 +245,13 @@ end [1.0, 1e-3], # extreme temperatures -> don't exect much swapping to occur num_iterations=num_iterations, adapt=false, - init_params = [[0.0], [1000.0]], # initialized far apart - # At most 1% of swaps should be successful. + init_params=[[0.0], [1000.0]], # initialized far apart + # At MOST 1% of swaps should be successful. mean_swap_rate_bound=0.01, compare_mean_swap_rate=≤, ) # `atol` is fairly high because we haven't run this for "too" long. - @test mean(chain_tempered[:, 1, :]) ≈ 1 atol=0.2 + @test mean(chain_tempered[:, 1, :]) ≈ 1 atol=0.3 end @testset "GMM 1D" begin @@ -260,7 +261,7 @@ end ) # Setup non-tempered. - sampler_rwmh = RWMH(MvNormal(0.1 * ones(1))) + sampler_rwmh = RWMH(MvNormal(0.1 * Diagonal(Ones(1)))) # Simple geometric ladder inverse_temperatures = MCMCTempering.check_inverse_temperatures(0.95 .^ (0:20)) @@ -270,12 +271,15 @@ end model, sampler_rwmh, inverse_temperatures, + swap_strategy=MCMCTempering.NonReversibleSwap(), num_iterations=num_iterations, adapt=false, # At least 25% of swaps should be successful. mean_swap_rate_bound=0.25, compare_mean_swap_rate=≥, progress=false, + # Make sure we have _some_ roundtrips. + minimum_roundtrips=10, ) # # Compare the chains. @@ -312,16 +316,19 @@ end MCMCTempering.RandomSwap() ] - @testset "$(swapstrategy)" for swapstrategy in swapstrategies + @testset "$(swap_strategy)" for swap_strategy in swapstrategies chain_tempered = test_and_sample_model( model, sampler, inverse_temperatures, num_iterations=num_iterations, - swapstrategy=swapstrategy, + swap_strategy=swap_strategy, adapt=false, + # Make sure we have _some_ roundtrips. + minimum_roundtrips=10, ) - compare_chains(chain, chain_tempered, rtol=0.1, compare_std=false, compare_ess=true) + + compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true) end end @@ -348,15 +355,14 @@ end end @testset "AdvancedHMC.jl" begin - num_iterations = 2_000 + num_iterations = 5_000 # Set up HMC smpler. initial_ϵ = 0.1 integrator = AdvancedHMC.Leapfrog(initial_ϵ) proposal = AdvancedHMC.NUTS{AdvancedHMC.MultinomialTS, AdvancedHMC.GeneralisedNoUTurn}(integrator) metric = AdvancedHMC.DiagEuclideanMetric(LogDensityProblems.dimension(model)) - adaptor = AdvancedHMC.StanHMCAdaptor(AdvancedHMC.MassMatrixAdaptor(metric), AdvancedHMC.StepSizeAdaptor(0.8, integrator)) - sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric, adaptor) + sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric) # Sample using HMC. samples_hmc = sample(model, sampler_hmc, num_iterations; init_params=copy(init_params), progress=false) @@ -366,27 +372,47 @@ end ) map_parameters!(b, chain_hmc) + # Make sure that we get the "same" result when only using the inverse temperature 1. + sampler_tempered = MCMCTempering.TemperedSampler(sampler_hmc, [1]) + chain_tempered = sample( + model, sampler_tempered, num_iterations; + init_params=copy(init_params), + chain_type=MCMCChains.Chains, + param_names=param_names, + progress=false, + ) + map_parameters!(b, chain_tempered) + compare_chains( + chain_hmc, chain_tempered; + atol=0.2, + compare_ess=true, + isbroken=false + ) + # Sample using tempered HMC. chain_tempered = test_and_sample_model( model, sampler_hmc, - [1, 0.25, 0.1, 0.01], + [1, 0.75, 0.5, 0.25, 0.1, 0.01], swap_strategy=MCMCTempering.ReversibleSwap(), num_iterations=num_iterations, adapt=false, - mean_swap_rate_bound=0, + mean_swap_rate_bound=0.1, init_params=copy(init_params), param_names=param_names, progress=false ) map_parameters!(b, chain_tempered) - - # TODO: Make it not broken, i.e. produce reasonable results. - compare_chains(chain_hmc, chain_tempered, atol=0.2, compare_std=false, compare_ess=true, isbroken=false) + compare_chains( + chain_hmc, chain_tempered; + atol=0.3, + compare_ess=true, + isbroken=false, + ) end - + @testset "AdvancedMH.jl" begin - num_iterations = 2_000 + num_iterations = 10_000 d = LogDensityProblems.dimension(model) # Set up MALA sampler. @@ -394,17 +420,33 @@ end sampler_mh = MALA(∇ -> MvNormal(σ² * ∇, 2σ² * I)) # Sample using MALA. - samples_mh = AbstractMCMC.sample( + chain_mh = AbstractMCMC.sample( model, sampler_mh, num_iterations; - init_params=copy(init_params), progress=false - ) - chain_mh = AbstractMCMC.bundle_samples( - samples_mh, MCMCTempering.maybe_wrap_model(model), sampler_mh, samples_mh[1], MCMCChains.Chains; - param_names=param_names + init_params=copy(init_params), + progress=false, + chain_type=MCMCChains.Chains, + param_names=param_names, ) map_parameters!(b, chain_mh) - # Sample using tempered MALA. + # Make sure that we get the "same" result when only using the inverse temperature 1. + sampler_tempered = MCMCTempering.TemperedSampler(sampler_mh, [1]) + chain_tempered = sample( + model, sampler_tempered, num_iterations; + init_params=copy(init_params), + chain_type=MCMCChains.Chains, + param_names=param_names, + progress=false, + ) + map_parameters!(b, chain_tempered) + compare_chains( + chain_mh, chain_tempered; + atol=0.2, + compare_ess=true, + isbroken=false, + ) + + # Sample using actual tempering. chain_tempered = test_and_sample_model( model, sampler_mh, @@ -412,16 +454,17 @@ end swap_strategy=MCMCTempering.ReversibleSwap(), num_iterations=num_iterations, adapt=false, - mean_swap_rate_bound=0, + mean_swap_rate_bound=0.1, init_params=copy(init_params), param_names=param_names ) map_parameters!(b, chain_tempered) # Need a large atol as MH is not great on its own - compare_chains(chain_mh, chain_tempered, atol=0.4, compare_std=false, compare_ess=true, isbroken=false) + compare_chains(chain_mh, chain_tempered, atol=0.2, compare_ess=true, isbroken=false) end end include("abstractmcmc.jl") + include("simple_gaussian.jl") end diff --git a/test/setup.jl b/test/setup.jl index 9eedc5a..90195ca 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -18,5 +18,5 @@ using Turing: Turing, DynamicPPL include("utils.jl") +include("test_utils.jl") include("compat.jl") - diff --git a/test/simple_gaussian.jl b/test/simple_gaussian.jl new file mode 100644 index 0000000..1ff19b8 --- /dev/null +++ b/test/simple_gaussian.jl @@ -0,0 +1,74 @@ +@testset "Simple tempered Gaussian (closed form)" begin + μ = Zeros(1) + inverse_temperatures = MCMCTempering.check_inverse_temperatures(0.8 .^ (0:10)) + variances_true = inv.(inverse_temperatures) + std_true_dict = map(variances_true) do v + Dict(:param_1 => √v) + end + tempered_dists = [MvNormal(Zeros(1), I / β) for β in inverse_temperatures] + tempered_multimodel = MCMCTempering.MultiModel(map(LogDensityModel ∘ DistributionLogDensity, tempered_dists)) + + init_params = zeros(length(μ)) + + num_samples = 1_000 + num_burnin = num_samples ÷ 2 + thin = 10 + + # Samplers. + rwmh = RWMH(MvNormal(Zeros(1), I)) + rwmh_tempered = TemperedSampler(rwmh, inverse_temperatures) + rwmh_product = MCMCTempering.MultiSampler(Fill(rwmh, length(tempered_dists))) + rwmh_product_with_swap = rwmh_product ∘ MCMCTempering.SwapSampler() + + # Sample. + @testset "TemperedSampler" begin + chains_product = sample( + DistributionLogDensity(tempered_dists[1]), rwmh_tempered, num_samples; + init_params, + bundle_resolve_swaps=true, + chain_type=Vector{MCMCChains.Chains}, + progress=false, + discard_initial=num_burnin, + thinning=thin, + ) + test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + end + + @testset "MultiSampler without swapping" begin + chains_product = sample( + tempered_multimodel, rwmh_product, num_samples; + init_params, + chain_type=Vector{MCMCChains.Chains}, + progress=false, + discard_initial=num_burnin, + thinning=thin, + ) + test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + end + + @testset "MultiSampler with swapping (saveall=true)" begin + chains_product = sample( + tempered_multimodel, rwmh_product_with_swap, num_samples; + init_params, + bundle_resolve_swaps=true, + chain_type=Vector{MCMCChains.Chains}, + progress=false, + discard_initial=num_burnin, + thinning=thin, + ) + test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + end + + @testset "MultiSampler with swapping (saveall=true)" begin + chains_product = sample( + tempered_multimodel, Setfield.@set(rwmh_product_with_swap.saveall = Val(false)), num_samples; + init_params, + chain_type=Vector{MCMCChains.Chains}, + progress=false, + discard_initial=num_burnin, + thinning=thin, + ) + test_chains_with_monotonic_variance(chains_product, Zeros(length(chains_product)), std_true_dict) + end +end + diff --git a/test/test_utils.jl b/test/test_utils.jl new file mode 100644 index 0000000..d343a8e --- /dev/null +++ b/test/test_utils.jl @@ -0,0 +1,126 @@ +using MCMCDiagnosticTools, Statistics, DataFrames + +""" + to_dict(c::MCMCChains.Chains[, col::Symbol]) + +Return a dictionary mapping parameter names to the values in column `col` of `c`. + +# Arguments +- `c`: A `MCMCChains.Chains` object. +- `col`: The column to extract values from. Defaults to the first column that is not `:parameters`. +""" +to_dict(c::MCMCChains.ChainDataFrame) = to_dict(c, first(filter(!=(:parameters), keys(c.nt)))) +function to_dict(c::MCMCChains.ChainDataFrame, col::Symbol) + df = DataFrame(c) + return Dict(sym => df[findfirst(==(sym), df[:, :parameters]), col] for sym in df.parameters) +end + +""" + atol_for_chain(chain; significance=1e-3, kind=Statistics.mean) + +Return a dictionary of absolute tolerances for each parameter in `chain`, computed +as the confidence interval width for the mean of the parameter with `significance`. +""" +function atol_for_chain(chain; significance=1e-3, kind=Statistics.mean) + param_names = names(chain, :parameters) + # Can reject H0 if, say, `abs(mean(chain2) - mean(chain1)) > confidence_width`. + # Or alternatively, compare means but with `atol` set to the `confidence_width`. + # NOTE: Failure to reject, i.e. passing the tests, does not imply that the means are equal. + mcse = to_dict(MCMCChains.mcse(chain; kind), :mcse) + return Dict(sym => quantile(Normal(0, mcse[sym]), 1 - significance/2) for sym in param_names) +end + +thin_to(chain, n) = chain[1:length(chain) ÷ n:end] + +""" + test_means(chain, mean_true; kwargs...) + +Test that the mean of each parameter in `chain` is approximately `mean_true`. + +# Arguments +- `chain`: A `MCMCChains.Chains` object. +- `mean_true`: A `Real` or `AbstractDict` mapping parameter names to their true mean. +- `kwargs...`: Passed to `atol_for_chain`. +""" +function test_means(chain::MCMCChains.Chains, mean_true::Real; kwargs...) + return test_means(chain, Dict(sym => mean_true for sym in names(chain, :parameters)); kwargs...) +end +function test_means(chain::MCMCChains.Chains, mean_true::AbstractDict; n=length(chain), kwargs...) + chain = thin_to(chain, n) + atol = atol_for_chain(chain; kwargs...) + @test all(isapprox(mean(chain[sym]), 0, atol=atol[sym]) for sym in names(chain, :parameters)) +end + +""" + test_std(chain, std_true; kwargs...) + +Test that the standard deviation of each parameter in `chain` is approximately `std_true`. + +# Arguments +- `chain`: A `MCMCChains.Chains` object. +- `std_true`: A `Real` or `AbstractDict` mapping parameter names to their true standard deviation. +- `kwargs...`: Passed to `atol_for_chain`. +""" +function test_std(chain::MCMCChains.Chains, std_true::Real; kwargs...) + return test_std(chain, Dict(sym => std_true for sym in names(chain, :parameters)); kwargs...) +end +function test_std(chain::MCMCChains.Chains, std_true::AbstractDict; n=length(chain), kwargs...) + chain = thin_to(chain, n) + atol = atol_for_chain(chain; kind=Statistics.std, kwargs...) + @info "std" [(std(chain[sym]), std_true[sym], atol[sym]) for sym in names(chain, :parameters)] + @test all(isapprox(std(chain[sym]), std_true[sym], atol=atol[sym]) for sym in names(chain, :parameters)) +end + +""" + test_std_monotonicity(chains; isbroken=false, kwargs...) + +Test that the standard deviation of each parameter in `chains` is monotonically increasing. + +# Arguments +- `chains`: A vector of `MCMCChains.Chains` objects. +- `isbroken`: If `true`, then the test will be marked as broken. +- `kwargs...`: Passed to `atol_for_chain`. +""" +function test_std_monotonicity(chains::AbstractVector{<:MCMCChains.Chains}; isbroken::Bool=false, kwargs...) + param_names = names(first(chains), :parameters) + # We should technically use a Bonferroni-correction here, but whatever. + atols = [atol_for_chain(chain; kind=Statistics.std, kwargs...) for chain in chains] + stds = [Dict(sym => std(chain[sym]) for sym in param_names) for chain in chains] + + num_chains = length(chains) + lbs = [Dict(sym => stds[i][sym] - atols[i][sym] for sym in param_names) for i in 1:num_chains] + ubs = [Dict(sym => stds[i][sym] + atols[i][sym] for sym in param_names) for i in 1:num_chains] + + for i = 2:num_chains + for sym in param_names + # If the upper-bound of the current is smaller than the lower-bound of the previous, then + # we can reject the null hypothesis that they are orderd. + if isbroken + @test_broken ubs[i][sym] ≥ lbs[i - 1][sym] + else + @test ubs[i][sym] ≥ lbs[i - 1][sym] + end + end + end +end + +""" + test_chains_with_monotonic_variance(chains, mean_true, std_true; significance=1e-3, kwargs...) + +Test that the mean and standard deviation of each parameter in `chains` is approximately `mean_true` +and `std_true`, respectively. Also test that the standard deviation is monotonically increasing. + +# Arguments +- `chains`: A vector of `MCMCChains.Chains` objects. +- `mean_true`: A vector of `Real` or `AbstractDict` mapping parameter names to their true mean. +- `std_true`: A vector of `Real` or `AbstractDict` mapping parameter names to their true standard deviation. +- `significance`: The significance level of the test. +- `kwargs...`: Passed to `atol_for_chain`. +""" +function test_chains_with_monotonic_variance(chains, mean_true, std_true; significance=1e-4, kwargs...) + @testset "chain $i" for i = 1:length(chains) + test_means(chains[i], mean_true[i]; kwargs...) + test_std(chains[i], std_true[i]; kwargs...) + end + test_std_monotonicity(chains; significance=0.05) +end