Skip to content

Implicit dependence on MCMCChains #170

@torfjelde

Description

@torfjelde

Currently there are certain methods which implicitly make assumptions about the functionality of AbstractMCMC.AbstractChains, e.g. indexing, etc. which are really only present in the particular implementation (though I believe it's also the only implementation) of AbstractChains provided by MCMCChains.Chains.
Thus, we're implicitly depending on MCMCChains.Chains without making it explicit.

In addition, it ends up tying the testing quite heavily to Turing.jl, which is non-ideal.

Possible solutions

1. Converting chains to "dummy" Vector{<:VarInfo}

As breifly discussed in #166 (comment), it seems like one feature that would be very helpful in moving away from this is to provide a method which maps (::Model, ::AbstractChains) -> ::Vector{<:VarInfo}, and vice-versa. Then we could stick to working with VarInfo in DynamicPPL.jl, which would help alleviate the above mentioned issues.

It could for example look something like:

import DynamicPPL: Model, VarInfo, getdist
# Adapted from `Turing.Inference.transitions_from_chain`
function varinfo_from_chain(
    model::Turing.Model,
    chain::MCMCChains.Chains;
    sampler = DynamicPPL.SampleFromPrior()
)
    vi = Turing.VarInfo(model)

    vis = map(1:length(chain)) do i
        c = chain[i]
        md = vi.metadata
        for v in keys(md)
            for vn in md[v].vns
                vn_sym = Symbol(vn)

                # Cannot use `vn_sym` to index in the chain
                # so we have to extract the corresponding "linear"
                # indices and use those.
                # `ks` is empty if `vn_sym` not in `c`.
                ks = MCMCChains.namesingroup(c, vn_sym)

                if !isempty(ks)
                    # 1st dimension is of size 1 since `c`
                    # only contains a single sample, and the
                    # last dimension is of size 1 since
                    # we're assuming we're working with a single chain.
                    val = copy(vec(c[ks].value))
                    DynamicPPL.setval!(vi, val, vn)
                    DynamicPPL.settrans!(vi, false, vn)
                else
                    error("$vn not present in chain but is required by the model")
                end
            end
        end
        new_vi = VarInfo(vi, sampler, vi[sampler])
        setlogp!(new_vi, first(chain[i][:lp])) # Is there a better way?

        return new_vi
    end

    return vis
end

which works fairly well (it does what it's supposed to + haven't found any issues yet).

Only downside with the above is that it requires /one/ additional execution of model everytime we convert.

2. ???

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions