Skip to content

Commit

Permalink
Refactoring (#156)
Browse files Browse the repository at this point in the history
* reduced number of files and move functionality related to particular
samplers to their respective files

* removed one line too many

* relaxed a test slightly
  • Loading branch information
torfjelde authored Apr 23, 2023
1 parent ed1ca98 commit bb5381b
Show file tree
Hide file tree
Showing 7 changed files with 683 additions and 669 deletions.
270 changes: 60 additions & 210 deletions src/MCMCTempering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@ using DocStringExtensions

include("logdensityproblems.jl")
include("abstractmcmc.jl")
include("model.jl")
include("adaptation.jl")
include("swapping.jl")
include("state.jl")
include("swapsampler.jl")
include("sampler.jl")
include("tempered_sampler.jl")
include("sampling.jl")
include("ladders.jl")
include("stepping.jl")
include("model.jl")
include("utils.jl")
include("bundle_samples.jl")

export tempered,
tempered_sample,
Expand All @@ -44,220 +43,71 @@ implements_logdensity(x) = LogDensityProblems.capabilities(x) !== nothing
maybe_wrap_model(model) = implements_logdensity(model) ? AbstractMCMC.LogDensityModel(model) : model
maybe_wrap_model(model::AbstractMCMC.LogDensityModel) = model

# Bundling.
# Bundling of non-tempered samples.
function bundle_nontempered_samples(
ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
::Type{T};
kwargs...
) where {T}
# Create the same model and sampler as we do in the initial step for `TemperedSampler`.
multimodel = MultiModel([
make_tempered_model(sampler, model, sampler.chain_to_beta[i])
for i in 1:numtemps(sampler)
])
multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)])
multitransitions = [
MultipleTransitions(sort_by_chain(ProcessOrder(), t.swaptransition, t.transition.transitions))
for t in ts
]

return AbstractMCMC.bundle_samples(
multitransitions,
multimodel,
multisampler,
MultipleStates(sort_by_chain(ProcessOrder(), state.swapstate, state.state.states)),
T;
kwargs...
)
end

function AbstractMCMC.bundle_samples(
ts::Vector{<:MultipleTransitions},
model::MultiModel,
sampler::MultiSampler,
state::MultipleStates,
# TODO: Generalize for any eltype `T`? Then need to overload for `Real`, etc.?
::Type{Vector{MCMCChains.Chains}};
"""
tempered(sampler, inverse_temperatures; kwargs...)
OR
tempered(sampler, num_temps; swap_strategy=ReversibleSwap(), kwargs...)
Return a tempered version of `sampler` using the provided `inverse_temperatures` or
inverse temperatures generated from `num_temps` and the `swap_strategy`.
# Arguments
- `sampler` is an algorithm or sampler object to be used for underlying sampling and to apply tempering to
- The temperature schedule can be defined either explicitly or just as an integer number of temperatures, i.e. as:
- `inverse_temperatures` containing a sequence of 'inverse temperatures' {β₀, ..., βₙ} where 0 ≤ βₙ < ... < β₁ < β₀ = 1
OR
- `num_temps`, specifying the integer number of inverse temperatures to include in a generated `inverse_temperatures`
# Keyword arguments
- `swap_strategy::AbstractSwapStrategy` specifies the method for swapping inverse temperatures between chains
- `steps_per_swap::Integer` steps are carried out between each attempt at a swap
# See also
- [`TemperedSampler`](@ref)
- For more on the swap strategies:
- [`AbstractSwapStrategy`](@ref)
- [`ReversibleSwap`](@ref)
- [`NonReversibleSwap`](@ref)
- [`SingleSwap`](@ref)
- [`SingleRandomSwap`](@ref)
- [`RandomSwap`](@ref)
- [`NoSwap`](@ref)
"""
function tempered(
sampler::AbstractMCMC.AbstractSampler,
num_temps::Integer;
swap_strategy::AbstractSwapStrategy=ReversibleSwap(),
kwargs...
)
return map(1:length(model), model.models, sampler.samplers, state.states) do i, model, sampler, state
AbstractMCMC.bundle_samples([t.transitions[i] for t in ts], model, sampler, state, MCMCChains.Chains; kwargs...)
end
end

# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
function AbstractMCMC.bundle_samples(
ts::Vector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
::Type{Vector{T}};
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
# NOTE: If we can't resolve the swaps, there's not really much we can do in terms
# of bundling the samples.
# TODO: Is this the best we can do?
!bundle_resolve_swaps && return ts

return bundle_nontempered_samples(ts, model, sampler, state, Vector{T}; kwargs...)
end

function AbstractMCMC.bundle_samples(
ts::Vector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
::Type{Vector{MCMCChains.Chains}};
kwargs...
)
return bundle_nontempered_samples(ts, model, sampler, state, Vector{MCMCChains.Chains}; kwargs...)
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
::Type{MCMCChains.Chains};
kwargs...
)
# Extract the transitions ordered, which are ordered according to processes, according to the chains.
ts_actual = [t.transition.transitions[first(t.swaptransition.chain_to_process)] for t in ts]
return AbstractMCMC.bundle_samples(
ts_actual,
model,
sampler_for_chain(sampler, state, 1),
state_for_chain(state, 1),
MCMCChains.Chains;
return tempered(
sampler, generate_inverse_temperatures(num_temps, swap_strategy);
swap_strategy = swap_strategy,
kwargs...
)
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector,
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler,
state::CompositionState,
::Type{T};
kwargs...
) where {T}
# In the case of `!saveall(sampler)`, the state is not a `CompositionTransition` so we just propagate
# the transitions to the `bundle_samples` for the outer stuff. Otherwise, we flatten the transitions.
ts_actual = saveall(sampler) ? mapreduce(t -> [inner_transition(t), outer_transition(t)], vcat, ts) : ts
# TODO: Should we really always default to outer sampler?
return AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
kwargs...
)
end

# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
function AbstractMCMC.bundle_samples(
ts::Vector,
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler,
state::CompositionState,
::Type{Vector{T}};
kwargs...
) where {T}
if !saveall(sampler)
# In this case, we just use the `outer` for everything since this is the only
# transitions we're keeping around.
return AbstractMCMC.bundle_samples(
ts, model, sampler.sampler_outer, state.state_outer, Vector{T};
kwargs...
)
end

# Otherwise, we don't know what to do.
return ts
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
state::CompositionState{<:MultipleStates,<:SwapState},
::Type{T};
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
# NOTE: If we can't resolve the swaps, there's not really much we can do in terms
# of bundling the samples.
# TODO: Is this the best we can do?
!bundle_resolve_swaps && return ts

# Resolve the swaps (using the already implemented resolution in `composition_transition`
# for this particular sampler but without `saveall`).
sampler_without_saveall = @set sampler.saveall = Val(false)
ts_actual = map(ts) do t
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
end

return AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
kwargs...
)
end

# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
function AbstractMCMC.bundle_samples(
ts::Vector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
state::CompositionState{<:MultipleStates,<:SwapState},
::Type{Vector{T}};
bundle_resolve_swaps::Bool=false,
kwargs...
) where {T}
# NOTE: If we can't resolve the swaps, there's not really much we can do in terms
# of bundling the samples.
# TODO: Is this the best we can do?
!bundle_resolve_swaps && return ts

# Resolve the swaps (using the already implemented resolution in `composition_transition`
# for this particular sampler but without `saveall`).
sampler_without_saveall = @set sampler.saveall = Val(false)
ts_actual = map(ts) do t
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
end

return AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler_outer, state.state_outer, Vector{T};
kwargs...
)
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector,
model::AbstractMCMC.AbstractModel,
sampler::RepeatedSampler,
state,
::Type{MCMCChains.Chains};
kwargs...
)
return AbstractMCMC.bundle_samples(ts, model, sampler.sampler, state, MCMCChains.Chains; kwargs...)
end

# Unflatten in the case of `SequentialTransitions`.
function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:SequentialTransitions},
model::AbstractMCMC.AbstractModel,
sampler::RepeatedSampler,
state::SequentialStates,
::Type{MCMCChains.Chains};
function tempered(
sampler::AbstractMCMC.AbstractSampler,
inverse_temperatures::Vector{<:Real};
swap_strategy::AbstractSwapStrategy=ReversibleSwap(),
steps_per_swap::Integer=1,
adapt::Bool=false,
adapt_target::Real=0.234,
adapt_stepsize::Real=1,
adapt_eta::Real=0.66,
adapt_schedule=Geometric(),
adapt_scale=defaultscale(adapt_schedule, inverse_temperatures),
kwargs...
)
ts_actual = [t for tseq in ts for t in tseq.transitions]
return AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler, state.states[end], MCMCChains.Chains;
kwargs...
!(adapt && typeof(swap_strategy) <: Union{RandomSwap, SingleRandomSwap}) || error("Adaptation of the inverse temperature ladder is not currently supported under the chosen swap strategy.")
steps_per_swap > 0 || error("`steps_per_swap` must take a positive integer value.")
inverse_temperatures = check_inverse_temperatures(inverse_temperatures)
adaptation_states = init_adaptation(
adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize
)
# NOTE: We just make a repeated sampler for `sampler_inner`.
# TODO: Generalize. Allow passing in a `MultiSampler`, etc.
sampler_inner = sampler^steps_per_swap
return TemperedSampler(sampler_inner, inverse_temperatures, swap_strategy, adapt, adaptation_states)
end

end
Loading

0 comments on commit bb5381b

Please sign in to comment.