Skip to content

Commit 5d72768

Browse files
authored
Merge pull request #147 from TuringLang/probmacro_bugfixes
Fix prob macros
2 parents bde1f74 + 7a64232 commit 5d72768

File tree

6 files changed

+158
-32
lines changed

6 files changed

+158
-32
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
88
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
11+
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1213
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1314

@@ -16,6 +17,7 @@ AbstractMCMC = "1"
1617
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
1718
Distributions = "0.22, 0.23"
1819
MacroTools = "0.5.1"
20+
NaturalSort = "1"
1921
ZygoteRules = "0.2"
2022
julia = "1"
2123

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Distributions
55
using Bijectors
66

77
import AbstractMCMC
8+
import NaturalSort
89
import MacroTools
910
import ZygoteRules
1011

src/prob_macro.jl

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,13 @@ function Distributions.loglikelihood(
190190
# Element-wise likelihood for each value in chain
191191
chain = right.chain
192192
ctx = LikelihoodContext()
193-
return map(1:length(chain)) do i
194-
c = chain[i]
195-
_setval!(vi, c)
193+
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
194+
logps = map(iters) do (sample_idx, chain_idx)
195+
setval!(vi, chain, sample_idx, chain_idx)
196196
model(vi, SampleFromPrior(), ctx)
197197
return getlogp(vi)
198198
end
199+
return reshape(logps, size(chain, 1), size(chain, 3))
199200
else
200201
# Likelihood without chain
201202
# Rhs values are used in the context
@@ -231,16 +232,3 @@ end
231232
return :(Model{$(Tuple(missings))}(model.f, $(to_namedtuple_expr(argnames, argvals)),
232233
model.defaults))
233234
end
234-
235-
_setval!(vi::TypedVarInfo, c::AbstractChains) = _setval!(vi.metadata, vi, c)
236-
@generated function _setval!(md::NamedTuple{names}, vi, c) where {names}
237-
return Expr(:block, map(names) do n
238-
quote
239-
for vn in md.$n.vns
240-
val = vec(c[Symbol(vn)])
241-
setval!(vi, val, vn)
242-
settrans!(vi, false, vn)
243-
end
244-
end
245-
end...)
246-
end

src/varinfo.jl

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`.
279279
280280
The values may or may not be transformed to Euclidean space.
281281
"""
282-
setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = val
282+
setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = [val;]
283283

284284
"""
285285
getval(vi::VarInfo, vns::Vector{<:VarName})
@@ -1144,3 +1144,49 @@ function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler)
11441144
setgid!(vi, spl.selector, vn)
11451145
end
11461146
end
1147+
1148+
setval!(vi::AbstractVarInfo, x) = _setval!(vi, values(x), keys(x))
1149+
function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
1150+
return _setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
1151+
end
1152+
1153+
function _setval!(vi::AbstractVarInfo, values, keys)
1154+
for vn in Base.keys(vi)
1155+
_setval_kernel!(vi, vn, values, keys)
1156+
end
1157+
return vi
1158+
end
1159+
_setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, values, keys)
1160+
@generated function _typed_setval!(
1161+
vi::TypedVarInfo,
1162+
metadata::NamedTuple{names},
1163+
values,
1164+
keys
1165+
) where {names}
1166+
updates = map(names) do n
1167+
quote
1168+
for vn in metadata.$n.vns
1169+
_setval_kernel!(vi, vn, values, keys)
1170+
end
1171+
end
1172+
end
1173+
1174+
return quote
1175+
$(updates...)
1176+
return vi
1177+
end
1178+
end
1179+
1180+
function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
1181+
sym = Symbol(vn)
1182+
regex = Regex("^$sym\$|^$sym\\[")
1183+
indices = findall(x -> match(regex, string(x)) !== nothing, keys)
1184+
if !isempty(indices)
1185+
sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural)
1186+
val = mapreduce(vcat, sorted_indices) do i
1187+
values[i]
1188+
end
1189+
setval!(vi, val, vn)
1190+
settrans!(vi, false, vn)
1191+
end
1192+
end

test/prob_macro.jl

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Random.seed!(129)
1010

1111
@testset "prob_macro" begin
1212
@testset "scalar" begin
13-
@model demo(x) = begin
13+
@model function demo(x)
1414
m ~ Normal()
1515
x ~ Normal(m, 1)
1616
end
@@ -29,37 +29,44 @@ Random.seed!(129)
2929
@test logprob"x = xval | m = mval, model = model" == loglike
3030
@test logprob"x = xval, m = mval | model = model" == logjoint
3131

32+
varinfo = VarInfo(demo(missing))
33+
@test logprob"x = xval, m = mval | model = model, varinfo = varinfo" == logjoint
34+
3235
varinfo = VarInfo(demo(xval))
3336
@test logprob"m = mval | model = model, varinfo = varinfo" == logprior
3437
@test logprob"m = mval | x = xval, model = model, varinfo = varinfo" == logprior
3538
@test logprob"x = xval | m = mval, model = model, varinfo = varinfo" == loglike
36-
varinfo = VarInfo(demo(missing))
37-
@test logprob"x = xval, m = mval | model = model, varinfo = varinfo" == logjoint
3839

3940
chain = sample(demo(xval), IS(), iters; save_state = true)
4041
chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple())
41-
lps = logpdf.(Normal.(vec(chain["m"]), 1), xval)
42+
lps = logpdf.(Normal.(chain["m"], 1), xval)
4243
@test logprob"x = xval | chain = chain" == lps
4344
@test logprob"x = xval | chain = chain2, model = model" == lps
44-
varinfo = VarInfo(demo(xval))
4545
@test logprob"x = xval | chain = chain, varinfo = varinfo" == lps
4646
@test logprob"x = xval | chain = chain2, model = model, varinfo = varinfo" == lps
47+
48+
# multiple chains
49+
pchain = chainscat(chain, chain)
50+
pchain2 = chainscat(chain2, chain2)
51+
plps = repeat(lps, 1, 2)
52+
@test logprob"x = xval | chain = pchain" == plps
53+
@test logprob"x = xval | chain = pchain2, model = model" == plps
54+
@test logprob"x = xval | chain = pchain, varinfo = varinfo" == plps
55+
@test logprob"x = xval | chain = pchain2, model = model, varinfo = varinfo" == plps
4756
end
4857

4958
@testset "vector" begin
5059
n = 5
51-
@model demo(x, n = n, ::Type{T} = Float64) where {T} = begin
52-
m = Vector{T}(undef, n)
53-
@. m ~ Normal()
54-
@. x ~ Normal.(m, 1)
60+
@model function demo(x, n = n)
61+
m ~ MvNormal(n, 1.0)
62+
x ~ MvNormal(m, 1.0)
5563
end
5664
mval = rand(n)
5765
xval = rand(n)
5866
iters = 1000
5967

60-
logprior = sum(logpdf.(Normal(), mval))
61-
like(m, x) = sum(logpdf.(Normal.(m, 1), x))
62-
loglike = like(mval, xval)
68+
logprior = logpdf(MvNormal(n, 1.0), mval)
69+
loglike = logpdf(MvNormal(mval, 1.0), xval)
6370
logjoint = logprior + loglike
6471

6572
model = demo(xval)
@@ -76,12 +83,49 @@ Random.seed!(129)
7683
chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple())
7784

7885
names = namesingroup(chain, "m")
79-
lps = map(1:iters) do iter
80-
like([chain[iter, name, 1] for name in names], xval)
81-
end
86+
lps = [
87+
logpdf(MvNormal(chain.value[i, names, j], 1.0), xval)
88+
for i in 1:size(chain, 1), j in 1:size(chain, 3)
89+
]
8290
@test logprob"x = xval | chain = chain" == lps
8391
@test logprob"x = xval | chain = chain2, model = model" == lps
8492
@test logprob"x = xval | chain = chain, varinfo = varinfo" == lps
8593
@test logprob"x = xval | chain = chain2, model = model, varinfo = varinfo" == lps
94+
95+
# multiple chains
96+
pchain = chainscat(chain, chain)
97+
pchain2 = chainscat(chain2, chain2)
98+
plps = repeat(lps, 1, 2)
99+
@test logprob"x = xval | chain = pchain" == plps
100+
@test logprob"x = xval | chain = pchain2, model = model" == plps
101+
@test logprob"x = xval | chain = pchain, varinfo = varinfo" == plps
102+
@test logprob"x = xval | chain = pchain2, model = model, varinfo = varinfo" == plps
103+
end
104+
105+
@testset "issue#137" begin
106+
@model function model1(y, group, n_groups)
107+
σ ~ truncated(Cauchy(0, 1), 0, Inf)
108+
α ~ filldist(Normal(0, 10), n_groups)
109+
μ = α[group]
110+
y ~ MvNormal(μ, σ)
111+
end
112+
113+
y = randn(100)
114+
group = rand(1:4, 100)
115+
n_groups = 4
116+
117+
chain1 = sample(model1(y, group, n_groups), NUTS(0.65), 2_000; save_state=true)
118+
logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain1"
119+
120+
@model function model2(y, group, n_groups)
121+
σ ~ truncated(Cauchy(0, 1), 0, Inf)
122+
α ~ filldist(Normal(0, 10), n_groups)
123+
for i in 1:length(y)
124+
y[i] ~ Normal(α[group[i]], σ)
125+
end
126+
end
127+
128+
chain2 = sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true)
129+
logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain2"
86130
end
87131
end

test/varinfo.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,4 +505,49 @@ include(dir*"/test/test_utils/AllUtils.jl")
505505
@test vi.metadata.w.gids[1] == Set([hmc.selector])
506506
@test vi.metadata.u.gids[1] == Set([hmc.selector])
507507
end
508+
509+
@testset "setval!" begin
510+
@model function testmodel(x)
511+
n = length(x)
512+
s ~ truncated(Normal(), 0, Inf)
513+
m ~ MvNormal(n, 1.0)
514+
x ~ MvNormal(m, s)
515+
end
516+
517+
x = randn(5)
518+
model = testmodel(x)
519+
520+
# UntypedVarInfo
521+
vi = VarInfo()
522+
model(vi, SampleFromPrior())
523+
524+
vicopy = deepcopy(vi)
525+
DynamicPPL.setval!(vicopy, (m = zeros(5),))
526+
@test vicopy[@varname(m)] == zeros(5)
527+
@test vicopy[@varname(s)] == vi[@varname(s)]
528+
529+
DynamicPPL.setval!(vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...))
530+
@test vicopy[@varname(m)] == 1:5
531+
@test vicopy[@varname(s)] == vi[@varname(s)]
532+
533+
DynamicPPL.setval!(vicopy, (s = 42,))
534+
@test vicopy[@varname(m)] == 1:5
535+
@test vicopy[@varname(s)] == 42
536+
537+
# TypedVarInfo
538+
vi = VarInfo(model)
539+
540+
vicopy = deepcopy(vi)
541+
DynamicPPL.setval!(vicopy, (m = zeros(5),))
542+
@test vicopy[@varname(m)] == zeros(5)
543+
@test vicopy[@varname(s)] == vi[@varname(s)]
544+
545+
DynamicPPL.setval!(vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...))
546+
@test vicopy[@varname(m)] == 1:5
547+
@test vicopy[@varname(s)] == vi[@varname(s)]
548+
549+
DynamicPPL.setval!(vicopy, (s = 42,))
550+
@test vicopy[@varname(m)] == 1:5
551+
@test vicopy[@varname(s)] == 42
552+
end
508553
end

0 commit comments

Comments
 (0)