Skip to content

Commit 990eeb2

Browse files
authored
Prior should use PriorContext (#2170)
* `Prior` now uses `PriorContext` * `Prior` sampler now implements `step` instead of relying on `SampleFromPrior` * remove unnecessary `make_prior_model` * bump patch version
1 parent 39f5d5b commit 990eeb2

File tree

3 files changed

+33
-31
lines changed

3 files changed

+33
-31
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.30.3"
3+
version = "0.30.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/Inference.jl

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,23 @@ DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) = SimpleVarInfo(θ, vi.l
115115
# Algorithm for sampling from the prior
116116
struct Prior <: InferenceAlgorithm end
117117

118+
function AbstractMCMC.step(
119+
rng::Random.AbstractRNG,
120+
model::DynamicPPL.Model,
121+
sampler::DynamicPPL.Sampler{<:Prior},
122+
state=nothing;
123+
kwargs...,
124+
)
125+
vi = last(DynamicPPL.evaluate!!(
126+
model,
127+
VarInfo(),
128+
SamplingContext(
129+
rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()
130+
)
131+
))
132+
return vi, nothing
133+
end
134+
118135
"""
119136
mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::Real)
120137
@@ -230,36 +247,6 @@ function AbstractMCMC.sample(
230247
chain_type=chain_type, progress=progress, kwargs...)
231248
end
232249

233-
function AbstractMCMC.sample(
234-
rng::AbstractRNG,
235-
model::AbstractModel,
236-
alg::Prior,
237-
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
238-
N::Integer,
239-
n_chains::Integer;
240-
chain_type=DynamicPPL.default_chain_type(alg),
241-
progress=PROGRESS[],
242-
kwargs...
243-
)
244-
return AbstractMCMC.sample(rng, model, SampleFromPrior(), ensemble, N, n_chains;
245-
chain_type, progress, kwargs...)
246-
end
247-
248-
function AbstractMCMC.sample(
249-
rng::AbstractRNG,
250-
model::AbstractModel,
251-
alg::Prior,
252-
N::Integer;
253-
chain_type=DynamicPPL.default_chain_type(alg),
254-
resume_from=nothing,
255-
initial_state=DynamicPPL.loadstate(resume_from),
256-
progress=PROGRESS[],
257-
kwargs...
258-
)
259-
return AbstractMCMC.mcmcsample(rng, model, SampleFromPrior(), N;
260-
chain_type, initial_state, progress, kwargs...)
261-
end
262-
263250
##########################
264251
# Chain making utilities #
265252
##########################

test/mcmc/Inference.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,21 @@
140140
@test all(haskey(x, :lp) for x in chains)
141141
@test mean(x[:s][1] for x in chains) 3 atol=0.1
142142
@test mean(x[:m][1] for x in chains) 0 atol=0.1
143+
144+
@testset "#2169" begin
145+
# Not exactly the same as the issue, but similar.
146+
@model function issue2169_model()
147+
if DynamicPPL.leafcontext(__context__) isa DynamicPPL.PriorContext
148+
x ~ Normal(0, 1)
149+
else
150+
x ~ Normal(1000, 1)
151+
end
152+
end
153+
154+
model = issue2169_model()
155+
chain = sample(model, Prior(), 10)
156+
@test all(mean(chain[:x]) .< 5)
157+
end
143158
end
144159

145160
@testset "chain ordering" begin

0 commit comments

Comments
 (0)