Skip to content

Commit 680c81a

Browse files
Move default model evaluation code to DynamicPPL (#1151)
* Remove distribution wrappers (~> DynamicPPL) * Remove asssume/observer fallback code (~> DynamicPPL) * Update imports/exports/qualified names * Remove unnecessary code from utils (is in DPPL now) * fix rebase * bump DPPL compat version Co-authored-by: mohamed82008 <mohamed82008@gmail.com>
1 parent e946541 commit 680c81a

File tree

11 files changed

+38
-478
lines changed

11 files changed

+38
-478
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ BinaryProvider = "0.5.6"
4040
Distributions = "0.22, 0.23"
4141
DistributionsAD = "0.4.8"
4242
DocStringExtensions = "0.8"
43-
DynamicPPL = "0.4"
43+
DynamicPPL = "0.5"
4444
EllipticalSliceSampling = "0.2"
4545
ForwardDiff = "0.10.3"
4646
Libtask = "0.3.1"

src/Turing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using Markdown, Libtask, MacroTools
1616
using Tracker: Tracker
1717

1818
import Base: ~, ==, convert, hash, promote_rule, rand, getindex, setindex!
19-
import DynamicPPL: getspace, runmodel!
19+
import DynamicPPL: getspace, runmodel!, NoDist, NamedDist
2020

2121
const PROGRESS = Ref(true)
2222
function turnprogress(switch::Bool)

src/inference/AdvancedSMC.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ function AbstractMCMC.sample_end!(
260260
spl.state.average_logevidence = loge
261261
end
262262

263-
function assume(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, ::VarInfo)
263+
function DynamicPPL.assume(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, ::VarInfo)
264264
vi = current_trace().vi
265265
if vn in getspace(spl)
266266
if ~haskey(vi, vn)
@@ -289,7 +289,7 @@ function assume(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName,
289289
return r, 0
290290
end
291291

292-
function observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
292+
function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
293293
produce(logpdf(dist, value))
294294
return 0
295295
end

0 commit comments

Comments
 (0)