Description
Overview of progress
- Hamiltonians:
hmc.jl
,sghmc.jl
,DynamicHMCExt
sample with LogDensityFunction: part 1 -hmc.jl
,sghmc.jl
,DynamicHMCExt
#2588 - ESS:
ess.jl
sample with LogDensityFunction: part 2 -ess.jl
+mh.jl
#2590 - MH:
mh.jl
sample with LogDensityFunction: part 2 -ess.jl
+mh.jl
#2590 - IS:
is.jl
- Prior:
prior.jl
- Emcee:
emcee.jl
- Particle samplers:
particle_mcmc.jl
-
ExternalSampler
-
RepeatSampler
- Add method warning to
AbstractMCMC.step(rng, model, spl[, state])
Current situation
Right now, if you call sample(::Model, ::InferenceAlgorithm, ::Int)
this first goes to src/mcmc/Inference.jl
where the InferenceAlgorithm
gets wrapped in DynamicPPL.Sampler
, e.g.
Turing.jl/src/mcmc/Inference.jl
Lines 268 to 278 in 5acc97f
This then goes to src/mcmc/$sampler.jl
which defines the methods sample(::Model, ::Sampler{<:InferenceAlgorithm}, ::Int)
, e.g.
Lines 82 to 104 in 5acc97f
This then goes to AbstractMCMC's sample
:
Which then calls step(::AbstractRNG, ::Model, ::Sampler{<:InferenceAlgorithm})
, which is defined in DynamicPPL:
Which then calls initialstep
, which goes back to being defined in src/mcmc/$sampler.jl
:
Lines 141 to 149 in 5acc97f
(this signature claims to work on AbstractModel
, it only really works for DynamicPPL.Model
)
Inside here, we finally construct a LogDensityFunction
from the model. So, there are very many steps between the time that sample()
is called, and the time where a LogDensityFunction
is actually constructed.
Proposal
Rework everything below the very first call to accept LogDensityFunction
rather than Model
. That is to say, the method sample(::Model, ::InferenceAlgorithm, ::Int)
should look something like this:
function sample(rng::Random.AbstracRNG, model::Model, alg::InferenceAlgorithm, N::Int, kwargs...)
adtype = get_adtype(alg) # returns nothing (e.g. for MH) or AbstractADType (e.g. NUTS)
ldf = DynamicPPL.LogDensityFunction(model; adtype=adtype)
spl = DynamicPPL.Sampler(alg)
sample(rng, ldf, spl, N; kwargs...)
end
# similar methods for sample(::Random.AbstractRNG, ::Model, ::Sampler{<:InferenceAlgorithm})
# as well as sample(:Random.AbstractRNG, ::LogDensityFunction, ::InferenceAlgorithm)
function sample(
rng::Random.AbstractRNG,
ldf::DynamicPPL.LogDensityFunction,
spl::DynamicPPL.Sampler{<:InferenceAlgorithm}, N::Int;
chain_type=MCMCChains.Chains, check=true, ...
)
# All of Turing's magic behaviour, e.g. checking the model, setting the chain_type,
# should happen in this method.
check && check_model(ldf.model)
... (whatever else)
# Then we can call mcmcsample, and that means that everything from mcmcsample
# onwards _knows_ that it's dealing with a LogDensityFunction and a Sampler{...}.
AbstractMCMC.mcmcsample(...)
end
# handle rng-less methods (ugly boilerplate, because somebody decided that the *optional*
# rng argument should be the first argument and not a keyword argument, ugh!)
function sample(model::Model, alg::InferenceAlgorithm, N::Int, kwargs...)
sample(Random.default_rng(), model, alg, N; kwargs...)
end
# similar for sample(::Model, ::Sampler{<:InferenceAlgorithm})
# as well as sample(::LogDensityFunction, ::InferenceAlgorithm)
# as well as sample(::LogDensityFunction, ::Sampler{<:InferenceAlgorithm})
# oh don't forget to handle the methods with MCMCEnsemble... argh...
# so that's 16 methods to make sure that everything works correctly
# - 2x with/without rng
# - 2x for Model and LDF
# - 2x for InferenceAlgorithm and Sampler{...}
# - 2x for standard and parallel
This would require making several changes across DynamicPPL and Turing. It (thankfully) probably does not need to touch AbstractMCMC, as long as we make LogDensityFunction
a subtype of AbstractMCMC.AbstractModel
(so that mcmcsample
can work). That should be fine, because AbstractModel
has no interface.
Why?
For one, this is probably the best way to let people have greater control over their sampling process. For example:
- If you want to use a particular type of VarInfo, this interface would allow you to construct it, make the LogDensityFunction, then pass it to
sample()
. Right now, this is actually very difficult to do. (Just try it!!) Note that this also provides a natural interface for opting into ThreadSafeVarInfo (cf.ThreadSafeVarInfo
andthreadid
DynamicPPL.jl#924) - Allows constructing the LDF with a particular adtype (this was part of the point behind making adtype a field of the LDF, cf. Remove LogDensityProblemsAD; wrap adtype in LogDensityFunction DynamicPPL.jl#806)
More philosophically, it's IMO the first step that's necessary towards encapsulating Turing's "magic behaviour" at the very top level of the call stack. We know that a DynamicPPL.Model
on its own does not actually give enough information about how to evaluate it — it's only LogDensityFunction
that contains the necessary information. Thus, it shouldn't be the job of the low-level functions like step
to make this decision — they should just 'receive' objects that are already complete.
I'm still not sure about
How this will work with Gibbs. I haven't looked at it deeply enough.