Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduction of compositions and products of samplers and models #151

Merged
merged 14 commits into from
Mar 11, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we happy with adding the MCMCChains dependency, we have avoided it until now, not strongly opposed just wanting to make sure this is considered

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about me make them optional dependencies? I.e. using extensions on Julia >1.8 and Requires.jl on older versions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think we maybe should keep MCMCChains as a dep for now because it means that we can just always return a Chains object in sample which will generally be much nicer to work with for a user rather than these nested transition types you see if you don't use MCMCchains 😕

We can look at making it optional later 👍

ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand Down
75 changes: 71 additions & 4 deletions src/MCMCTempering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ using ProgressLogging: ProgressLogging
using ConcreteStructs: @concrete
using Setfield: @set, @set!

using MCMCChains: MCMCChains

using InverseFunctions

using DocStringExtensions

include("logdensityproblems.jl")
include("abstractmcmc.jl")
include("adaptation.jl")
include("swapping.jl")
include("state.jl")
Expand All @@ -39,16 +42,80 @@ implements_logdensity(x) = LogDensityProblems.capabilities(x) !== nothing
maybe_wrap_model(model) = implements_logdensity(model) ? AbstractMCMC.LogDensityModel(model) : model
maybe_wrap_model(model::AbstractMCMC.LogDensityModel) = model

# Bundling.
# TODO: Improve this, somehow.
# TODO: Move this to an extension.
function AbstractMCMC.bundle_samples(
ts::AbstractVector,
ts::AbstractVector{<:TemperedTransition},
model::AbstractMCMC.AbstractModel,
sampler::TemperedSampler,
state::TemperedState,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state::TemperedState,
state,

I think we should drop the type signature on state, as it means a user needs to somehow access the states to call bundle_samples manually, this was in my other PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we are using methods that expects this type, no?
The reason behind being restrictive with types here are:

  1. bundle_samples should really just be compatible with whatever the step can return for that particular combination of inputs, hence it makes sense to specialize on the same types as we do in step.
  2. Not specializing on the state here can potentially lead to unnecessary method ambiguities since this is heavily overloaded method.

chain_type::Type;
::Type{MCMCChains.Chains};
kwargs...
)
return AbstractMCMC.bundle_samples(
map(Base.Fix2(getproperty, :transition), filter(!Base.Fix2(getproperty, :is_swap), ts)), # Remove the swaps.
model,
sampler_for_chain(sampler, state),
state_for_chain(state),
MCMCChains.Chains;
kwargs...
)
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector,
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler,
state::CompositionState,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state::CompositionState,
state,

Same as above

::Type{MCMCChains.Chains};
kwargs...
)
return AbstractMCMC.bundle_samples(
ts, model, sampler.sampler_outer, state.state_outer, MCMCChains.Chains;
kwargs...
)
end

# Unflatten in the case of `SequentialTransitions`
function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:SequentialTransitions},
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler,
state::SequentialStates,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state::SequentialStates,
state,

Same as above

::Type{MCMCChains.Chains};
kwargs...
)
ts_actual = [t for tseq in ts for t in tseq.transitions]
return AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler_outer, state.states[end], MCMCChains.Chains;
kwargs...
)
end

function AbstractMCMC.bundle_samples(
ts::AbstractVector,
model::AbstractMCMC.AbstractModel,
sampler::RepeatedSampler,
state,
::Type{MCMCChains.Chains};
kwargs...
)
return AbstractMCMC.bundle_samples(ts, model, sampler.sampler, state, MCMCChains.Chains; kwargs...)
end

