Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/DynamicPPL-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
jobs:
test:
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
continue-on-error: ${{ matrix.version == 'nightly' }}
strategy:
matrix:
version:
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "1.0"
Bijectors = "0.5.2, 0.6"
Distributions = "0.22, 0.23"
MacroTools = "0.5.1"
StaticArrays = "0.12.2"
ZygoteRules = "0.2"
julia = "1"

Expand Down
4 changes: 3 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ using Bijectors
using MacroTools

import AbstractMCMC
import Random
import StaticArrays
import ZygoteRules

import Random

import Base: Symbol,
==,
hash,
Expand Down
35 changes: 25 additions & 10 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
"Distributions."

const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo)
const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_logps)

"""
isassumption(expr)
Expand Down Expand Up @@ -234,24 +234,25 @@ function generate_tilde(left, right, args)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume)(
_context, _sampler, $tmpright, $vn, $inds, _varinfo)
_context, _sampler, $tmpright, $vn, $inds, _varinfo, _logps)
else
$(DynamicPPL.tilde_observe)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
end
end
end

return quote
$(top...)
$left = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds, _varinfo)
$left = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds,
_varinfo, _logps)
end
end

# If the LHS is a literal, it is always an observation
return quote
$(top...)
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo, _logps)
end
end

Expand All @@ -278,25 +279,26 @@ function generate_dot_tilde(left, right, args)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
else
$(DynamicPPL.dot_tilde_observe)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
end
end
end

return quote
$(top...)
$left .= $(DynamicPPL.dot_tilde_assume)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo, _logps)
end
end

# If the LHS is a literal, it is always an observation
return quote
$(top...)
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo,
_logps)
end
end

Expand Down Expand Up @@ -333,16 +335,29 @@ function build_output(model_info)
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
end

@gensym(evaluator, generator)
@gensym(evaluator, innerevaluator, generator)
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))

return quote
function $evaluator(
model::$(DynamicPPL.Model),
varinfo::$(DynamicPPL.VarInfo),
sampler::$(DynamicPPL.AbstractSampler),
context::$(DynamicPPL.AbstractContext),
)
logps = $(DynamicPPL.initlogps)(varinfo)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this will create a type instability. I just noticed it now. Perhaps another reason to have logps inside VarInfo and resize it accordingly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's please fix this in another PR before releasing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's due to our use of StaticArrays here and JuliaLang/julia#34902. I guess we should just use arrays, compared to all other allocations and model evaluation that should not matter at all.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree 👍

result = $innerevaluator(model, varinfo, sampler, context, logps)
$(DynamicPPL.acclogp!)(varinfo, $(Base.sum)(logps))
return result
end

function $innerevaluator(
_model::$(DynamicPPL.Model),
_varinfo::$(DynamicPPL.VarInfo),
_sampler::$(DynamicPPL.AbstractSampler),
_context::$(DynamicPPL.AbstractContext),
_logps,
)
$unwrap_data_expr
$main_body
Expand Down
50 changes: 27 additions & 23 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@ function tilde(ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
end

"""
tilde_assume(ctx, sampler, right, vn, inds, vi)
tilde_assume(ctx, sampler, right, vn, inds, vi, logps)

Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
accumulate the log probability, and return the sampled value.
accumulate the log probability in `logps` (separately for each thread), and return the
sampled value.

Falls back to `tilde(ctx, sampler, right, vn, inds, vi)`.
"""
function tilde_assume(ctx, sampler, right, vn, inds, vi)
function tilde_assume(ctx, sampler, right, vn, inds, vi, logps)
value, logp = tilde(ctx, sampler, right, vn, inds, vi)
acclogp!(vi, logp)
logps[Threads.threadid()] += logp
return value
end

Expand Down Expand Up @@ -75,28 +76,29 @@ end
tilde_observe(ctx, sampler, right, left, vname, vinds, vi)

Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
accumulate the log probability, and return the observed value.
accumulate the log probability in `logps` (separately for each thread), and return the
observed value.

Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name
and indices; if needed, these can be accessed through this function, though.
"""
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
function tilde_observe(ctx, sampler, right, left, vname, vinds, vi, logps)
logp = tilde(ctx, sampler, right, left, vi)
acclogp!(vi, logp)
logps[Threads.threadid()] += logp
return left
end

"""
tilde_observe(ctx, sampler, right, left, vi)
tilde_observe(ctx, sampler, right, left, vi, logps)

Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the
observed value.
Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability in `logps`
(separately for each thread), and return the observed value.

Falls back to `tilde(ctx, sampler, right, left, vi)`.
"""
function tilde_observe(ctx, sampler, right, left, vi)
function tilde_observe(ctx, sampler, right, left, vi, logps)
logp = tilde(ctx, sampler, right, left, vi)
acclogp!(vi, logp)
logps[Threads.threadid()] += logp
return left
end

Expand Down Expand Up @@ -199,13 +201,14 @@ end
dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)

Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the
model inputs), accumulate the log probability, and return the sampled value.
model inputs), accumulate the log probability in `logps` (separately for each thread), and
return the sampled value.

