Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ export AbstractVarInfo,
getargnames,
getdefaults,
getgenerator,
runmodel!,
# Samplers
Sampler,
SampleFromPrior,
Expand Down
29 changes: 18 additions & 11 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ function generate_tilde(left, right, model_info)
ctx = model_info[:main_body_names][:ctx]
sampler = model_info[:main_body_names][:sampler]

@gensym tmpright
@gensym tmpright tmpleft
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
|| throw(ArgumentError($DISTMSG)))]
Expand All @@ -290,8 +290,8 @@ function generate_tilde(left, right, model_info)
assumption = [
:($out = $(DynamicPPL.tilde_assume)($ctx, $sampler, $tmpright, $vn, $inds,
$vi)),
:($left = $out[1]),
:($(DynamicPPL.acclogp!)($vi, $out[2]))
:($(DynamicPPL.acclogp!)($vi, $out[2])),
:($left = $out[1])
]

# It can only be an observation if the LHS is an argument of the model
Expand All @@ -303,11 +303,13 @@ function generate_tilde(left, right, model_info)
if $isassumption
$(assumption...)
else
$tmpleft = $left
$(DynamicPPL.acclogp!)(
$vi,
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $left, $vn,
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vn,
$inds, $vi)
)
$tmpleft
end
end
end
Expand All @@ -321,10 +323,12 @@ function generate_tilde(left, right, model_info)
# If the LHS is a literal, it is always an observation
return quote
$(top...)
$tmpleft = $left
$(DynamicPPL.acclogp!)(
$vi,
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $left, $vi)
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vi)
)
$tmpleft
end
end

Expand All @@ -341,7 +345,7 @@ function generate_dot_tilde(left, right, model_info)
ctx = model_info[:main_body_names][:ctx]
sampler = model_info[:main_body_names][:sampler]

@gensym tmpright
@gensym tmpright tmpleft
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
|| throw(ArgumentError($DISTMSG)))]
Expand All @@ -353,8 +357,8 @@ function generate_dot_tilde(left, right, model_info)
assumption = [
:($out = $(DynamicPPL.dot_tilde_assume)($ctx, $sampler, $tmpright, $left,
$vn, $inds, $vi)),
:($left .= $out[1]),
:($(DynamicPPL.acclogp!)($vi, $out[2]))
:($(DynamicPPL.acclogp!)($vi, $out[2])),
:($left .= $out[1])
]

# It can only be an observation if the LHS is an argument of the model
Expand All @@ -366,11 +370,13 @@ function generate_dot_tilde(left, right, model_info)
if $isassumption
$(assumption...)
else
$tmpleft = $left
$(DynamicPPL.acclogp!)(
$vi,
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $left,
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $tmpleft,
$vn, $inds, $vi)
)
$tmpleft
end
end
end
Expand All @@ -384,10 +390,12 @@ function generate_dot_tilde(left, right, model_info)
# If the LHS is a literal, it is always an observation
return quote
$(top...)
$tmpleft = $left
$(DynamicPPL.acclogp!)(
$vi,
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $left, $vi)
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vi)
)
$tmpleft
end
end

Expand Down Expand Up @@ -443,7 +451,6 @@ function build_output(model_info)
$ctx::$(DynamicPPL.AbstractContext),
)
$unwrap_data_expr
$(DynamicPPL.resetlogp!)($vi)
$main_body
end

