1+ # TODO : Make `UniformSampling` and `Prior` algs + just use `Sampler`
2+ # That would let us use all defaults for Sampler, combine it with other samplers etc.
13"""
24Robust initialization method for model parameters in Hamiltonian samplers.
35"""
@@ -17,55 +19,123 @@ function init(rng, dist, ::SampleFromUniform, n::Int)
1719 return istransformable (dist) ? inittrans (rng, dist, n) : rand (rng, dist, n)
1820end
1921
20- """
21- has_eval_num(spl::AbstractSampler)
22-
23- Check whether `spl` has a field called `eval_num` in its state variables or not.
24- """
25- has_eval_num (spl:: SampleFromUniform ) = false
26- has_eval_num (spl:: SampleFromPrior ) = false
27- has_eval_num (spl:: AbstractSampler ) = :eval_num in fieldnames (typeof (spl. state))
28-
29- """
30- An abstract type that mutable sampler state structs inherit from.
31- """
32- abstract type AbstractSamplerState end
33-
3422"""
3523 Sampler{T}
3624
37- Generic interface for implementing inference algorithms.
38- An implementation of an algorithm should include the following:
39-
40- 1. A type specifying the algorithm and its parameters, derived from InferenceAlgorithm
41- 2. A method of `sample` function that produces results of inference, which is where actual inference happens.
25+ Generic sampler type for inference algorithms of type `T` in DynamicPPL.
4226
43- DynamicPPL translates models to chunks that call the modelling functions at specified points.
44- The dispatch is based on the value of a `sampler` variable.
45- To include a new inference algorithm implements the requirements mentioned above in a separate file,
46- then include that file at the end of this one.
27+ `Sampler` should implement the AbstractMCMC interface, and in particular
28+ [`AbstractMCMC.step`](@ref). A default implementation of the initial sampling step is
29+ provided that supports resuming sampling from a previous state and setting initial
30+ parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref)
31+ for loading previous states and actually performing the initial sampling step,
32+ respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref)
33+ that specifies how the initial parameter values are sampled if they are not provided.
34+ By default, values are sampled from the prior.
4735"""
48- mutable struct Sampler{T, S<: AbstractSamplerState } <: AbstractSampler
49- alg :: T
50- info :: Dict{Symbol, Any} # sampler infomation
51- selector :: Selector
52- state :: S
36+ struct Sampler{T} <: AbstractSampler
37+ alg:: T
38+ selector:: Selector # Can we remove it?
39+ # TODO : add space such that we can integrate existing external samplers in DynamicPPL
5340end
5441Sampler (alg) = Sampler (alg, Selector ())
5542Sampler (alg, model:: Model ) = Sampler (alg, model, Selector ())
56- Sampler (alg, model:: Model , s:: Selector ) = Sampler (alg, model, s)
43+ Sampler (alg, model:: Model , s:: Selector ) = Sampler (alg, s)
5744
5845# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
59-
60- function AbstractMCMC. step! (
46+ function AbstractMCMC. step (
6147 rng:: Random.AbstractRNG ,
6248 model:: Model ,
6349 sampler:: Union{SampleFromUniform,SampleFromPrior} ,
64- :: Integer ,
65- transition;
50+ state = nothing ;
6651 kwargs...
6752)
6853 vi = VarInfo ()
69- model (vi, sampler)
70- return vi
54+ model (rng, vi, sampler)
55+ return vi, nothing
56+ end
57+
58+ # initial step: general interface for resuming and
59+ function AbstractMCMC. step (
60+ rng:: Random.AbstractRNG ,
61+ model:: Model ,
62+ spl:: Sampler ;
63+ resume_from = nothing ,
64+ kwargs...
65+ )
66+ if resume_from != = nothing
67+ state = loadstate (resume_from)
68+ return AbstractMCMC. step (rng, model, spl, state; kwargs... )
69+ end
70+
71+ # Sample initial values.
72+ _spl = initialsampler (spl)
73+ vi = VarInfo (rng, model, _spl)
74+
75+ # Update the parameters if provided.
76+ if haskey (kwargs, :init_params )
77+ initialize_parameters! (vi, kwargs[:init_params ], spl)
78+
79+ # Update joint log probability.
80+ model (rng, vi, _spl)
81+ end
82+
83+ return initialstep (rng, model, spl, vi; kwargs... )
84+ end
85+
86+ """
87+ loadstate(data)
88+
89+ Load sampler state from `data`.
90+ """
91+ function loadstate end
92+
93+ """
94+ initialsampler(sampler::Sampler)
95+
96+ Return the sampler that is used for generating the initial parameters when sampling with
97+ `sampler`.
98+
99+ By default, it returns an instance of [`SampleFromPrior`](@ref).
100+ """
101+ initialsampler (spl:: Sampler ) = SampleFromPrior ()
102+
103+ function initialize_parameters! (vi:: AbstractVarInfo , init_params, spl:: Sampler )
104+ @debug " Using passed-in initial variable values" init_params
105+
106+ # Flatten parameters.
107+ init_theta = mapreduce (vcat, init_params) do x
108+ vec ([x;])
109+ end
110+
111+ # Get all values.
112+ linked = islinked (vi, spl)
113+ linked && invlink! (vi, spl)
114+ theta = vi[spl]
115+ length (theta) == length (init_theta_flat) ||
116+ error (" Provided initial value doesn't match the dimension of the model" )
117+
118+ # Update values that are provided.
119+ for i in 1 : length (init_theta)
120+ x = init_theta[i]
121+ if x != = missing
122+ theta[i] = x
123+ end
124+ end
125+
126+ # Update in `vi`.
127+ vi[spl] = theta
128+ linked && link! (vi, spl)
129+
130+ return
71131end
132+
133+ """
134+ initialstep(rng, model, sampler, varinfo; kwargs...)
135+
136+ Perform the initial sampling step of the `sampler` for the `model`.
137+
138+ The `varinfo` contains the initial samples, which can be provided by the user or
139+ sampled randomly.
140+ """
141+ function initialstep end
0 commit comments