-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduction of compositions and products of samplers and models #151
Changes from 12 commits
b8a3111
ae8588e
7db1dce
10905f2
9e08c3a
a90845a
a77c32a
38c0a52
ad22454
933c0bd
0b44758
a4016e1
ef97a94
180a928
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -9,11 +9,14 @@ using ProgressLogging: ProgressLogging | |||||
using ConcreteStructs: @concrete | ||||||
using Setfield: @set, @set! | ||||||
|
||||||
using MCMCChains: MCMCChains | ||||||
|
||||||
using InverseFunctions | ||||||
|
||||||
using DocStringExtensions | ||||||
|
||||||
include("logdensityproblems.jl") | ||||||
include("abstractmcmc.jl") | ||||||
include("adaptation.jl") | ||||||
include("swapping.jl") | ||||||
include("state.jl") | ||||||
|
@@ -39,16 +42,80 @@ implements_logdensity(x) = LogDensityProblems.capabilities(x) !== nothing | |||||
maybe_wrap_model(model) = implements_logdensity(model) ? AbstractMCMC.LogDensityModel(model) : model | ||||||
maybe_wrap_model(model::AbstractMCMC.LogDensityModel) = model | ||||||
|
||||||
# Bundling. | ||||||
# TODO: Improve this, somehow. | ||||||
# TODO: Move this to an extension. | ||||||
function AbstractMCMC.bundle_samples( | ||||||
ts::AbstractVector, | ||||||
ts::AbstractVector{<:TemperedTransition}, | ||||||
model::AbstractMCMC.AbstractModel, | ||||||
sampler::TemperedSampler, | ||||||
state::TemperedState, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think we should drop the type signature on state, as it means a user needs to somehow access the states to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But we are using methods that expects this type, no?
|
||||||
chain_type::Type; | ||||||
::Type{MCMCChains.Chains}; | ||||||
kwargs... | ||||||
) | ||||||
return AbstractMCMC.bundle_samples( | ||||||
map(Base.Fix2(getproperty, :transition), filter(!Base.Fix2(getproperty, :is_swap), ts)), # Remove the swaps. | ||||||
model, | ||||||
sampler_for_chain(sampler, state), | ||||||
state_for_chain(state), | ||||||
MCMCChains.Chains; | ||||||
kwargs... | ||||||
) | ||||||
end | ||||||
|
||||||
function AbstractMCMC.bundle_samples( | ||||||
ts::AbstractVector, | ||||||
model::AbstractMCMC.AbstractModel, | ||||||
sampler::CompositionSampler, | ||||||
state::CompositionState, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Same as above |
||||||
::Type{MCMCChains.Chains}; | ||||||
kwargs... | ||||||
) | ||||||
return AbstractMCMC.bundle_samples( | ||||||
ts, model, sampler.sampler_outer, state.state_outer, MCMCChains.Chains; | ||||||
kwargs... | ||||||
) | ||||||
end | ||||||
|
||||||
# Unflatten in the case of `SequentialTransitions` | ||||||
function AbstractMCMC.bundle_samples( | ||||||
ts::AbstractVector{<:SequentialTransitions}, | ||||||
model::AbstractMCMC.AbstractModel, | ||||||
sampler::CompositionSampler, | ||||||
state::SequentialStates, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Same as above |
||||||
::Type{MCMCChains.Chains}; | ||||||
kwargs... | ||||||
) | ||||||
ts_actual = [t for tseq in ts for t in tseq.transitions] | ||||||
return AbstractMCMC.bundle_samples( | ||||||
ts_actual, model, sampler.sampler_outer, state.states[end], MCMCChains.Chains; | ||||||
kwargs... | ||||||
) | ||||||
end | ||||||
|
||||||
function AbstractMCMC.bundle_samples( | ||||||
ts::AbstractVector, | ||||||
model::AbstractMCMC.AbstractModel, | ||||||
sampler::RepeatedSampler, | ||||||
state, | ||||||
::Type{MCMCChains.Chains}; | ||||||
kwargs... | ||||||
) | ||||||
return AbstractMCMC.bundle_samples(ts, model, sampler.sampler, state, MCMCChains.Chains; kwargs...) | ||||||
end | ||||||
|
||||||
# Unflatten in the case of `SequentialTransitions`. | ||||||
function AbstractMCMC.bundle_samples( | ||||||
ts::AbstractVector{<:SequentialTransitions}, | ||||||
model::AbstractMCMC.AbstractModel, | ||||||
sampler::RepeatedSampler, | ||||||
state::SequentialStates, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Same as above |
||||||
::Type{MCMCChains.Chains}; | ||||||
kwargs... | ||||||
) | ||||||
AbstractMCMC.bundle_samples( | ||||||
ts, maybe_wrap_model(model), sampler_for_chain(sampler, state, 1), state_for_chain(state, 1), chain_type; | ||||||
ts_actual = [t for tseq in ts for t in tseq.transitions] | ||||||
return AbstractMCMC.bundle_samples( | ||||||
ts_actual, model, sampler.sampler, state.states[end], MCMCChains.Chains; | ||||||
kwargs... | ||||||
) | ||||||
end | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
using Setfield | ||
using AbstractMCMC: AbstractMCMC | ||
|
||
import LinearAlgebra: × | ||
|
||
""" | ||
getparams([model, ]state) | ||
|
||
Get the parameters from the `state`. | ||
|
||
Default implementation uses [`getparams_and_logprob`](@ref). | ||
""" | ||
getparams(state) = first(getparams_and_logprob(state)) | ||
getparams(model, state) = first(getparams_and_logprob(model, state)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To clarify, have you built this as the default so that to support tempering someone only needs to implement one function: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tempering, compositions, etc. And yes, it's so that the bare minimum to get things working is to implement |
||
|
||
""" | ||
getlogprob([model, ]state) | ||
|
||
Get the log probability of the `state`. | ||
|
||
Default implementation uses [`getparams_and_logprob`](@ref). | ||
""" | ||
getlogprob(state) = last(getparams_and_logprob(state)) | ||
getlogprob(model, state) = last(getparams_and_logprob(model, state)) | ||
|
||
""" | ||
getparams_and_logprob([model, ]state) | ||
|
||
Return a vector of parameters from the `state`. | ||
|
||
See also: [`setparams_and_logprob!!`](@ref). | ||
""" | ||
getparams_and_logprob(model, state) = getparams_and_logprob(state) | ||
|
||
""" | ||
setparams_and_logprob!!([model, ]state, params) | ||
|
||
Set the parameters in the state to `params`, possibly mutating if it makes sense. | ||
|
||
See also: [`getparams_and_logprob`](@ref). | ||
""" | ||
setparams_and_logprob!!(model, state, params, logprob) = setparams_and_logprob!!(state, params, logprob) | ||
|
||
""" | ||
state_from(model, state_target, state_source[, transition_source, transition_target]) | ||
|
||
Return a new state similar to `state_target` but updated from `state_source`, which could be | ||
a different type of state. | ||
""" | ||
function state_from(model, state_target, state_source, transition_target, transition_source) | ||
return state_from(model, state_target, state_source) | ||
end | ||
function state_from(model, state_target, state_source) | ||
params, logp = getparams_and_logprob(model, state_source) | ||
return setparams_and_logprob!!(model, state_target, params, logp) | ||
end | ||
|
||
""" | ||
SequentialTransitions | ||
|
||
A `SequentialTransitions` object is a container for a sequence of transitions. | ||
""" | ||
struct SequentialTransitions{A} | ||
transitions::A | ||
end | ||
|
||
# Since it's a _sequence_ of transitions, the parameters and logprobs are the ones of the | ||
# last transition/state. | ||
getparams_and_logprob(transitions::SequentialTransitions) = getparams_and_logprob(transitions.transitions[end]) | ||
function getparams_and_logprob(model, transitions::SequentialTransitions) | ||
return getparams_and_logprob(model, transitions.transitions[end]) | ||
end | ||
|
||
function setparams_and_logprob!!(transitions::SequentialTransitions, params, logprob) | ||
return @set transitions.transitions[end] = setparams_and_logprob!!(transitions.transitions[end], params, logprob) | ||
end | ||
function setparams_and_logprob!!(model, transitions::SequentialTransitions, params, logprob) | ||
return @set transitions.transitions[end] = setparams_and_logprob!!(model, transitions.transitions[end], params, logprob) | ||
end | ||
|
||
""" | ||
SequentialStates | ||
|
||
A `SequentialStates` object is a container for a sequence of states. | ||
""" | ||
struct SequentialStates{A} | ||
states::A | ||
end | ||
|
||
# Since it's a _sequence_ of transitions, the parameters and logprobs are the ones of the | ||
# last transition/state. | ||
getparams_and_logprob(state::SequentialStates) = getparams_and_logprob(state.states[end]) | ||
getparams_and_logprob(model, state::SequentialStates) = getparams_and_logprob(model, state.states[end]) | ||
|
||
function setparams_and_logprob!!(state::SequentialStates, params, logprob) | ||
return @set state.states[end] = setparams_and_logprob!!(state.states[end], params, logprob) | ||
end | ||
function setparams_and_logprob!!(model, state::SequentialStates, params, logprob) | ||
return @set state.states[end] = setparams_and_logprob!!(model, state.states[end], params, logprob) | ||
end | ||
|
||
# Includes. | ||
include("samplers/composition.jl") | ||
include("samplers/repeated.jl") | ||
include("samplers/multi.jl") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,7 +98,8 @@ function tempered( | |
sampler::AbstractMCMC.AbstractSampler, | ||
inverse_temperatures::Vector{<:Real}; | ||
swap_strategy::AbstractSwapStrategy=ReversibleSwap(), | ||
swap_every::Integer=10, | ||
# TODO: Change `swap_every` to something like `number_of_iterations_per_swap`. | ||
swap_every::Integer=1, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with a name change proposal, I was thinking it might be worth adopting the micro- and macro-step terms I find myself using sometimes. Just an idea but we could make the arg for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we've addressed this nicely in the other PR; agree? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm sort of, see my comment on that PR |
||
adapt::Bool=false, | ||
adapt_target::Real=0.234, | ||
adapt_stepsize::Real=1, | ||
|
@@ -108,10 +109,14 @@ function tempered( | |
kwargs... | ||
) | ||
!(adapt && typeof(swap_strategy) <: Union{RandomSwap, SingleRandomSwap}) || error("Adaptation of the inverse temperature ladder is not currently supported under the chosen swap strategy.") | ||
swap_every > 1 || error("`swap_every` must take a positive integer value greater than 1.") | ||
swap_every ≥ 1 || error("`swap_every` must take a positive integer value greater ≥1.") | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
inverse_temperatures = check_inverse_temperatures(inverse_temperatures) | ||
adaptation_states = init_adaptation( | ||
adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize | ||
) | ||
return TemperedSampler(sampler, inverse_temperatures, swap_every, swap_strategy, adapt, adaptation_states) | ||
# NOTE: We just make a repeated sampler for `sampler_inner`. | ||
# TODO: Generalize. Allow passing in a `MultiSampler`, etc. | ||
sampler_inner = sampler^swap_every | ||
HarrisonWilde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# FIXME: Remove the hard-coded `2` for swap-every, and change `should_swap` acoordingly. | ||
return TemperedSampler(sampler_inner, inverse_temperatures, 2, swap_strategy, adapt, adaptation_states) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
""" | ||
CompositionSampler <: AbstractMCMC.AbstractSampler | ||
|
||
A `CompositionSampler` is a container for a sequence of samplers. | ||
|
||
# Fields | ||
$(FIELDS) | ||
|
||
# Examples | ||
```julia | ||
composed_sampler = sampler_inner ∘ sampler_outer # or `CompositionSampler(sampler_inner, sampler_outer, Val(true))` | ||
AbstractMCMC.step(rng, model, composed_sampler) # one step of `sampler_inner`, and one step of `sampler_outer` | ||
``` | ||
""" | ||
struct CompositionSampler{S1,S2,SaveAll} <: AbstractMCMC.AbstractSampler | ||
"The outer sampler" | ||
sampler_outer::S1 | ||
"The inner sampler" | ||
sampler_inner::S2 | ||
"Whether to save all the transitions or just the last one" | ||
saveall::SaveAll | ||
end | ||
|
||
CompositionSampler(sampler_outer, sampler_inner) = CompositionSampler(sampler_outer, sampler_inner, Val(true)) | ||
|
||
Base.:∘(s_outer::AbstractMCMC.AbstractSampler, s_inner::AbstractMCMC.AbstractSampler) = CompositionSampler(s_outer, s_inner) | ||
|
||
""" | ||
saveall(sampler) | ||
|
||
Return whether the sampler saves all the transitions or just the last one. | ||
""" | ||
saveall(sampler::CompositionSampler) = sampler.saveall | ||
saveall(::CompositionSampler{<:Any,<:Any,Val{SaveAll}}) where {SaveAll} = SaveAll | ||
|
||
""" | ||
CompositionState | ||
|
||
A `CompositionState` is a container for a sequence of states. | ||
|
||
# Fields | ||
$(FIELDS) | ||
""" | ||
struct CompositionState{S1,S2} | ||
"The outer state" | ||
state_outer::S1 | ||
"The inner state" | ||
state_inner::S2 | ||
end | ||
|
||
getparams_and_logprob(state::CompositionState) = getparams_and_logprob(state.state_outer) | ||
getparams_and_logprob(model, state::CompositionState) = getparams_and_logprob(model, state.state_outer) | ||
|
||
function setparams_and_logprob!!(state::CompositionState, params, logprob) | ||
return @set state.state_outer = setparams_and_logprob!!(state.state_outer, params, logprob) | ||
end | ||
function setparams_and_logprob!!(model, state::CompositionState, params, logprob) | ||
return @set state.state_outer = setparams_and_logprob!!(model, state.state_outer, params, logprob) | ||
end | ||
|
||
function AbstractMCMC.step( | ||
rng::Random.AbstractRNG, | ||
model::AbstractMCMC.AbstractModel, | ||
sampler::CompositionSampler; | ||
kwargs... | ||
) | ||
state_inner_initial = last(AbstractMCMC.step(rng, model, sampler.sampler_inner; kwargs...)) | ||
state_outer_initial = last(AbstractMCMC.step(rng, model, sampler.sampler_outer; kwargs...)) | ||
|
||
# Create the composition state, and take a full step. | ||
state = if saveall(sampler) | ||
SequentialStates((state_inner_initial, state_outer_initial)) | ||
else | ||
CompositionState(state_outer_initial, state_inner_initial) | ||
end | ||
return AbstractMCMC.step(rng, model, sampler, state; kwargs...) | ||
end | ||
|
||
# TODO: Do we even need two versions? We could technically use `SequentialStates` | ||
# in place of `CompositionState` and just have one version. | ||
# The annoying part here is that we'll have to check `saveall` on every `step` | ||
# rather than just for the initial step. | ||
|
||
# NOTE: Version which does keep track of all transitions and states. | ||
function AbstractMCMC.step( | ||
rng::Random.AbstractRNG, | ||
model::AbstractMCMC.AbstractModel, | ||
sampler::CompositionSampler, | ||
state::SequentialStates; | ||
kwargs... | ||
) | ||
@assert length(state.states) == 2 "Composition samplers only support SequentialStates with two states." | ||
|
||
state_inner_prev, state_outer_prev = state.states | ||
|
||
# Update the inner state. | ||
current_state_inner = state_from(model, state_inner_prev, state_outer_prev) | ||
|
||
# Take a step in the inner sampler. | ||
transition_inner, state_inner = AbstractMCMC.step(rng, model, sampler.sampler_inner, current_state_inner; kwargs...) | ||
|
||
# Take a step in the outer sampler. | ||
current_state_outer = state_from(model, state_outer_prev, state_inner) | ||
transition_outer, state_outer = AbstractMCMC.step(rng, model, sampler.sampler_outer, current_state_outer; kwargs...) | ||
|
||
return SequentialTransitions((transition_inner, transition_outer)), SequentialStates((state_inner, state_outer)) | ||
end | ||
|
||
# NOTE: Version which does NOT keep track of all transitions and states. | ||
function AbstractMCMC.step( | ||
rng::Random.AbstractRNG, | ||
model::AbstractMCMC.AbstractModel, | ||
sampler::CompositionSampler, | ||
state::CompositionState; | ||
kwargs... | ||
) | ||
# Update the inner state. | ||
current_state_inner = state_from(model, state.state_inner, state.state_outer) | ||
|
||
# Take a step in the inner sampler. | ||
state_inner = last(AbstractMCMC.step(rng, model, sampler.sampler_inner, current_state_inner; kwargs...)) | ||
|
||
# Take a step in the outer sampler. | ||
current_state_outer = state_from(model, state.state_outer, state_inner) | ||
transition_outer, state_outer = AbstractMCMC.step(rng, model, sampler.sampler_outer, current_state_outer; kwargs...) | ||
|
||
# Create the composition state. | ||
state = CompositionState(state_outer, state_inner) | ||
|
||
return transition_outer, state | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we happy with adding the MCMCChains dependency, we have avoided it until now, not strongly opposed just wanting to make sure this is considered
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about me make them optional dependencies? I.e. using extensions on Julia >1.8 and Requires.jl on older versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I think we maybe should keep MCMCChains as a dep for now because it means that we can just always return a
Chains
object insample
which will generally be much nicer to work with for a user rather than these nested transition types you see if you don't use MCMCchains 😕We can look at making it optional later 👍