|  | 
| 113 | 113 |         end | 
| 114 | 114 |     end | 
| 115 | 115 | 
 | 
| 116 |  | -    @testset "SampleFromPrior and SampleUniform" begin | 
| 117 |  | -        @model function gdemo(x, y) | 
| 118 |  | -            s ~ InverseGamma(2, 3) | 
| 119 |  | -            m ~ Normal(2.0, sqrt(s)) | 
| 120 |  | -            x ~ Normal(m, sqrt(s)) | 
| 121 |  | -            return y ~ Normal(m, sqrt(s)) | 
| 122 |  | -        end | 
| 123 |  | - | 
| 124 |  | -        model = gdemo(1.0, 2.0) | 
| 125 |  | -        N = 1_000 | 
| 126 |  | - | 
| 127 |  | -        chains = sample(model, SampleFromPrior(), N; progress=false) | 
| 128 |  | -        @test chains isa Vector{<:VarInfo} | 
| 129 |  | -        @test length(chains) == N | 
| 130 |  | - | 
| 131 |  | -        # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. | 
| 132 |  | -        @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 | 
| 133 |  | - | 
| 134 |  | -        # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. | 
| 135 |  | -        @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 | 
| 136 |  | - | 
| 137 |  | -        chains = sample(model, SampleFromUniform(), N; progress=false) | 
| 138 |  | -        @test chains isa Vector{<:VarInfo} | 
| 139 |  | -        @test length(chains) == N | 
| 140 |  | - | 
| 141 |  | -        # `m` is Gaussian, i.e. no transformation is used, so it | 
| 142 |  | -        # will be drawn from U[-2, 2] and its mean should be 0. | 
| 143 |  | -        @test mean(vi[@varname(m)] for vi in chains) ≈ 0.0 atol = 0.1 | 
| 144 |  | - | 
| 145 |  | -        # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. | 
| 146 |  | -        @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 | 
| 147 |  | -    end | 
| 148 |  | - | 
| 149 |  | -    @testset "init" begin | 
| 150 |  | -        @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS | 
| 151 |  | -            N = 1000 | 
| 152 |  | -            chain_init = sample(model, SampleFromUniform(), N; progress=false) | 
| 153 |  | - | 
| 154 |  | -            for vn in keys(first(chain_init)) | 
| 155 |  | -                if AbstractPPL.subsumes(@varname(s), vn) | 
| 156 |  | -                    # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. | 
| 157 |  | -                    dist = InverseGamma(2, 3) | 
| 158 |  | -                    b = DynamicPPL.link_transform(dist) | 
| 159 |  | -                    @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 | 
| 160 |  | -                elseif AbstractPPL.subsumes(@varname(m), vn) | 
| 161 |  | -                    # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. | 
| 162 |  | -                    @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 | 
| 163 |  | -                else | 
| 164 |  | -                    error("Unknown variable name: $vn") | 
| 165 |  | -                end | 
| 166 |  | -            end | 
| 167 |  | -        end | 
| 168 |  | -    end | 
| 169 |  | - | 
| 170 | 116 |     @testset "Initial parameters" begin | 
| 171 | 117 |         # dummy algorithm that just returns initial value and does not perform any sampling | 
| 172 | 118 |         abstract type OnlyInitAlg end | 
|  | 
0 commit comments