Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8"
DynamicPPL = "0.16"
DynamicPPL = "0.17.2"
EllipticalSliceSampling = "0.4"
ForwardDiff = "0.10.3"
Libtask = "0.4, 0.5.3"
Expand Down
6 changes: 3 additions & 3 deletions docs/src/for-developers/how_turing_implements_abstractmcmc.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ Consider an instance `m` of `Model` and a sampler `spl`, with associated `VarInf
* recall that the code for `m.f(vi, ...)` is automatically generated by compilation of the `@model` macro
* for every tilde statement in the `@model` declaration, this code contains a call to `assume(vi, ...)` if the variable on the LHS of the tilde is a **model parameter to infer**, and `observe(vi, ...)` if the variable on the LHS of the tilde is an **observation**
* in the file corresponding to your sampling method (ie in `Turing.jl/src/inference/<your_method>.jl`), you have **overloaded** `assume` and `observe`, so that they can modify `vi` to include the information and samples that you care about!
* at a minimum, `assume` and `observe` return the log density `lp` of the sample or observation. the model evaluation function then immediately calls `acclogp!(vi, lp)`, which adds `lp` to the value of the log joint density stored in `vi`.
* at a minimum, `assume` and `observe` return the log density `lp` of the sample or observation. the model evaluation function then immediately calls `acclogp!!(vi, lp)`, which adds `lp` to the value of the log joint density stored in `vi`.

Here's what `assume` looks like for Importance Sampling:

Expand Down Expand Up @@ -226,10 +226,10 @@ It simply returns the density (in the discrete case, the probability) of the obs
We focus on the AbstractMCMC functions that are overriden in `is.jl` and executed inside `mcmcsample`: `step!`, which is called `n_samples` times, and `sample_end!`, which is executed once after those `n_samples` iterations.

* During the \$\$i\$\$-th iteration, `step!` does 3 things:
* `empty!(spl.state.vi)`: remove information about the previous sample from the sampler's `VarInfo`
* `empty!!(spl.state.vi)`: remove information about the previous sample from the sampler's `VarInfo`
* `model(rng, spl.state.vi, spl)`: call the model evaluation function
* calls to `assume` add the samples from the prior \$\$s\_i\$\$ and \$\$m\_i\$\$ to `spl.state.vi`
* calls to both `assume` or `observe` are followed by the line `acclogp!(vi, lp)`, where `lp` is an output of `assume` and `observe`
* calls to both `assume` or `observe` are followed by the line `acclogp!!(vi, lp)`, where `lp` is an output of `assume` and `observe`
* `lp` is set to 0 after `assume`, and to the value of the density at the observation after `observe`
* when all the tilde statements have been covered, `spl.state.vi.logp[]` is the sum of the `lp`, ie the likelihood \$\$\log p(x, y \mid s\_i, m\_i) = \log p(x \mid s\_i, m\_i) + \log p(y \mid s\_i, m\_i)\$\$ of the observations given the latent variable samples \$\$s\_i\$\$ and \$\$m\_i\$\$.
* `return Transition(spl)`: build a transition from the sampler, and return that transition
Expand Down
4 changes: 2 additions & 2 deletions src/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function DynamicPPL.initialstep(

# Update the variables.
vi[spl] = Q.q
DynamicPPL.setlogp!(vi, Q.ℓq)
DynamicPPL.setlogp!!(vi, Q.ℓq)

# Create first sample and state.
sample = Transition(vi)
Expand Down Expand Up @@ -127,7 +127,7 @@ function AbstractMCMC.step(

# Update the variables.
vi[spl] = Q.q
DynamicPPL.setlogp!(vi, Q.ℓq)
DynamicPPL.setlogp!!(vi, Q.ℓq)

# Create next sample and state.
sample = Transition(vi)
Expand Down
12 changes: 7 additions & 5 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ function gradient_logp(
logp_old = getlogp(vi)
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler, ctx)
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, sampler, ctx))
logp = getlogp(new_vi)
setlogp!(vi, ForwardDiff.value(logp))
# Don't need to capture the resulting `vi` since this is only
# needed if `vi` is mutable.
setlogp!!(vi, ForwardDiff.value(logp))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess ideally we would want a no-op here when setlogp!! returns a new object - if we just call setlogp!! with a SimpleVarInfo it would create a new VarInfo object even though it's never used. Maybe one could check if vi === new_vi and only call setlogp!! if it is true (assuming that in this case setlogp!! would not return a new object)?

But I guess this should be left for the SimpleVarInfo PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this not be optimized away in the case where vi is immutable?
I sort of assumed so. But I guess we might as well give the compiler a helping hand just to be sure.

return logp
end

Expand All @@ -120,7 +122,7 @@ function gradient_logp(
config = ForwardDiff.GradientConfig(f, θ, chunk)
∂l∂θ = ForwardDiff.gradient!(similar(θ), f, θ, config)
l = getlogp(vi)
setlogp!(vi, logp_old)
setlogp!!(vi, logp_old)

return l, ∂l∂θ
end
Expand All @@ -137,7 +139,7 @@ function gradient_logp(
# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler, ctx)
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, sampler, ctx))
return getlogp(new_vi)
end

Expand All @@ -162,7 +164,7 @@ function gradient_logp(
# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler, context)
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, sampler, context))
return getlogp(new_vi)
end

Expand Down
2 changes: 1 addition & 1 deletion src/core/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function AdvancedPS.reset_model(f::TracedModel)
end

function AdvancedPS.reset_logprob!(f::TracedModel)
DynamicPPL.resetlogp!(f.varinfo)
DynamicPPL.resetlogp!!(f.varinfo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also relies on the fact that it mutates f.varinfo.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, but TracedModel isn't compatible with immutable AbstractVarInfo 😕

return
end

20 changes: 11 additions & 9 deletions src/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ function DynamicPPL.initialstep(
# Reset the VarInfo.
reset_num_produce!(vi)
set_retained_vns_del_by_spl!(vi, spl)
resetlogp!(vi)
empty!(vi)
resetlogp!!(vi)
empty!!(vi)
Comment on lines +113 to +114
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does empty!! not call resetlogp!!? Maybe this would be reasonable? In any case, also here it is mandatory that the calls are actually mutating.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the calls below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same response as above: these samplers aren't compatible with immutable VarInfo 😕


# Create a new set of particles.
particles = AdvancedPS.ParticleContainer(
Expand Down Expand Up @@ -249,7 +249,7 @@ function DynamicPPL.initialstep(
# Reset the VarInfo before new sweep
reset_num_produce!(vi)
set_retained_vns_del_by_spl!(vi, spl)
resetlogp!(vi)
resetlogp!!(vi)

# Create a new set of particles
num_particles = spl.alg.nparticles
Expand Down Expand Up @@ -281,7 +281,7 @@ function AbstractMCMC.step(
)
# Reset the VarInfo before new sweep.
reset_num_produce!(vi)
resetlogp!(vi)
resetlogp!!(vi)

# Create reference particle for which the samples will be retained.
reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi))
Expand Down Expand Up @@ -315,6 +315,8 @@ function AbstractMCMC.step(
return transition, _vi
end

DynamicPPL.use_threadsafe_eval(::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo) = false

function DynamicPPL.assume(
rng,
spl::Sampler{<:Union{PG,SMC}},
Expand All @@ -326,7 +328,7 @@ function DynamicPPL.assume(
if inspace(vn, spl)
if ~haskey(vi, vn)
r = rand(rng, dist)
push!(vi, vn, r, dist, spl)
push!!(vi, vn, r, dist, spl)
elseif is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = rand(rng, dist)
Expand All @@ -342,17 +344,17 @@ function DynamicPPL.assume(
r = vi[vn]
else
r = rand(rng, dist)
push!(vi, vn, r, dist, DynamicPPL.Selector(:invalid))
push!!(vi, vn, r, dist, DynamicPPL.Selector(:invalid))
end
lp = logpdf_with_trans(dist, r, istrans(vi, vn))
acclogp!(vi, lp)
acclogp!!(vi, lp)
end
return r, 0
return r, 0, vi
end

function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
produce(logpdf(dist, value))
return 0
return 0, vi
end

# Convenient constructor
Expand Down
5 changes: 4 additions & 1 deletion src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ module Inference
using ..Core
using ..Utilities
using DynamicPPL: Metadata, VarInfo, TypedVarInfo,
islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize,
islinked, invlink!, link!,
setindex!!, push!!,
setlogp!!, getlogp,
tonamedtuple, VarName, getsym, vectorize,
settrans!, _getvns, getdist,
Model, Sampler, SampleFromPrior, SampleFromUniform,
DefaultContext, PriorContext,
Expand Down
15 changes: 7 additions & 8 deletions src/inference/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,21 @@ function AbstractMCMC.step(

# Update the parameters if provided.
if haskey(kwargs, :init_params)
for vi in vis
initialize_parameters!(vi, kwargs[:init_params], spl)
vis = map(vis) do vi
vi = initialize_parameters!!(vi, kwargs[:init_params], spl)

# Update log joint probability.
model(rng, vi, SampleFromPrior())
last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromPrior()))
end
end

# Compute initial transition and states.
transition = map(vis) do vi
Transition(vi)
end
transition = map(Transition, vis)

# TODO: Make compatible with immutable `AbstractVarInfo`.
state = EmceeState(
vis[1],
map(vis) do vi
# Transform to unconstrained space.
DynamicPPL.link!(vi, spl)
AMH.Transition(vi[spl], getlogp(vi))
end
Expand All @@ -78,7 +77,7 @@ function AbstractMCMC.step(

# Compute the next transition and state.
transition = map(states) do _state
vi[spl] = _state.params
vi = setindex!!(vi, _state.params, spl)
DynamicPPL.invlink!(vi, spl)
t = Transition(tonamedtuple(vi), _state.lp)
DynamicPPL.link!(vi, spl)
Expand Down
22 changes: 11 additions & 11 deletions src/inference/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ function AbstractMCMC.step(
)

# update sample and log-likelihood
vi[spl] = sample
setlogp!(vi, state.loglikelihood)
vi = setindex!!(vi, sample, spl)
vi = setlogp!!(vi, state.loglikelihood)

return Transition(vi), vi
end
Expand Down Expand Up @@ -111,6 +111,7 @@ EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true
function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
sampler = p.sampler
varinfo = p.varinfo
# TODO: Surely there's a better way of doing this now that we have `SamplingContext`?
vns = _getvns(varinfo, sampler)
for vn in Iterators.flatten(values(vns))
set_flag!(varinfo, vn, "del")
Expand All @@ -131,17 +132,16 @@ end

function (ℓ::ESSLogLikelihood)(f)
sampler = ℓ.sampler
varinfo = ℓ.varinfo
varinfo[sampler] = f
ℓ.model(varinfo, sampler)
varinfo = setindex!!(ℓ.varinfo, f, sampler)
varinfo = last(DynamicPPL.evaluate!!(ℓ.model, varinfo, sampler))
return getlogp(varinfo)
end

function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn, vi)
if inspace(vn, sampler)
return DynamicPPL.tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi)
return if inspace(vn, sampler)
DynamicPPL.tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi)
else
return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi)
DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi)
end
end

Expand All @@ -151,10 +151,10 @@ end

function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vns, vi)
# TODO: Or should we do `all(Base.Fix2(inspace, sampler), vns)`?
if inspace(first(vns), sampler)
return DynamicPPL.dot_tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, left, vns, vi)
return if inspace(first(vns), sampler)
DynamicPPL.dot_tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, left, vns, vi)
else
return DynamicPPL.dot_tilde_assume(rng, ctx, SampleFromPrior(), right, left, vns, vi)
DynamicPPL.dot_tilde_assume(rng, ctx, SampleFromPrior(), right, left, vns, vi)
end
end

Expand Down
5 changes: 4 additions & 1 deletion src/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ function DynamicPPL.initialstep(
vi::AbstractVarInfo;
kwargs...
)
# TODO: Technically this only works for `VarInfo` or `ThreadSafeVarInfo{<:VarInfo}`.
# Should we enforce this?

# Create tuple of samplers
algs = spl.alg.algs
i = 0
Expand Down Expand Up @@ -230,7 +233,7 @@ function AbstractMCMC.step(
states = map(samplers, state.states) do _sampler, _state
# Recompute `vi.logp` if needed.
if _sampler.selector.rerun
model(rng, vi, _sampler)
vi = last(DynamicPPL.evaluate!!(model, rng, vi, _sampler))
end

# Update state of current sampler with updated `VarInfo` object.
Expand Down
6 changes: 4 additions & 2 deletions src/inference/gibbs_conditional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ function AbstractMCMC.step(
condvals = conditioned(tonamedtuple(vi))
conddist = spl.alg.conditional(condvals)
updated = rand(rng, conddist)
vi[spl] = [updated;] # setindex allows only vectors in this case...
model(rng, vi, SampleFromPrior()) # update log joint probability
# Setindex allows only vectors in this case.
vi = setindex!!(vi, [updated;], spl)
# Update log joint probability.
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromPrior()))

return nothing, vi
end
Expand Down
Loading