- 
                Notifications
    You must be signed in to change notification settings 
- Fork 36
Description
Please move this over to Turing.jl if appropriate, but I believe the issue lies in a function in DynamicPPL.
There is an issue with a number of samplers where using the resume_from, using the state from a previous sample with the resume from does not use the state for the next sample. This can be demonstrated simply by code that fails to error:
(jl_L8WFYz) pkg> status
Status `/tmp/jl_L8WFYz/Project.toml`
  [fce5fe82] Turing v0.40.2
julia> using Turing
julia> @model function model_test()
       x~Normal(0,1)
       y~Normal(x,1)
       end
model_test (generic function with 2 methods)
julia> tst_chn = sample(
           model_test() | (y=0.5,), 
           MH(),
           MCMCSerial(),
           3,
           1,
           save_state = true)
Sampling (Chain 1 of 1) 100%|███████████████████████████████| Time: 0:00:06
Chains MCMC chain (3×4×1 Array{Float64, 3}):
Iterations        = 1:1:3
Number of chains  = 1
Samples per chain = 3
Wall duration     = 5.98 seconds
Compute duration  = 5.98 seconds
parameters        = x
internals         = lp, logprior, loglikelihood
Use `describe(chains)` for summary statistics and quantiles.
julia> tst_chn2 = sample(
            model_test() | (x=0.5,),
            MH(),
            MCMCSerial(),
            3,
            1,
            resume_from=tst_chn.info.samplerstate)
Sampling (Chain 1 of 1) 100%|███████████████████████████████| Time: 0:00:02
Chains MCMC chain (3×4×1 Array{Float64, 3}):
Iterations        = 1:1:3
Number of chains  = 1
Samples per chain = 3
Wall duration     = 1.33 seconds
Compute duration  = 1.33 seconds
parameters        = y
internals         = lp, logprior, loglikelihood
Use `describe(chains)` for summary statistics and quantiles.
This code should error because it should be using the varinfo from my samplerstate which has an x field and not a y field!
Through some debugging I found that during calls to AbstractMCMC.mcmcsample code like this was being run for each of the different parallelism code:
function mcmcsample(
    rng::Random.AbstractRNG,
    model::AbstractModel,
    sampler::AbstractSampler,
    ::MCMCDistributed,
    N::Integer,
    nchains::Integer;
    progress::Union{Bool,Symbol}=PROGRESS[],
    progressname="Sampling ($(Distributed.nworkers()) process$(_pluralise(Distributed.nworkers(); plural="es")))",
    initial_params=nothing,
    initial_state=nothing,
    kwargs...,
)
...
    _initial_state =
        initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
...
endThis means that when the following function in DynamicPPL is called:
end
function AbstractMCMC.sample(
    rng::Random.AbstractRNG,
    model::Model,
    sampler::Sampler,
    N::Integer;
    chain_type=default_chain_type(sampler),
    resume_from=nothing,
    initial_state=loadstate(resume_from),
    kwargs...,
)
    return AbstractMCMC.mcmcsample(
        rng, model, sampler, N; chain_type, initial_state, kwargs...
    )
end
the call to loadstate is overwritten with nothing, so the resume_from state is lost!
I would propose to change this function to check that if resumefrom is passed and initialstate is nothing, to then call loadstate.