diff --git a/src/samplers/multi.jl b/src/samplers/multi.jl index 636ea0f..db4fa02 100644 --- a/src/samplers/multi.jl +++ b/src/samplers/multi.jl @@ -158,20 +158,51 @@ function AbstractMCMC.step( return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states)) end -# function AbstractMCMC.step( -# rng::Random.AbstractRNG, -# model::MultiModel, -# sampler::RepeatedSampler{<:MultiSampler}, -# states::SequentialStates{<:MultipleStates}; -# kwargs... -# ) -# multisampler = sampler.sampler -# multistates = last(states.states) -# @assert length(model.models) == length(multisampler.samplers) == length(states.states) "Number of models, samplers, and states must be equal." -# transition_and_states = asyncmap(model.models, multisampler.samplers, multistates.states) do model, sampler, state -# # Just re-wrap each of the samplers in a `RepeatedSampler` and call it's implementation. -# AbstractMCMC.step(rng, model, RepeatedSampler(sampler, sampler.num_repeat), state; kwargs...) -# end - -# return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states)) -# end +# NOTE: In the case of a `RepeatedSampler{<:MultiSampler}`, it's better to, effectively, re-order +# the samplers so that we make a `MultiSampler` of `RepeatedSampler`s. +# We don't want to mutate the sampler, so instead we just convert the sequence of multi-states into +# a multi-state of sequential states, and then work with this ordering in subsequent calls to `step`. +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + repeated_sampler::RepeatedSampler{<:MultiSampler}, + states::SequentialStates; + kwargs... +) + @debug "Working with RepeatedSampler{<:MultiSampler}; converting a sequence of multi-states into a multi-state of sequential states" + + multisampler = repeated_sampler.sampler + multistates = last(states.states) + @assert length(model.models) == length(multisampler.samplers) == length(multistates.states) "Number of models $(length(model.models)), samplers $(length(multisampler.samplers)), and states $(length(multistates.states)) must be equal." + transition_and_states = asyncmap(model.models, multisampler.samplers, multistates.states) do model, sampler, state + # Just re-wrap each of the samplers in a `RepeatedSampler` and call it's implementation. + AbstractMCMC.step( + rng, model, RepeatedSampler(sampler, repeated_sampler.num_repeat), SequentialStates([state]); + kwargs... + ) + end + + return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states)) +end + +# And then we define how `RepeatedSampler{<:MultiSampler}` should work with a `MultipleStates`. +# NOTE: If `saveall(sampler)` is `false`, this is also the implementation we'll hit. +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::MultiModel, + repeated_sampler::RepeatedSampler{<:MultiSampler}, + multistates::MultipleStates; + kwargs... +) + multisampler = repeated_sampler.sampler + @assert length(model.models) == length(multisampler.samplers) == length(multistates.states) "Number of models $(length(model.models)), samplers $(length(multisampler.samplers)), and states $(length(multistates.states)) must be equal." + transition_and_states = asyncmap(model.models, multisampler.samplers, multistates.states) do model, sampler, state + # Just re-wrap each of the samplers in a `RepeatedSampler` and call it's implementation. + AbstractMCMC.step( + rng, model, RepeatedSampler(sampler, repeated_sampler.num_repeat), state; + kwargs... + ) + end + + return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states)) +end