Skip to content

Commit 1d477ca

Browse files
committed
Allow Prior to skip model re-evaluation
1 parent b41a4b1 commit 1d477ca

File tree

2 files changed

+51
-15
lines changed

2 files changed

+51
-15
lines changed

src/mcmc/Inference.jl

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
136136
stat::N
137137

138138
"""
139-
Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
139+
Transition(model::Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true)
140140
141141
Construct a new `Turing.Inference.Transition` object using the outputs of a
142142
sampler step.
@@ -148,17 +148,38 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
148148
149149
`sampler_transition` is the transition object returned by the sampler
150150
itself and is only used to extract statistics of interest.
151+
152+
By default, the model is re-evaluated in order to obtain values of:
153+
- the values of the parameters as per user parameterisation (`vals_as_in_model`)
154+
- the various components of the log joint probability (`logprior`, `loglikelihood`)
155+
that are guaranteed to be correct.
156+
157+
If you **know** for a fact that the VarInfo `vi` already contains this information,
158+
then you can set `reevaluate=false` to skip the re-evaluation step.
159+
160+
!!! warning
161+
Note that in general this is unsafe and may lead to wrong results.
162+
163+
If `reevaluate` is set to `false`, it is the caller's responsibility to ensure that
164+
the `VarInfo` passed in has `ValuesAsInModelAccumulator`, `LogPriorAccumulator`,
165+
and `LogLikelihoodAccumulator` set up with the correct values. Note that the
166+
`ValuesAsInModelAccumulator` must also have `include_colon_eq == true`, i.e. it
167+
must be set up to track `x := y` statements.
151168
"""
152-
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition)
153-
vi = DynamicPPL.setaccs!!(
154-
vi,
155-
(
156-
DynamicPPL.ValuesAsInModelAccumulator(true),
157-
DynamicPPL.LogPriorAccumulator(),
158-
DynamicPPL.LogLikelihoodAccumulator(),
159-
),
160-
)
161-
_, vi = DynamicPPL.evaluate!!(model, vi)
169+
function Transition(
170+
model::DynamicPPL.Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true
171+
)
172+
if reevaluate
173+
vi = DynamicPPL.setaccs!!(
174+
vi,
175+
(
176+
DynamicPPL.ValuesAsInModelAccumulator(true),
177+
DynamicPPL.LogPriorAccumulator(),
178+
DynamicPPL.LogLikelihoodAccumulator(),
179+
),
180+
)
181+
_, vi = DynamicPPL.evaluate!!(model, vi)
182+
end
162183

163184
# Extract all the information we need
164185
vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
@@ -175,12 +196,18 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
175196
function Transition(
176197
model::DynamicPPL.Model,
177198
untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata},
178-
sampler_transition,
199+
sampler_transition;
200+
reevaluate=true,
179201
)
180202
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
181203
# much faster to convert it to a typed varinfo first, hence this method.
182204
# https://github.com/TuringLang/Turing.jl/issues/2604
183-
return Transition(model, DynamicPPL.typed_varinfo(untyped_vi), sampler_transition)
205+
return Transition(
206+
model,
207+
DynamicPPL.typed_varinfo(untyped_vi),
208+
sampler_transition;
209+
reevaluate=reevaluate,
210+
)
184211
end
185212
end
186213

src/mcmc/prior.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,17 @@ function AbstractMCMC.step(
1616
sampling_model = DynamicPPL.contextualize(
1717
model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context)
1818
)
19-
_, vi = DynamicPPL.evaluate!!(sampling_model, VarInfo())
20-
return Transition(model, vi, nothing), nothing
19+
vi = VarInfo()
20+
vi = DynamicPPL.setaccs!!(
21+
vi,
22+
(
23+
DynamicPPL.ValuesAsInModelAccumulator(true),
24+
DynamicPPL.LogPriorAccumulator(),
25+
DynamicPPL.LogLikelihoodAccumulator(),
26+
),
27+
)
28+
_, vi = DynamicPPL.evaluate!!(sampling_model, vi)
29+
return Transition(model, vi, nothing; reevaluate=false), nothing
2130
end
2231

2332
DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains

0 commit comments

Comments
 (0)