Skip to content

Rework sample() call stack to use LogDensityFunction #2555

Open
@penelopeysm

Description

@penelopeysm

Overview of progress

Current situation

Right now, if you call sample(::Model, ::InferenceAlgorithm, ::Int) this first goes to src/mcmc/Inference.jl where the InferenceAlgorithm gets wrapped in DynamicPPL.Sampler, e.g.

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::InferenceAlgorithm,
N::Integer;
check_model::Bool=true,
kwargs...,
)
check_model && _check_model(model, alg)
return AbstractMCMC.sample(rng, model, Sampler(alg), N; kwargs...)
end

This then goes to src/mcmc/$sampler.jl which defines the methods sample(::Model, ::Sampler{<:InferenceAlgorithm}, ::Int), e.g.

Turing.jl/src/mcmc/hmc.jl

Lines 82 to 104 in 5acc97f

function AbstractMCMC.sample(
rng::AbstractRNG,
model::DynamicPPL.Model,
sampler::Sampler{<:AdaptiveHamiltonian},
N::Integer;
chain_type=DynamicPPL.default_chain_type(sampler),
resume_from=nothing,
initial_state=DynamicPPL.loadstate(resume_from),
progress=PROGRESS[],
nadapts=sampler.alg.n_adapts,
discard_adapt=true,
discard_initial=-1,
kwargs...,
)
if resume_from === nothing
# If `nadapts` is `-1`, then the user called a convenience
# constructor like `NUTS()` or `NUTS(0.65)`,
# and we should set a default for them.
if nadapts == -1
_nadapts = min(1000, N ÷ 2)
else
_nadapts = nadapts
end

This then goes to AbstractMCMC's sample:

https://github.com/TuringLang/AbstractMCMC.jl/blob/fdaa0ebce22ce227b068e847415cd9ee0e15c004/src/sample.jl#L255-L259

Which then calls step(::AbstractRNG, ::Model, ::Sampler{<:InferenceAlgorithm}), which is defined in DynamicPPL:

https://github.com/TuringLang/DynamicPPL.jl/blob/072234d094d1d68064bf259d3c3e815a87c18c8e/src/sampler.jl#L108-L126

Which then calls initialstep, which goes back to being defined in src/mcmc/$sampler.jl:

Turing.jl/src/mcmc/hmc.jl

Lines 141 to 149 in 5acc97f

function DynamicPPL.initialstep(
rng::AbstractRNG,
model::AbstractModel,
spl::Sampler{<:Hamiltonian},
vi_original::AbstractVarInfo;
initial_params=nothing,
nadapts=0,
kwargs...,
)

(this signature claims to work on AbstractModel, it only really works for DynamicPPL.Model)

Inside here, we finally construct a LogDensityFunction from the model. So, there are very many steps between the time that sample() is called, and the time where a LogDensityFunction is actually constructed.

Proposal

Rework everything below the very first call to accept LogDensityFunction rather than Model. That is to say, the method sample(::Model, ::InferenceAlgorithm, ::Int) should look something like this:

function sample(rng::Random.AbstracRNG, model::Model, alg::InferenceAlgorithm, N::Int, kwargs...)
    adtype = get_adtype(alg) # returns nothing (e.g. for MH) or AbstractADType (e.g. NUTS)
    ldf = DynamicPPL.LogDensityFunction(model; adtype=adtype)
    spl = DynamicPPL.Sampler(alg)
    sample(rng, ldf, spl, N; kwargs...)
end

# similar methods for sample(::Random.AbstractRNG, ::Model, ::Sampler{<:InferenceAlgorithm})
# as well as sample(:Random.AbstractRNG, ::LogDensityFunction, ::InferenceAlgorithm)

function sample(
    rng::Random.AbstractRNG,
    ldf::DynamicPPL.LogDensityFunction,
    spl::DynamicPPL.Sampler{<:InferenceAlgorithm}, N::Int;
    chain_type=MCMCChains.Chains, check=true, ...
)
    # All of Turing's magic behaviour, e.g. checking the model, setting the chain_type,
    # should happen in this method.
    check && check_model(ldf.model)
    ... (whatever else)
    # Then we can call mcmcsample, and that means that everything from mcmcsample
    # onwards _knows_ that it's dealing with a LogDensityFunction and a Sampler{...}.
    AbstractMCMC.mcmcsample(...)
end

# handle rng-less methods (ugly boilerplate, because somebody decided that the *optional*
# rng argument should be the first argument and not a keyword argument, ugh!)
function sample(model::Model, alg::InferenceAlgorithm, N::Int, kwargs...)
    sample(Random.default_rng(), model, alg, N; kwargs...)
end
# similar for sample(::Model, ::Sampler{<:InferenceAlgorithm})
#  as well as sample(::LogDensityFunction, ::InferenceAlgorithm)
#  as well as sample(::LogDensityFunction, ::Sampler{<:InferenceAlgorithm})

# oh don't forget to handle the methods with MCMCEnsemble... argh...
# so that's 16 methods to make sure that everything works correctly
# - 2x with/without rng
# - 2x for Model and LDF
# - 2x for InferenceAlgorithm and Sampler{...}
# - 2x for standard and parallel

This would require making several changes across DynamicPPL and Turing. It (thankfully) probably does not need to touch AbstractMCMC, as long as we make LogDensityFunction a subtype of AbstractMCMC.AbstractModel (so that mcmcsample can work). That should be fine, because AbstractModel has no interface.

Why?

For one, this is probably the best way to let people have greater control over their sampling process. For example:

More philosophically, it's IMO the first step that's necessary towards encapsulating Turing's "magic behaviour" at the very top level of the call stack. We know that a DynamicPPL.Model on its own does not actually give enough information about how to evaluate it — it's only LogDensityFunction that contains the necessary information. Thus, it shouldn't be the job of the low-level functions like step to make this decision — they should just 'receive' objects that are already complete.

I'm still not sure about

How this will work with Gibbs. I haven't looked at it deeply enough.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions