- 
                Notifications
    
You must be signed in to change notification settings  - Fork 37
 
Description
Currently there are certain methods which implicitly make assumptions about the functionality of AbstractMCMC.AbstractChains, e.g. indexing, etc. which are really only present in the particular implementation (though I believe it's also the only implementation) of AbstractChains provided by MCMCChains.Chains.
Thus, we're implicitly depending on MCMCChains.Chains without making it explicit.
In addition, it ends up tying the testing quite heavily to Turing.jl, which is non-ideal.
Possible solutions
1. Converting chains to "dummy" Vector{<:VarInfo}
As breifly discussed in #166 (comment), it seems like one feature that would be very helpful in moving away from this is to provide a method which maps (::Model, ::AbstractChains) -> ::Vector{<:VarInfo}, and vice-versa. Then we could stick to working with VarInfo in DynamicPPL.jl, which would help alleviate the above mentioned issues.
It could for example look something like:
import DynamicPPL: Model, VarInfo, getdist
# Adapted from `Turing.Inference.transitions_from_chain`
function varinfo_from_chain(
    model::Turing.Model,
    chain::MCMCChains.Chains;
    sampler = DynamicPPL.SampleFromPrior()
)
    vi = Turing.VarInfo(model)
    vis = map(1:length(chain)) do i
        c = chain[i]
        md = vi.metadata
        for v in keys(md)
            for vn in md[v].vns
                vn_sym = Symbol(vn)
                # Cannot use `vn_sym` to index in the chain
                # so we have to extract the corresponding "linear"
                # indices and use those.
                # `ks` is empty if `vn_sym` not in `c`.
                ks = MCMCChains.namesingroup(c, vn_sym)
                if !isempty(ks)
                    # 1st dimension is of size 1 since `c`
                    # only contains a single sample, and the
                    # last dimension is of size 1 since
                    # we're assuming we're working with a single chain.
                    val = copy(vec(c[ks].value))
                    DynamicPPL.setval!(vi, val, vn)
                    DynamicPPL.settrans!(vi, false, vn)
                else
                    error("$vn not present in chain but is required by the model")
                end
            end
        end
        new_vi = VarInfo(vi, sampler, vi[sampler])
        setlogp!(new_vi, first(chain[i][:lp])) # Is there a better way?
        return new_vi
    end
    return vis
endwhich works fairly well (it does what it's supposed to + haven't found any issues yet).
Only downside with the above is that it requires /one/ additional execution of model everytime we convert.