Skip to content

Commit

Permalink
specialize step for combination of RepeatedSampler and MultiSampler
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Mar 3, 2023
1 parent 0b44758 commit a4016e1
Showing 1 changed file with 48 additions and 17 deletions.
65 changes: 48 additions & 17 deletions src/samplers/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,51 @@ function AbstractMCMC.step(
return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states))
end

# function AbstractMCMC.step(
# rng::Random.AbstractRNG,
# model::MultiModel,
# sampler::RepeatedSampler{<:MultiSampler},
# states::SequentialStates{<:MultipleStates};
# kwargs...
# )
# multisampler = sampler.sampler
# multistates = last(states.states)
# @assert length(model.models) == length(multisampler.samplers) == length(states.states) "Number of models, samplers, and states must be equal."
# transition_and_states = asyncmap(model.models, multisampler.samplers, multistates.states) do model, sampler, state
# # Just re-wrap each of the samplers in a `RepeatedSampler` and call it's implementation.
# AbstractMCMC.step(rng, model, RepeatedSampler(sampler, sampler.num_repeat), state; kwargs...)
# end

# return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states))
# end
# NOTE: In the case of a `RepeatedSampler{<:MultiSampler}`, it's better to, effectively, re-order
# the samplers so that we make a `MultiSampler` of `RepeatedSampler`s.
# We don't want to mutate the sampler, so instead we just convert the sequence of multi-states into
# a multi-state of sequential states, and then work with this ordering in subsequent calls to `step`.
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::MultiModel,
repeated_sampler::RepeatedSampler{<:MultiSampler},
states::SequentialStates;
kwargs...
)
@debug "Working with RepeatedSampler{<:MultiSampler}; converting a sequence of multi-states into a multi-state of sequential states"

multisampler = repeated_sampler.sampler
multistates = last(states.states)
@assert length(model.models) == length(multisampler.samplers) == length(multistates.states) "Number of models $(length(model.models)), samplers $(length(multisampler.samplers)), and states $(length(multistates.states)) must be equal."
transition_and_states = asyncmap(model.models, multisampler.samplers, multistates.states) do model, sampler, state
# Just re-wrap each of the samplers in a `RepeatedSampler` and call it's implementation.
AbstractMCMC.step(
rng, model, RepeatedSampler(sampler, repeated_sampler.num_repeat), SequentialStates([state]);
kwargs...
)
end

return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states))
end

# And then we define how `RepeatedSampler{<:MultiSampler}` should work with a `MultipleStates`.
# NOTE: If `saveall(sampler)` is `false`, this is also the implementation we'll hit.
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::MultiModel,
repeated_sampler::RepeatedSampler{<:MultiSampler},
multistates::MultipleStates;
kwargs...
)
multisampler = repeated_sampler.sampler
@assert length(model.models) == length(multisampler.samplers) == length(multistates.states) "Number of models $(length(model.models)), samplers $(length(multisampler.samplers)), and states $(length(multistates.states)) must be equal."
transition_and_states = asyncmap(model.models, multisampler.samplers, multistates.states) do model, sampler, state
# Just re-wrap each of the samplers in a `RepeatedSampler` and call it's implementation.
AbstractMCMC.step(
rng, model, RepeatedSampler(sampler, repeated_sampler.num_repeat), state;
kwargs...
)
end

return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states))
end

0 comments on commit a4016e1

Please sign in to comment.