1+ # TODO : Make `UniformSampling` and `Prior` algs + just use `Sampler`
2+ # That would let us use all defaults for Sampler, combine it with other samplers etc.
13"""
24Robust initialization method for model parameters in Hamiltonian samplers.
35"""
@@ -17,55 +19,93 @@ function init(rng, dist, ::SampleFromUniform, n::Int)
1719 return istransformable (dist) ? inittrans (rng, dist, n) : rand (rng, dist, n)
1820end
1921
20- """
21- has_eval_num(spl::AbstractSampler)
22-
23- Check whether `spl` has a field called `eval_num` in its state variables or not.
24- """
25- has_eval_num (spl:: SampleFromUniform ) = false
26- has_eval_num (spl:: SampleFromPrior ) = false
27- has_eval_num (spl:: AbstractSampler ) = :eval_num in fieldnames (typeof (spl. state))
28-
29- """
30- An abstract type that mutable sampler state structs inherit from.
31- """
32- abstract type AbstractSamplerState end
33-
3422"""
3523 Sampler{T}
3624
37- Generic interface for implementing inference algorithms.
38- An implementation of an algorithm should include the following:
39-
40- 1. A type specifying the algorithm and its parameters, derived from InferenceAlgorithm
41- 2. A method of `sample` function that produces results of inference, which is where actual inference happens.
42-
43- DynamicPPL translates models to chunks that call the modelling functions at specified points.
44- The dispatch is based on the value of a `sampler` variable.
45- To include a new inference algorithm implements the requirements mentioned above in a separate file,
46- then include that file at the end of this one.
25+ Generic sampler type for inference algorithms in DynamicPPL.
4726"""
48- mutable struct Sampler{T, S<: AbstractSamplerState } <: AbstractSampler
49- alg :: T
50- info :: Dict{Symbol, Any} # sampler infomation
51- selector :: Selector
52- state :: S
27+ struct Sampler{T} <: AbstractSampler
28+ alg:: T
29+ # TODO : remove selector & add space
30+ selector:: Selector
5331end
5432Sampler (alg) = Sampler (alg, Selector ())
5533Sampler (alg, model:: Model ) = Sampler (alg, model, Selector ())
56- Sampler (alg, model:: Model , s:: Selector ) = Sampler (alg, model, s)
34+ Sampler (alg, model:: Model , s:: Selector ) = Sampler (alg, s)
5735
5836# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
59-
60- function AbstractMCMC. step! (
37+ function AbstractMCMC. step (
6138 rng:: Random.AbstractRNG ,
6239 model:: Model ,
6340 sampler:: Union{SampleFromUniform,SampleFromPrior} ,
64- :: Integer ,
65- transition;
41+ state = nothing ;
6642 kwargs...
6743)
6844 vi = VarInfo ()
69- model (vi, sampler)
70- return vi
45+ model (rng, vi, sampler)
46+ return vi, nothing
47+ end
48+
49+ # initial step: general interface for resuming and
50+ function AbstractMCMC. step (
51+ rng:: Random.AbstractRNG ,
52+ model:: Model ,
53+ spl:: Sampler ;
54+ resume_from = nothing ,
55+ kwargs...
56+ )
57+ if resume_from != = nothing
58+ state = loadstate (resume_from)
59+ return AbstractMCMC. step (rng, model, spl, state; kwargs... )
60+ end
61+
62+ # Sample initial values.
63+ _spl = initialsampler (spl)
64+ vi = VarInfo (rng, model, _spl)
65+
66+ # Update the parameters if provided.
67+ if haskey (kwargs, :init_params )
68+ initialize_parameters! (vi, kwargs[:init_params ], spl)
69+
70+ # Update joint log probability.
71+ model (rng, vi, _spl)
72+ end
73+
74+ return initialstep (rng, model, spl, vi; kwargs... )
75+ end
76+
77+ function loadstate end
78+
79+ initialsampler (spl:: Sampler ) = SampleFromPrior ()
80+
81+ function initialstep end
82+
83+ function initialize_parameters! (vi:: AbstractVarInfo , init_params, spl:: Sampler )
84+ @debug " Using passed-in initial variable values" init_params
85+
86+ # Flatten parameters.
87+ init_theta = mapreduce (vcat, init_params) do x
88+ vec ([x;])
89+ end
90+
91+ # Get all values.
92+ linked = islinked (vi, spl)
93+ linked && invlink! (vi, spl)
94+ theta = vi[spl]
95+ length (theta) == length (init_theta_flat) ||
96+ error (" Provided initial value doesn't match the dimension of the model" )
97+
98+ # Update values that are provided.
99+ for i in 1 : length (init_theta)
100+ x = init_theta[i]
101+ if x != = missing
102+ theta[i] = x
103+ end
104+ end
105+
106+ # Update in `vi`.
107+ vi[spl] = theta
108+ linked && link! (vi, spl)
109+
110+ return
71111end
0 commit comments