@@ -10,7 +10,7 @@ Random.seed!(129)
1010
1111@testset " prob_macro" begin
1212 @testset " scalar" begin
13- @model demo (x) = begin
13+ @model function demo (x)
1414 m ~ Normal ()
1515 x ~ Normal (m, 1 )
1616 end
@@ -29,37 +29,44 @@ Random.seed!(129)
2929 @test logprob " x = xval | m = mval, model = model" == loglike
3030 @test logprob " x = xval, m = mval | model = model" == logjoint
3131
32+ varinfo = VarInfo (demo (missing ))
33+ @test logprob " x = xval, m = mval | model = model, varinfo = varinfo" == logjoint
34+
3235 varinfo = VarInfo (demo (xval))
3336 @test logprob " m = mval | model = model, varinfo = varinfo" == logprior
3437 @test logprob " m = mval | x = xval, model = model, varinfo = varinfo" == logprior
3538 @test logprob " x = xval | m = mval, model = model, varinfo = varinfo" == loglike
36- varinfo = VarInfo (demo (missing ))
37- @test logprob " x = xval, m = mval | model = model, varinfo = varinfo" == logjoint
3839
3940 chain = sample (demo (xval), IS (), iters; save_state = true )
4041 chain2 = Chains (chain. value, chain. logevidence, chain. name_map, NamedTuple ())
41- lps = logpdf .(Normal .(vec ( chain[" m" ]) , 1 ), xval)
42+ lps = logpdf .(Normal .(chain[" m" ], 1 ), xval)
4243 @test logprob " x = xval | chain = chain" == lps
4344 @test logprob " x = xval | chain = chain2, model = model" == lps
44- varinfo = VarInfo (demo (xval))
4545 @test logprob " x = xval | chain = chain, varinfo = varinfo" == lps
4646 @test logprob " x = xval | chain = chain2, model = model, varinfo = varinfo" == lps
47+
48+ # multiple chains
49+ pchain = chainscat (chain, chain)
50+ pchain2 = chainscat (chain2, chain2)
51+ plps = repeat (lps, 1 , 2 )
52+ @test logprob " x = xval | chain = pchain" == plps
53+ @test logprob " x = xval | chain = pchain2, model = model" == plps
54+ @test logprob " x = xval | chain = pchain, varinfo = varinfo" == plps
55+ @test logprob " x = xval | chain = pchain2, model = model, varinfo = varinfo" == plps
4756 end
4857
4958 @testset " vector" begin
5059 n = 5
51- @model demo (x, n = n, :: Type{T} = Float64) where {T} = begin
52- m = Vector {T} (undef, n)
53- @. m ~ Normal ()
54- @. x ~ Normal .(m, 1 )
60+ @model function demo (x, n = n)
61+ m ~ MvNormal (n, 1.0 )
62+ x ~ MvNormal (m, 1.0 )
5563 end
5664 mval = rand (n)
5765 xval = rand (n)
5866 iters = 1000
5967
60- logprior = sum (logpdf .(Normal (), mval))
61- like (m, x) = sum (logpdf .(Normal .(m, 1 ), x))
62- loglike = like (mval, xval)
68+ logprior = logpdf (MvNormal (n, 1.0 ), mval)
69+ loglike = logpdf (MvNormal (mval, 1.0 ), xval)
6370 logjoint = logprior + loglike
6471
6572 model = demo (xval)
@@ -76,12 +83,49 @@ Random.seed!(129)
7683 chain2 = Chains (chain. value, chain. logevidence, chain. name_map, NamedTuple ())
7784
7885 names = namesingroup (chain, " m" )
79- lps = map (1 : iters) do iter
80- like ([chain[iter, name, 1 ] for name in names], xval)
81- end
86+ lps = [
87+ logpdf (MvNormal (chain. value[i, names, j], 1.0 ), xval)
88+ for i in 1 : size (chain, 1 ), j in 1 : size (chain, 3 )
89+ ]
8290 @test logprob " x = xval | chain = chain" == lps
8391 @test logprob " x = xval | chain = chain2, model = model" == lps
8492 @test logprob " x = xval | chain = chain, varinfo = varinfo" == lps
8593 @test logprob " x = xval | chain = chain2, model = model, varinfo = varinfo" == lps
94+
95+ # multiple chains
96+ pchain = chainscat (chain, chain)
97+ pchain2 = chainscat (chain2, chain2)
98+ plps = repeat (lps, 1 , 2 )
99+ @test logprob " x = xval | chain = pchain" == plps
100+ @test logprob " x = xval | chain = pchain2, model = model" == plps
101+ @test logprob " x = xval | chain = pchain, varinfo = varinfo" == plps
102+ @test logprob " x = xval | chain = pchain2, model = model, varinfo = varinfo" == plps
103+ end
104+
105+ @testset " issue#137" begin
106+ @model function model1 (y, group, n_groups)
107+ σ ~ truncated (Cauchy (0 , 1 ), 0 , Inf )
108+ α ~ filldist (Normal (0 , 10 ), n_groups)
109+ μ = α[group]
110+ y ~ MvNormal (μ, σ)
111+ end
112+
113+ y = randn (100 )
114+ group = rand (1 : 4 , 100 )
115+ n_groups = 4
116+
117+ chain1 = sample (model1 (y, group, n_groups), NUTS (0.65 ), 2_000 ; save_state= true )
118+ logprob " y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain1"
119+
120+ @model function model2 (y, group, n_groups)
121+ σ ~ truncated (Cauchy (0 , 1 ), 0 , Inf )
122+ α ~ filldist (Normal (0 , 10 ), n_groups)
123+ for i in 1 : length (y)
124+ y[i] ~ Normal (α[group[i]], σ)
125+ end
126+ end
127+
128+ chain2 = sample (model2 (y, group, n_groups), NUTS (0.65 ), 2_000 ; save_state= true )
129+ logprob " y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain2"
86130 end
87131end
0 commit comments