Expand Down
18 changes: 9 additions & 9 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,34 +109,34 @@ function Model{missings}(
return Model{missings}(model.f, args, modelgen)
end

"""
(model::Model)([spl = SampleFromPrior(), ctx = DefaultContext()])

Sample from `model` using the sampler `spl`.
"""
function (model::Model)(
vi::AbstractVarInfo=VarInfo(),
spl::AbstractSampler=SampleFromPrior(),
ctx::AbstractContext=DefaultContext()
)
return model.f(model, vi, spl, ctx)
return model(VarInfo(), spl, ctx)
end


"""
runmodel!(model::Model, vi::AbstractVarInfo[, spl::AbstractSampler, ctx::AbstractContext])
(model::Model)(vi::AbstractVarInfo[, spl = SampleFromPrior(), ctx = DefaultContext()])

Sample from `model` using the sampler `spl` storing the sample and log joint probability in `vi`.
Resets the `vi` and increases `spl`s `state.eval_num`.
"""
function runmodel!(
model::Model,
function (model::Model)(
vi::AbstractVarInfo,
spl::AbstractSampler=SampleFromPrior(),
ctx::AbstractContext=DefaultContext()
)
setlogp!(vi, 0)
resetlogp!(vi)
if has_eval_num(spl)
spl.state.eval_num += 1
end
model(vi, spl, ctx)
return vi
return model.f(model, vi, spl, ctx)
end


Expand Down
2 changes: 1 addition & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ const TypedVarInfo = VarInfo{<:NamedTuple}

function VarInfo(model::Model, ctx = DefaultContext())
vi = VarInfo()
runmodel!(model, vi, SampleFromPrior(), ctx)
model(vi, SampleFromPrior(), ctx)
return TypedVarInfo(vi)
end

Expand Down
2 changes: 1 addition & 1 deletion test/Turing/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using Markdown, Libtask, MacroTools
using Tracker: Tracker

import Base: ~, ==, convert, hash, promote_rule, rand, getindex, setindex!
import DynamicPPL: getspace, runmodel!
import DynamicPPL: getspace

const PROGRESS = Ref(true)
function turnprogress(switch::Bool)
Expand Down
6 changes: 3 additions & 3 deletions test/Turing/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ function AbstractMCMC.sample_init!(
gradient_logp(x, spl.state.vi, model, spl)
end

runmodel!(model, spl.state.vi, SampleFromUniform())
model(spl.state.vi, SampleFromUniform())

if spl.selector.tag == :default
link!(spl.state.vi, spl)
runmodel!(model, spl.state.vi, spl)
model(spl.state.vi, spl)
end

# Set the parameters to a starting value.
Expand Down Expand Up @@ -145,4 +145,4 @@ function AbstractMCMC.psample(
end
return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains;
chain_type=chain_type, progress=false, kwargs...)
end
end
4 changes: 2 additions & 2 deletions test/Turing/contrib/inference/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function step(
Turing.DEBUG && @debug "X-> R..."
if spl.selector.tag != :default
link!(vi, spl)
runmodel!(model, vi, spl)
model(vi, spl)
end

Turing.DEBUG && @debug "recording old variables..."
Expand Down Expand Up @@ -198,7 +198,7 @@ function step(
Turing.DEBUG && @debug "X-> R..."
if spl.selector.tag != :default
link!(vi, spl)
runmodel!(model, vi, spl)
model(vi, spl)
end

Turing.DEBUG && @debug "recording old variables..."
Expand Down
2 changes: 1 addition & 1 deletion test/Turing/core/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Distributions, LinearAlgebra
using ..Utilities, Reexport
using Tracker: Tracker
using ..Turing: Turing
using DynamicPPL: Model, runmodel!,
using DynamicPPL: Model,
AbstractSampler, Sampler, SampleFromPrior
using LinearAlgebra: copytri!
using Bijectors: PDMatDistribution
Expand Down
6 changes: 4 additions & 2 deletions test/Turing/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ function gradient_logp(
logp_old = getlogp(vi)
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
logp = getlogp(runmodel!(model, new_vi, sampler))
model(new_vi, sampler)
logp = getlogp(new_vi)
setlogp!(vi, ForwardDiff.value(logp))
return logp
end
Expand All @@ -119,7 +120,8 @@ function gradient_logp(
# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
return getlogp(runmodel!(model, new_vi, sampler))
model(new_vi, sampler)
return getlogp(new_vi)
end

# Compute forward and reverse passes.
Expand Down
3 changes: 2 additions & 1 deletion test/Turing/core/compat/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ function gradient_logp(
# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
return getlogp(runmodel!(model, new_vi, sampler))
model(new_vi, sampler)
return getlogp(new_vi)
end

# Compute forward and reverse passes.
Expand Down
2 changes: 1 addition & 1 deletion test/Turing/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ..Core, ..Utilities
using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo,
islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize,
settrans!, _getvns, getdist, CACHERESET, AbstractSampler,
Model, runmodel!, Sampler, SampleFromPrior, SampleFromUniform,
Model, Sampler, SampleFromPrior, SampleFromUniform,
Selector, AbstractSamplerState, DefaultContext, PriorContext,
LikelihoodContext, MiniBatchContext, set_flag!, unset_flag!, NamedDist, NoDist
using Distributions, Libtask, Bijectors
Expand Down
6 changes: 3 additions & 3 deletions test/Turing/inference/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ function AbstractMCMC.step!(

# recompute log-likelihood in logp
if spl.selector.tag !== :default
runmodel!(model, vi, spl)
model(vi, spl)
end

# define previous sampler state
Expand Down Expand Up @@ -117,7 +117,7 @@ function EllipticalSliceSampling.sample_prior(rng::Random.AbstractRNG, model::ES
vi = spl.state.vi
vns = _getvns(vi, spl)
set_flag!(vi, vns[1][1], "del")
runmodel!(model.model, vi, spl)
model.model(vi, spl)
return vi[spl]
end

Expand All @@ -140,7 +140,7 @@ function Distributions.loglikelihood(model::ESSModel, f)
spl = model.spl
vi = spl.state.vi
vi[spl] = f
runmodel!(model.model, vi, spl)
model.model(vi, spl)
getlogp(vi)
end

Expand Down
6 changes: 3 additions & 3 deletions test/Turing/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ function AbstractMCMC.sample_init!(
# non-Gibbs sampling.
if !islinked(spl.state.vi, spl) && spl.selector.tag == :default
link!(spl.state.vi, spl)
runmodel!(model, spl.state.vi, spl)
model(spl.state.vi, spl)
end
end

Expand Down Expand Up @@ -343,7 +343,7 @@ function AbstractMCMC.step!(
# Transform the space
Turing.DEBUG && @debug "X-> R..."
link!(spl.state.vi, spl)
runmodel!(model, spl.state.vi, spl)
model(spl.state.vi, spl)
# Update Hamiltonian
metric = gen_metric(length(spl.state.vi[spl]), spl)
∂logπ∂θ = gen_∂logπ∂θ(spl.state.vi, spl, model)
Expand Down Expand Up @@ -413,7 +413,7 @@ function gen_logπ(vi::VarInfo, spl::Sampler, model)
function logπ(x)::Float64
x_old, lj_old = vi[spl], getlogp(vi)
vi[spl] = x
runmodel!(model, vi, spl)
model(vi, spl)
lj = getlogp(vi)
vi[spl] = x_old
setlogp!(vi, lj_old)
Expand Down
4 changes: 2 additions & 2 deletions test/Turing/inference/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function gen_logπ_mh(spl::Sampler, model)
x_old, lj_old = vi[spl], getlogp(vi)
# vi[spl] = x
set_namedtuple!(vi, x)
runmodel!(model, vi)
model(vi)
lj = getlogp(vi)
vi[spl] = x_old
setlogp!(vi, lj_old)
Expand Down Expand Up @@ -231,7 +231,7 @@ function AbstractMCMC.step!(
kwargs...
)
if spl.selector.rerun # Recompute joint in logp
runmodel!(model, spl.state.vi)
model(spl.state.vi)
end

# Retrieve distribution and value NamedTuples.
Expand Down
5 changes: 1 addition & 4 deletions test/independence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,5 @@
end
end
model = coinflip([1,1,0])

vi = VarInfo()

runmodel!(model, vi, SampleFromPrior(), LikelihoodContext())
model(SampleFromPrior(), LikelihoodContext())
end
13 changes: 6 additions & 7 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using .Turing, Random
using AbstractMCMC: step!
using DynamicPPL: Selector, reconstruct, invlink, CACHERESET,
SampleFromPrior, Sampler, runmodel!, SampleFromUniform,
SampleFromPrior, Sampler, SampleFromUniform,
_getidcs, set_retained_vns_del_by_spl!, is_flagged,
set_flag!, unset_flag!, VarInfo, TypedVarInfo,
getlogp, setlogp!, resetlogp!, acclogp!, vectorize,
Expand Down Expand Up @@ -114,19 +114,18 @@ include(dir*"/test/test_utils/AllUtils.jl")
test_base!(vi)
test_base!(empty!(TypedVarInfo(vi)))
end
@testset "runmodel!" begin
# Test that eval_num is incremented when calling runmodel!
@testset "in-place" begin
# Test that eval_num is incremented when running the model
@model testmodel() = begin
x ~ Normal()
end
alg = HMC(0.1, 5)
spl = Sampler(alg, testmodel())
vi = VarInfo()
m = testmodel()
m(vi)
runmodel!(m, vi, spl)
vi = VarInfo(m)
m(vi, spl)
@test spl.state.eval_num == 1
runmodel!(m, vi, spl)
m(vi, spl)
@test spl.state.eval_num == 2
end
@testset "flags" begin
Expand Down