Falls back to `dot_tilde(ctx, sampler, right, left, vn, inds, vi)`.
"""
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi, logps)
value, logp = dot_tilde(ctx, sampler, right, left, vn, inds, vi)
acclogp!(vi, logp)
logps[Threads.threadid()] += logp
return value
end

Expand Down Expand Up @@ -381,31 +384,32 @@ function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi)
end

"""
dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi)
dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi, logps)

Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs),
accumulate the log probability, and return the observed value.
accumulate the log probability in `logps` (separately for each thread), and return the
observed value.

Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable
name and indices; if needed, these can be accessed through this function, though.
"""
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi)
function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi, logps)
logp = dot_tilde(ctx, sampler, right, left, vi)
acclogp!(vi, logp)
logps[Threads.threadid()] += logp
return left
end

"""
dot_tilde_observe(ctx, sampler, right, left, vi)
dot_tilde_observe(ctx, sampler, right, left, vi, logps)

Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log
probability, and return the observed value.
probability in `logps` (separately for each thread), and return the observed value.

Falls back to `dot_tilde(ctx, sampler, right, left, vi)`.
"""
function dot_tilde_observe(ctx, sampler, right, left, vi)
function dot_tilde_observe(ctx, sampler, right, left, vi, logps)
logp = dot_tilde(ctx, sampler, right, left, vi)
acclogp!(vi, logp)
logps[Threads.threadid()] += logp
return left
end

Expand Down
13 changes: 13 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ function getargs_tilde(expr::Expr)
return
end

"""
initlogps(varinfo)

Return an `MVector` of length `Threads.nthreads()` filled with `zero(getlogp(varinfo))`.

It is used for accumulating the log probability in the model evaluation in a thread-safe
way.
"""
function initlogps(varinfo)
T = typeof(getlogp(varinfo))
return zeros(StaticArrays.MVector{Threads.nthreads(),T})
end

############################################
# Julia 1.2 temporary fix - Julia PR 33303 #
############################################
Expand Down
51 changes: 48 additions & 3 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,13 @@ end

# Test use of internal names
@model testmodel(x) = begin
x[1] ~ Bernoulli(0.5)
x[1] ~  Bernoulli(0.5)
global varinfo_ = _varinfo
global sampler_ = _sampler
global model_ = _model
global context_ = _context
global lp = getlogp(_varinfo)
global logps_ = _logps
global lp = sum(_logps)
return x
end
model = testmodel([1.0])
Expand All @@ -226,6 +227,15 @@ end
@test model_ === model
@test sampler_ === SampleFromPrior()
@test context_ === DefaultContext()
@test length(logps_) == Threads.nthreads()
@test sum(logps_) == lp
for i in 1:length(logps_)
if i == Threads.threadid()
@test logps_[i] == lp
else
@test iszero(logps_[i])
end
end

# test DPPL#61
@model testmodel(z) = begin
Expand All @@ -240,7 +250,7 @@ end
function makemodel(p)
@model testmodel(x) = begin
x[1] ~ Bernoulli(p)
global lp = getlogp(_varinfo)
global lp = sum(_logps)
return x
end
return testmodel
Expand Down Expand Up @@ -580,4 +590,39 @@ end
model = demo()
@test all(iszero(model()) for _ in 1:1000)
end
@testset "threading" begin
@info "Peforming threading tests with $(Threads.nthreads()) threads"

x = rand(10_000)

@model function wthreads(x)
x[1] ~ Normal(0, 1)
Threads.@threads for i in 2:length(x)
x[i] ~ Normal(x[i-1], 1)
end
end

vi = VarInfo()
wthreads(x)(vi)
lp_w_threads = getlogp(vi)

println("With threading:")
@time wthreads(x)(vi)

@model function wothreads(x)
x[1] ~ Normal(0, 1)
for i in 2:length(x)
x[i] ~ Normal(x[i-1], 1)
end
end

vi = VarInfo()
wothreads(x)(vi)
lp_wo_threads = getlogp(vi)

println("Without threading:")
@time wothreads(x)(vi)

@test lp_w_threads ≈ lp_wo_threads
end
end
6 changes: 3 additions & 3 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -471,18 +471,18 @@ include(dir*"/test/test_utils/AllUtils.jl")
@test mapreduce(x -> x.gids, vcat, vi1.metadata) ==
[Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set{Selector}(), Set{Selector}()]

@inferred g_demo_f(vi1, hmc)
@test_broken @inferred g_demo_f(vi1, hmc)
@test mapreduce(x -> x.gids, vcat, vi1.metadata) ==
[Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set([hmc.selector]), Set([hmc.selector])]

g = Sampler(Gibbs(PG(10, :x, :y, :z), HMC(0.4, 8, :w, :u)), g_demo_f)
pg, hmc = g.state.samplers
vi = empty!(TypedVarInfo(vi))
@inferred g_demo_f(vi, SampleFromPrior())
@test_broken @inferred g_demo_f(vi, SampleFromPrior())
pg.state.vi = vi
step!(Random.GLOBAL_RNG, g_demo_f, pg, 1)
vi = pg.state.vi
@inferred g_demo_f(vi, hmc)
@test_broken @inferred g_demo_f(vi, hmc)
@test vi.metadata.x.gids[1] == Set([pg.selector])
@test vi.metadata.y.gids[1] == Set([pg.selector])
@test vi.metadata.z.gids[1] == Set([pg.selector])
Expand Down