# Unflatten in the case of `SequentialTransitions`.
function AbstractMCMC.bundle_samples(
ts::AbstractVector{<:SequentialTransitions},
model::AbstractMCMC.AbstractModel,
sampler::RepeatedSampler,
state::SequentialStates,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state::SequentialStates,
state,

Same as above

::Type{MCMCChains.Chains};
kwargs...
)
AbstractMCMC.bundle_samples(
ts, maybe_wrap_model(model), sampler_for_chain(sampler, state, 1), state_for_chain(state, 1), chain_type;
ts_actual = [t for tseq in ts for t in tseq.transitions]
return AbstractMCMC.bundle_samples(
ts_actual, model, sampler.sampler, state.states[end], MCMCChains.Chains;
kwargs...
)
end
Expand Down
106 changes: 106 additions & 0 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using Setfield
using AbstractMCMC: AbstractMCMC

import LinearAlgebra: ×

"""
getparams([model, ]state)

Get the parameters from the `state`.

Default implementation uses [`getparams_and_logprob`](@ref).
"""
getparams(state) = first(getparams_and_logprob(state))
getparams(model, state) = first(getparams_and_logprob(model, state))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, have you built this as the default so that to support tempering someone only needs to implement one function: getparams_and_logprob rather than getparams and getlogprob separately?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tempering, compositions, etc. And yes, it's so that the bare minimum to get things working is to implement getparams_and_logprob, and then if you want to make sure that every computational path is as fast as possible, you can implement getparams, etc. For example, in certain cases we might only need the parameters and so it might be unnecessary to also get the logprob.


"""
getlogprob([model, ]state)

Get the log probability of the `state`.

Default implementation uses [`getparams_and_logprob`](@ref).
"""
getlogprob(state) = last(getparams_and_logprob(state))
getlogprob(model, state) = last(getparams_and_logprob(model, state))

"""
getparams_and_logprob([model, ]state)

Return a vector of parameters from the `state`.

See also: [`setparams_and_logprob!!`](@ref).
"""
getparams_and_logprob(model, state) = getparams_and_logprob(state)

"""
setparams_and_logprob!!([model, ]state, params)

Set the parameters in the state to `params`, possibly mutating if it makes sense.

See also: [`getparams_and_logprob`](@ref).
"""
setparams_and_logprob!!(model, state, params, logprob) = setparams_and_logprob!!(state, params, logprob)

"""
state_from(model, state_target, state_source[, transition_source, transition_target])

Return a new state similar to `state_target` but updated from `state_source`, which could be
a different type of state.
"""
function state_from(model, state_target, state_source, transition_target, transition_source)
return state_from(model, state_target, state_source)
end
function state_from(model, state_target, state_source)
params, logp = getparams_and_logprob(model, state_source)
return setparams_and_logprob!!(model, state_target, params, logp)
end

"""
SequentialTransitions

A `SequentialTransitions` object is a container for a sequence of transitions.
"""
struct SequentialTransitions{A}
transitions::A
end

# Since it's a _sequence_ of transitions, the parameters and logprobs are the ones of the
# last transition/state.
getparams_and_logprob(transitions::SequentialTransitions) = getparams_and_logprob(transitions.transitions[end])
function getparams_and_logprob(model, transitions::SequentialTransitions)
return getparams_and_logprob(model, transitions.transitions[end])
end

function setparams_and_logprob!!(transitions::SequentialTransitions, params, logprob)
return @set transitions.transitions[end] = setparams_and_logprob!!(transitions.transitions[end], params, logprob)
end
function setparams_and_logprob!!(model, transitions::SequentialTransitions, params, logprob)
return @set transitions.transitions[end] = setparams_and_logprob!!(model, transitions.transitions[end], params, logprob)
end

"""
SequentialStates

A `SequentialStates` object is a container for a sequence of states.
"""
struct SequentialStates{A}
states::A
end

# Since it's a _sequence_ of transitions, the parameters and logprobs are the ones of the
# last transition/state.
getparams_and_logprob(state::SequentialStates) = getparams_and_logprob(state.states[end])
getparams_and_logprob(model, state::SequentialStates) = getparams_and_logprob(model, state.states[end])

function setparams_and_logprob!!(state::SequentialStates, params, logprob)
return @set state.states[end] = setparams_and_logprob!!(state.states[end], params, logprob)
end
function setparams_and_logprob!!(model, state::SequentialStates, params, logprob)
return @set state.states[end] = setparams_and_logprob!!(model, state.states[end], params, logprob)
end

# Includes.
include("samplers/composition.jl")
include("samplers/repeated.jl")
include("samplers/multi.jl")

11 changes: 8 additions & 3 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ function tempered(
sampler::AbstractMCMC.AbstractSampler,
inverse_temperatures::Vector{<:Real};
swap_strategy::AbstractSwapStrategy=ReversibleSwap(),
swap_every::Integer=10,
# TODO: Change `swap_every` to something like `number_of_iterations_per_swap`.
swap_every::Integer=1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with a name change proposal, I was thinking it might be worth adopting the micro- and macro-step terms I find myself using sometimes. Just an idea but we could make the arg for tempered_sample n_macro_steps and then for tempered n_micro_steps, equally could be inner and outer to match the compositional sampler definition.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we've addressed this nicely in the other PR; agree?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm sort of, see my comment on that PR

adapt::Bool=false,
adapt_target::Real=0.234,
adapt_stepsize::Real=1,
Expand All @@ -108,10 +109,14 @@ function tempered(
kwargs...
)
!(adapt && typeof(swap_strategy) <: Union{RandomSwap, SingleRandomSwap}) || error("Adaptation of the inverse temperature ladder is not currently supported under the chosen swap strategy.")
swap_every > 1 || error("`swap_every` must take a positive integer value greater than 1.")
swap_every 1 || error("`swap_every` must take a positive integer value greater 1.")
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
inverse_temperatures = check_inverse_temperatures(inverse_temperatures)
adaptation_states = init_adaptation(
adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize
)
return TemperedSampler(sampler, inverse_temperatures, swap_every, swap_strategy, adapt, adaptation_states)
# NOTE: We just make a repeated sampler for `sampler_inner`.
# TODO: Generalize. Allow passing in a `MultiSampler`, etc.
sampler_inner = sampler^swap_every
HarrisonWilde marked this conversation as resolved.
Show resolved Hide resolved
# FIXME: Remove the hard-coded `2` for swap-every, and change `should_swap` acoordingly.
return TemperedSampler(sampler_inner, inverse_temperatures, 2, swap_strategy, adapt, adaptation_states)
end
131 changes: 131 additions & 0 deletions src/samplers/composition.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
CompositionSampler <: AbstractMCMC.AbstractSampler

A `CompositionSampler` is a container for a sequence of samplers.

# Fields
$(FIELDS)

# Examples
```julia
composed_sampler = sampler_inner ∘ sampler_outer # or `CompositionSampler(sampler_inner, sampler_outer, Val(true))`
AbstractMCMC.step(rng, model, composed_sampler) # one step of `sampler_inner`, and one step of `sampler_outer`
```
"""
struct CompositionSampler{S1,S2,SaveAll} <: AbstractMCMC.AbstractSampler
"The outer sampler"
sampler_outer::S1
"The inner sampler"
sampler_inner::S2
"Whether to save all the transitions or just the last one"
saveall::SaveAll
end

CompositionSampler(sampler_outer, sampler_inner) = CompositionSampler(sampler_outer, sampler_inner, Val(true))

Base.:∘(s_outer::AbstractMCMC.AbstractSampler, s_inner::AbstractMCMC.AbstractSampler) = CompositionSampler(s_outer, s_inner)

"""
saveall(sampler)

Return whether the sampler saves all the transitions or just the last one.
"""
saveall(sampler::CompositionSampler) = sampler.saveall
saveall(::CompositionSampler{<:Any,<:Any,Val{SaveAll}}) where {SaveAll} = SaveAll

"""
CompositionState

A `CompositionState` is a container for a sequence of states.

# Fields
$(FIELDS)
"""
struct CompositionState{S1,S2}
"The outer state"
state_outer::S1
"The inner state"
state_inner::S2
end

getparams_and_logprob(state::CompositionState) = getparams_and_logprob(state.state_outer)
getparams_and_logprob(model, state::CompositionState) = getparams_and_logprob(model, state.state_outer)

function setparams_and_logprob!!(state::CompositionState, params, logprob)
return @set state.state_outer = setparams_and_logprob!!(state.state_outer, params, logprob)
end
function setparams_and_logprob!!(model, state::CompositionState, params, logprob)
return @set state.state_outer = setparams_and_logprob!!(model, state.state_outer, params, logprob)
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler;
kwargs...
)
state_inner_initial = last(AbstractMCMC.step(rng, model, sampler.sampler_inner; kwargs...))
state_outer_initial = last(AbstractMCMC.step(rng, model, sampler.sampler_outer; kwargs...))

# Create the composition state, and take a full step.
state = if saveall(sampler)
SequentialStates((state_inner_initial, state_outer_initial))
else
CompositionState(state_outer_initial, state_inner_initial)
end
return AbstractMCMC.step(rng, model, sampler, state; kwargs...)
end

# TODO: Do we even need two versions? We could technically use `SequentialStates`
# in place of `CompositionState` and just have one version.
# The annoying part here is that we'll have to check `saveall` on every `step`
# rather than just for the initial step.

# NOTE: Version which does keep track of all transitions and states.
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler,
state::SequentialStates;
kwargs...
)
@assert length(state.states) == 2 "Composition samplers only support SequentialStates with two states."

state_inner_prev, state_outer_prev = state.states

# Update the inner state.
current_state_inner = state_from(model, state_inner_prev, state_outer_prev)

# Take a step in the inner sampler.
transition_inner, state_inner = AbstractMCMC.step(rng, model, sampler.sampler_inner, current_state_inner; kwargs...)

# Take a step in the outer sampler.
current_state_outer = state_from(model, state_outer_prev, state_inner)
transition_outer, state_outer = AbstractMCMC.step(rng, model, sampler.sampler_outer, current_state_outer; kwargs...)

return SequentialTransitions((transition_inner, transition_outer)), SequentialStates((state_inner, state_outer))
end

# NOTE: Version which does NOT keep track of all transitions and states.
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.AbstractModel,
sampler::CompositionSampler,
state::CompositionState;
kwargs...
)
# Update the inner state.
current_state_inner = state_from(model, state.state_inner, state.state_outer)

# Take a step in the inner sampler.
state_inner = last(AbstractMCMC.step(rng, model, sampler.sampler_inner, current_state_inner; kwargs...))

# Take a step in the outer sampler.
current_state_outer = state_from(model, state.state_outer, state_inner)
transition_outer, state_outer = AbstractMCMC.step(rng, model, sampler.sampler_outer, current_state_outer; kwargs...)

# Create the composition state.
state = CompositionState(state_outer, state_inner)

return transition_outer, state
end
Loading