Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

Expand All @@ -16,6 +17,7 @@ AbstractMCMC = "1"
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
Distributions = "0.22, 0.23"
MacroTools = "0.5.1"
NaturalSort = "1"
ZygoteRules = "0.2"
julia = "1"

Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Distributions
using Bijectors

import AbstractMCMC
import NaturalSort
import MacroTools
import ZygoteRules

Expand Down
20 changes: 4 additions & 16 deletions src/prob_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,13 @@ function Distributions.loglikelihood(
# Element-wise likelihood for each value in chain
chain = right.chain
ctx = LikelihoodContext()
return map(1:length(chain)) do i
c = chain[i]
_setval!(vi, c)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
logps = map(iters) do (sample_idx, chain_idx)
setval!(vi, chain, sample_idx, chain_idx)
model(vi, SampleFromPrior(), ctx)
return getlogp(vi)
end
return reshape(logps, size(chain, 1), size(chain, 3))
else
# Likelihood without chain
# Rhs values are used in the context
Expand Down Expand Up @@ -231,16 +232,3 @@ end
return :(Model{$(Tuple(missings))}(model.f, $(to_namedtuple_expr(argnames, argvals)),
model.defaults))
end

_setval!(vi::TypedVarInfo, c::AbstractChains) = _setval!(vi.metadata, vi, c)
@generated function _setval!(md::NamedTuple{names}, vi, c) where {names}
return Expr(:block, map(names) do n
quote
for vn in md.$n.vns
val = vec(c[Symbol(vn)])
setval!(vi, val, vn)
settrans!(vi, false, vn)
end
end
end...)
end
48 changes: 47 additions & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`.

The values may or may not be transformed to Euclidean space.
"""
setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = val
setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = [val;]

"""
getval(vi::VarInfo, vns::Vector{<:VarName})
Expand Down Expand Up @@ -1144,3 +1144,49 @@ function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler)
setgid!(vi, spl.selector, vn)
end
end

setval!(vi::AbstractVarInfo, x) = _setval!(vi, values(x), keys(x))
function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
return _setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
end

function _setval!(vi::AbstractVarInfo, values, keys)
for vn in Base.keys(vi)
_setval_kernel!(vi, vn, values, keys)
end
return vi
end
_setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, values, keys)
@generated function _typed_setval!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think simple generated functions like this can be replaced by a map do-block on the named tuple directly. Last I checked the Julia compiler inferred and inlined it just fine with a do-block.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done in many places in DPPL.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's what I assumed as well but it's not true in general. At least in March I found the generated function to be much more efficient on a simple example (TuringLang/Turing.jl#1167 (comment)), and hence IMO one really has to benchmark every possible switch from generated function to regular map (which is a bit unfortunate since I'd like to just use regular functions wherever possible...).

I'll benchmark this example to see if we could get rid of the generated function here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably ok to keep as-is here, and perform refactoring to change all places in another PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this one of those things where the compiler will "inline" everything up-to some fixed threshold? Same as with recursions. I remember looking into this stuff when working on Bijectors.jl, and there was essentially a fixed depth (I think ~20 or something) at which point the recursion (even though the methods were type-stable) wouldn't be unrolled. I bet it's the same here, where if names is sufficiently small then map will be the same as generated, but if names is large it won't since this can cause issues, e.g. a Turing model with millions of univariate parameters will probably not be too comfortable for the compiler. With that said, I agree with forcing "inlining" by the use of generated functions since if someone is running a model with millions of parameters, it's likely that you're still just looking at <30 different symbols.

vi::TypedVarInfo,
metadata::NamedTuple{names},
values,
keys
) where {names}
updates = map(names) do n
quote
for vn in metadata.$n.vns
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this always going to be just a single element now, which is the same as VarName($n)? If not, could you be so kind and explain? 👼 That was quite confusing to me when working on the other PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK metadata.$n.vns will be a proper vector of multiple elements as soon as the model contains random variables with different indexing, e.g., a vector-valued random variable x for which separate samples x[1] and x[2] are sampled in the generative model definition will lead to a TypedVarInfo object in which metadata is a NamedTuple with key :x and a corresponding value of type Metadata that contains all information of x[1] and x[2] - and hence a field vns that contains the VarName instances for both x[1] and x[2]. See

"""
TypedVarInfo(vi::UntypedVarInfo)
This function finds all the unique `sym`s from the instances of `VarName{sym}` found in
`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the
global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as
a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each
symbol.
"""
for the current implementation that creates TypedVarInfo objects from untyped VarInfo objects (usually after running the model once).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. So if x ~ MvNormal it won't be different, but if x[1] ~ Normal, x[2] ~ Normal then it will be?

Thanks man:)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.

_setval_kernel!(vi, vn, values, keys)
end
end
end

return quote
$(updates...)
return vi
end
end

function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
sym = Symbol(vn)
regex = Regex("^$sym\$|^$sym\\[")
indices = findall(x -> match(regex, string(x)) !== nothing, keys)
if !isempty(indices)
sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural)
val = mapreduce(vcat, sorted_indices) do i
values[i]
end
setval!(vi, val, vn)
settrans!(vi, false, vn)
end
end
74 changes: 59 additions & 15 deletions test/prob_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Random.seed!(129)

@testset "prob_macro" begin
@testset "scalar" begin
@model demo(x) = begin
@model function demo(x)
m ~ Normal()
x ~ Normal(m, 1)
end
Expand All @@ -29,37 +29,44 @@ Random.seed!(129)
@test logprob"x = xval | m = mval, model = model" == loglike
@test logprob"x = xval, m = mval | model = model" == logjoint

varinfo = VarInfo(demo(missing))
@test logprob"x = xval, m = mval | model = model, varinfo = varinfo" == logjoint

varinfo = VarInfo(demo(xval))
@test logprob"m = mval | model = model, varinfo = varinfo" == logprior
@test logprob"m = mval | x = xval, model = model, varinfo = varinfo" == logprior
@test logprob"x = xval | m = mval, model = model, varinfo = varinfo" == loglike
varinfo = VarInfo(demo(missing))
@test logprob"x = xval, m = mval | model = model, varinfo = varinfo" == logjoint

chain = sample(demo(xval), IS(), iters; save_state = true)
chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple())
lps = logpdf.(Normal.(vec(chain["m"]), 1), xval)
lps = logpdf.(Normal.(chain["m"], 1), xval)
@test logprob"x = xval | chain = chain" == lps
@test logprob"x = xval | chain = chain2, model = model" == lps
varinfo = VarInfo(demo(xval))
@test logprob"x = xval | chain = chain, varinfo = varinfo" == lps
@test logprob"x = xval | chain = chain2, model = model, varinfo = varinfo" == lps

# multiple chains
pchain = chainscat(chain, chain)
pchain2 = chainscat(chain2, chain2)
plps = repeat(lps, 1, 2)
@test logprob"x = xval | chain = pchain" == plps
@test logprob"x = xval | chain = pchain2, model = model" == plps
@test logprob"x = xval | chain = pchain, varinfo = varinfo" == plps
@test logprob"x = xval | chain = pchain2, model = model, varinfo = varinfo" == plps
end

@testset "vector" begin
n = 5
@model demo(x, n = n, ::Type{T} = Float64) where {T} = begin
m = Vector{T}(undef, n)
@. m ~ Normal()
@. x ~ Normal.(m, 1)
@model function demo(x, n = n)
m ~ MvNormal(n, 1.0)
x ~ MvNormal(m, 1.0)
end
mval = rand(n)
xval = rand(n)
iters = 1000

logprior = sum(logpdf.(Normal(), mval))
like(m, x) = sum(logpdf.(Normal.(m, 1), x))
loglike = like(mval, xval)
logprior = logpdf(MvNormal(n, 1.0), mval)
loglike = logpdf(MvNormal(mval, 1.0), xval)
logjoint = logprior + loglike

model = demo(xval)
Expand All @@ -76,12 +83,49 @@ Random.seed!(129)
chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple())

names = namesingroup(chain, "m")
lps = map(1:iters) do iter
like([chain[iter, name, 1] for name in names], xval)
end
lps = [
logpdf(MvNormal(chain.value[i, names, j], 1.0), xval)
for i in 1:size(chain, 1), j in 1:size(chain, 3)
]
@test logprob"x = xval | chain = chain" == lps
@test logprob"x = xval | chain = chain2, model = model" == lps
@test logprob"x = xval | chain = chain, varinfo = varinfo" == lps
@test logprob"x = xval | chain = chain2, model = model, varinfo = varinfo" == lps

# multiple chains
pchain = chainscat(chain, chain)
pchain2 = chainscat(chain2, chain2)
plps = repeat(lps, 1, 2)
@test logprob"x = xval | chain = pchain" == plps
@test logprob"x = xval | chain = pchain2, model = model" == plps
@test logprob"x = xval | chain = pchain, varinfo = varinfo" == plps
@test logprob"x = xval | chain = pchain2, model = model, varinfo = varinfo" == plps
end

@testset "issue#137" begin
@model function model1(y, group, n_groups)
σ ~ truncated(Cauchy(0, 1), 0, Inf)
α ~ filldist(Normal(0, 10), n_groups)
μ = α[group]
y ~ MvNormal(μ, σ)
end

y = randn(100)
group = rand(1:4, 100)
n_groups = 4

chain1 = sample(model1(y, group, n_groups), NUTS(0.65), 2_000; save_state=true)
logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain1"

@model function model2(y, group, n_groups)
σ ~ truncated(Cauchy(0, 1), 0, Inf)
α ~ filldist(Normal(0, 10), n_groups)
for i in 1:length(y)
y[i] ~ Normal(α[group[i]], σ)
end
end

chain2 = sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true)
logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain2"
end
end
45 changes: 45 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -505,4 +505,49 @@ include(dir*"/test/test_utils/AllUtils.jl")
@test vi.metadata.w.gids[1] == Set([hmc.selector])
@test vi.metadata.u.gids[1] == Set([hmc.selector])
end

@testset "setval!" begin
@model function testmodel(x)
n = length(x)
s ~ truncated(Normal(), 0, Inf)
m ~ MvNormal(n, 1.0)
x ~ MvNormal(m, s)
end

x = randn(5)
model = testmodel(x)

# UntypedVarInfo
vi = VarInfo()
model(vi, SampleFromPrior())

vicopy = deepcopy(vi)
DynamicPPL.setval!(vicopy, (m = zeros(5),))
@test vicopy[@varname(m)] == zeros(5)
@test vicopy[@varname(s)] == vi[@varname(s)]

DynamicPPL.setval!(vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...))
@test vicopy[@varname(m)] == 1:5
@test vicopy[@varname(s)] == vi[@varname(s)]

DynamicPPL.setval!(vicopy, (s = 42,))
@test vicopy[@varname(m)] == 1:5
@test vicopy[@varname(s)] == 42

# TypedVarInfo
vi = VarInfo(model)

vicopy = deepcopy(vi)
DynamicPPL.setval!(vicopy, (m = zeros(5),))
@test vicopy[@varname(m)] == zeros(5)
@test vicopy[@varname(s)] == vi[@varname(s)]

DynamicPPL.setval!(vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...))
@test vicopy[@varname(m)] == 1:5
@test vicopy[@varname(s)] == vi[@varname(s)]

DynamicPPL.setval!(vicopy, (s = 42,))
@test vicopy[@varname(m)] == 1:5
@test vicopy[@varname(s)] == 42
end
end