Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
stat::N

"""
Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
Transition(model::Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true)

Construct a new `Turing.Inference.Transition` object using the outputs of a
sampler step.
Expand All @@ -148,17 +148,38 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition

`sampler_transition` is the transition object returned by the sampler
itself and is only used to extract statistics of interest.

By default, the model is re-evaluated in order to obtain values of:
- the values of the parameters as per user parameterisation (`vals_as_in_model`)
- the various components of the log joint probability (`logprior`, `loglikelihood`)
that are guaranteed to be correct.

If you **know** for a fact that the VarInfo `vi` already contains this information,
then you can set `reevaluate=false` to skip the re-evaluation step.

!!! warning
Note that in general this is unsafe and may lead to wrong results.

If `reevaluate` is set to `false`, it is the caller's responsibility to ensure that
the `VarInfo` passed in has `ValuesAsInModelAccumulator`, `LogPriorAccumulator`,
and `LogLikelihoodAccumulator` set up with the correct values. Note that the
`ValuesAsInModelAccumulator` must also have `include_colon_eq == true`, i.e. it
must be set up to track `x := y` statements.
"""
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition)
vi = DynamicPPL.setaccs!!(
vi,
(
DynamicPPL.ValuesAsInModelAccumulator(true),
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
),
)
_, vi = DynamicPPL.evaluate!!(model, vi)
function Transition(
model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true
)
if reevaluate
vi = DynamicPPL.setaccs!!(
vi,
(
DynamicPPL.ValuesAsInModelAccumulator(true),
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
),
)
_, vi = DynamicPPL.evaluate!!(model, vi)
end

# Extract all the information we need
vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
Expand All @@ -175,12 +196,18 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
function Transition(
model::DynamicPPL.Model,
untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata},
sampler_transition,
sampler_transition;
reevaluate=true,
)
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
# much faster to convert it to a typed varinfo first, hence this method.
# https://github.com/TuringLang/Turing.jl/issues/2604
return Transition(model, DynamicPPL.typed_varinfo(untyped_vi), sampler_transition)
return Transition(
model,
DynamicPPL.typed_varinfo(untyped_vi),
sampler_transition;
reevaluate=reevaluate,
)
end
end

Expand Down
15 changes: 11 additions & 4 deletions src/mcmc/prior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ function AbstractMCMC.step(
sampling_model = DynamicPPL.contextualize(
model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context)
)
_, vi = DynamicPPL.evaluate!!(sampling_model, VarInfo())
return Transition(model, vi, nothing), nothing
vi = VarInfo()
vi = DynamicPPL.setaccs!!(
vi,
(
DynamicPPL.ValuesAsInModelAccumulator(true),
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
),
)
_, vi = DynamicPPL.evaluate!!(sampling_model, vi)
return Transition(model, vi, nothing; reevaluate=false), nothing
end

DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains
Comment on lines -22 to -23
Copy link
Member Author

@penelopeysm penelopeysm Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these lines are obsolete, it's actually the following method that gets called:

DynamicPPL.default_chain_type(sampler::Sampler{<:InferenceAlgorithm}) = MCMCChains.Chains

23 changes: 23 additions & 0 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,29 @@ using Turing
@test mean(x[:s][1] for x in chains) ≈ 3 atol = 0.11
@test mean(x[:m][1] for x in chains) ≈ 0 atol = 0.1
end

@testset "accumulators are set correctly" begin
# Prior() uses `reevaluate=false` when constructing a
# `Turing.Inference.Transition`, so we had better make sure that it
# does capture colon-eq statements, as we can't rely on the default
# `Transition` constructor to do this for us.
@model function coloneq()
x ~ Normal()
10.0 ~ Normal(x)
z := 1.0
return nothing
end
chain = sample(coloneq(), Prior(), N)
@test chain isa MCMCChains.Chains
@test all(x -> x == 1.0, chain[:z])
# And for the same reason we should also make sure that the logp
# components are correctly calculated.
@test isapprox(chain[:logprior], logpdf.(Normal(), chain[:x]))
@test isapprox(chain[:loglikelihood], logpdf.(Normal.(chain[:x]), 10.0))
@test isapprox(chain[:lp], chain[:logprior] .+ chain[:loglikelihood])
# And that the outcome is not influenced by the likelihood
@test mean(chain, :x) ≈ 0.0 atol = 0.1
end
end

@testset "chain ordering" begin
Expand Down
Loading