|
113 | 113 | end |
114 | 114 | end |
115 | 115 |
|
116 | | - @testset "SampleFromPrior and SampleUniform" begin |
117 | | - @model function gdemo(x, y) |
118 | | - s ~ InverseGamma(2, 3) |
119 | | - m ~ Normal(2.0, sqrt(s)) |
120 | | - x ~ Normal(m, sqrt(s)) |
121 | | - return y ~ Normal(m, sqrt(s)) |
122 | | - end |
123 | | - |
124 | | - model = gdemo(1.0, 2.0) |
125 | | - N = 1_000 |
126 | | - |
127 | | - chains = sample(model, SampleFromPrior(), N; progress=false) |
128 | | - @test chains isa Vector{<:VarInfo} |
129 | | - @test length(chains) == N |
130 | | - |
131 | | - # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. |
132 | | - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 |
133 | | - |
134 | | - # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. |
135 | | - @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 |
136 | | - |
137 | | - chains = sample(model, SampleFromUniform(), N; progress=false) |
138 | | - @test chains isa Vector{<:VarInfo} |
139 | | - @test length(chains) == N |
140 | | - |
141 | | - # `m` is Gaussian, i.e. no transformation is used, so it |
142 | | - # will be drawn from U[-2, 2] and its mean should be 0. |
143 | | - @test mean(vi[@varname(m)] for vi in chains) ≈ 0.0 atol = 0.1 |
144 | | - |
145 | | - # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. |
146 | | - @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 |
147 | | - end |
148 | | - |
149 | | - @testset "init" begin |
150 | | - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS |
151 | | - N = 1000 |
152 | | - chain_init = sample(model, SampleFromUniform(), N; progress=false) |
153 | | - |
154 | | - for vn in keys(first(chain_init)) |
155 | | - if AbstractPPL.subsumes(@varname(s), vn) |
156 | | - # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. |
157 | | - dist = InverseGamma(2, 3) |
158 | | - b = DynamicPPL.link_transform(dist) |
159 | | - @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 |
160 | | - elseif AbstractPPL.subsumes(@varname(m), vn) |
161 | | - # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. |
162 | | - @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 |
163 | | - else |
164 | | - error("Unknown variable name: $vn") |
165 | | - end |
166 | | - end |
167 | | - end |
168 | | - end |
169 | | - |
170 | 116 | @testset "Initial parameters" begin |
171 | 117 | # dummy algorithm that just returns initial value and does not perform any sampling |
172 | 118 | abstract type OnlyInitAlg end |
|
0 commit comments