|  | 
| 1 | 1 | @testset "sampler.jl" begin | 
|  | 2 | +    @testset "initial_state and resume_from kwargs" begin | 
|  | 3 | +        # Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our | 
|  | 4 | +        # overloaded method. | 
|  | 5 | +        @model f() = x ~ Normal() | 
|  | 6 | +        model = f() | 
|  | 7 | +        # This sampler just returns the state it was given as its 'sample' | 
|  | 8 | +        struct S <: AbstractMCMC.AbstractSampler end | 
|  | 9 | +        function AbstractMCMC.step( | 
|  | 10 | +            rng::Random.AbstractRNG, | 
|  | 11 | +            model::Model, | 
|  | 12 | +            sampler::Sampler{<:S}, | 
|  | 13 | +            state=nothing; | 
|  | 14 | +            kwargs..., | 
|  | 15 | +        ) | 
|  | 16 | +            if state === nothing | 
|  | 17 | +                s = rand() | 
|  | 18 | +                return s, s | 
|  | 19 | +            else | 
|  | 20 | +                return state, state | 
|  | 21 | +            end | 
|  | 22 | +        end | 
|  | 23 | +        spl = Sampler(S()) | 
|  | 24 | + | 
|  | 25 | +        function AbstractMCMC.bundle_samples( | 
|  | 26 | +            samples::Vector{Float64}, | 
|  | 27 | +            model::Model, | 
|  | 28 | +            sampler::Sampler{<:S}, | 
|  | 29 | +            state, | 
|  | 30 | +            chain_type::Type{MCMCChains.Chains}; | 
|  | 31 | +            kwargs..., | 
|  | 32 | +        ) | 
|  | 33 | +            return MCMCChains.Chains(samples, [:x]; info=(samplerstate=state,)) | 
|  | 34 | +        end | 
|  | 35 | + | 
|  | 36 | +        N_iters, N_chains = 10, 3 | 
|  | 37 | + | 
|  | 38 | +        @testset "single-chain sampling" begin | 
|  | 39 | +            chn = sample(model, spl, N_iters; progress=false, chain_type=MCMCChains.Chains) | 
|  | 40 | +            initial_value = chn[:x][1] | 
|  | 41 | +            @test all(chn[:x] .== initial_value) # sanity check | 
|  | 42 | +            # using `initial_state` | 
|  | 43 | +            chn2 = sample( | 
|  | 44 | +                model, | 
|  | 45 | +                spl, | 
|  | 46 | +                N_iters; | 
|  | 47 | +                progress=false, | 
|  | 48 | +                initial_state=chn.info.samplerstate, | 
|  | 49 | +                chain_type=MCMCChains.Chains, | 
|  | 50 | +            ) | 
|  | 51 | +            @test all(chn2[:x] .== initial_value) | 
|  | 52 | +            # using `resume_from` | 
|  | 53 | +            chn3 = sample( | 
|  | 54 | +                model, | 
|  | 55 | +                spl, | 
|  | 56 | +                N_iters; | 
|  | 57 | +                progress=false, | 
|  | 58 | +                resume_from=chn, | 
|  | 59 | +                chain_type=MCMCChains.Chains, | 
|  | 60 | +            ) | 
|  | 61 | +            @test all(chn3[:x] .== initial_value) | 
|  | 62 | +        end | 
|  | 63 | + | 
|  | 64 | +        @testset "multiple-chain sampling" begin | 
|  | 65 | +            chn = sample( | 
|  | 66 | +                model, | 
|  | 67 | +                spl, | 
|  | 68 | +                MCMCThreads(), | 
|  | 69 | +                N_iters, | 
|  | 70 | +                N_chains; | 
|  | 71 | +                progress=false, | 
|  | 72 | +                chain_type=MCMCChains.Chains, | 
|  | 73 | +            ) | 
|  | 74 | +            initial_value = chn[:x][1, :] | 
|  | 75 | +            @test all(i -> chn[:x][i, :] == initial_value, 1:N_iters) # sanity check | 
|  | 76 | +            # using `initial_state` | 
|  | 77 | +            chn2 = sample( | 
|  | 78 | +                model, | 
|  | 79 | +                spl, | 
|  | 80 | +                MCMCThreads(), | 
|  | 81 | +                N_iters, | 
|  | 82 | +                N_chains; | 
|  | 83 | +                progress=false, | 
|  | 84 | +                initial_state=chn.info.samplerstate, | 
|  | 85 | +                chain_type=MCMCChains.Chains, | 
|  | 86 | +            ) | 
|  | 87 | +            @test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters) | 
|  | 88 | +            # using `resume_from` | 
|  | 89 | +            chn3 = sample( | 
|  | 90 | +                model, | 
|  | 91 | +                spl, | 
|  | 92 | +                MCMCThreads(), | 
|  | 93 | +                N_iters, | 
|  | 94 | +                N_chains; | 
|  | 95 | +                progress=false, | 
|  | 96 | +                resume_from=chn, | 
|  | 97 | +                chain_type=MCMCChains.Chains, | 
|  | 98 | +            ) | 
|  | 99 | +            @test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters) | 
|  | 100 | +        end | 
|  | 101 | +    end | 
|  | 102 | + | 
| 2 | 103 |     @testset "SampleFromPrior and SampleUniform" begin | 
| 3 | 104 |         @model function gdemo(x, y) | 
| 4 | 105 |             s ~ InverseGamma(2, 3) | 
|  | 
0 commit comments