@@ -4,38 +4,98 @@ function _check_model(model::DynamicPPL.Model)
44    new_model =  DynamicPPL. setleafcontext (model, DynamicPPL. InitContext ())
55    return  DynamicPPL. check_model (new_model, VarInfo (); error_on_failure= true )
66end 
7- function  _check_model (model:: DynamicPPL.Model , alg :: InferenceAlgorithm )
7+ function  _check_model (model:: DynamicPPL.Model , :: AbstractSampler )
88    return  _check_model (model)
99end 
1010
11+ """ 
12+     Turing.Inference.init_strategy(spl::AbstractSampler) 
13+ 
14+ Get the default initialization strategy for a given sampler `spl`, i.e. how initial 
15+ parameters for sampling are chosen if not specified by the user. By default, this is 
16+ `InitFromPrior()`, which samples initial parameters from the prior distribution. 
17+ """ 
18+ init_strategy (:: AbstractSampler ) =  DynamicPPL. InitFromPrior ()
19+ 
20+ """ 
21+     _convert_initial_params(initial_params) 
22+ 
23+ Convert `initial_params` to a `DynamicPPl.AbstractInitStrategy` if it is not already one, or 
24+ throw a useful error message. 
25+ """ 
26+ _convert_initial_params (initial_params:: DynamicPPL.AbstractInitStrategy ) =  initial_params
27+ function  _convert_initial_params (nt:: NamedTuple )
28+     @info  " Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead." 
29+     return  DynamicPPL. InitFromParams (nt)
30+ end 
31+ function  _convert_initial_params (d:: AbstractDict{<:VarName} )
32+     @info  " Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead." 
33+     return  DynamicPPL. InitFromParams (d)
34+ end 
35+ function  _convert_initial_params (:: AbstractVector{<:Real} )
36+     errmsg =  " `initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally a `DynamicPPL.AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code." 
37+     throw (ArgumentError (errmsg))
38+ end 
39+ function  _convert_initial_params (@nospecialize (_:: Any ))
40+     errmsg =  " `initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or a `DynamicPPL.AbstractInitStrategy`." 
41+     throw (ArgumentError (errmsg))
42+ end 
43+ 
44+ """ 
45+     default_varinfo(rng, model, sampler) 
46+ 
47+ Return a default varinfo object for the given `model` and `sampler`. 
48+ The default method for this returns a NTVarInfo (i.e. 'typed varinfo'). 
49+ """ 
50+ function  default_varinfo (
51+     rng:: Random.AbstractRNG , model:: DynamicPPL.Model , :: AbstractSampler 
52+ )
53+     #  Note that in `AbstractMCMC.step`, the values in the varinfo returned here are
54+     #  immediately overwritten by a subsequent call to `init!!`. The reason why we
55+     #  _do_ create a varinfo with parameters here (as opposed to simply returning
56+     #  an empty `typed_varinfo(VarInfo())`) is to avoid issues where pushing to an empty
57+     #  typed VarInfo would fail. This can happen if two VarNames have different types
58+     #  but share the same symbol (e.g. `x.a` and `x.b`).
59+     #  TODO (mhauru) Fix push!! to work with arbitrary lens types, and then remove the arguments
60+     #  and return an empty VarInfo instead.
61+     return  DynamicPPL. typed_varinfo (VarInfo (rng, model))
62+ end 
63+ 
1164# ########################################
1265#  Default definitions for the interface #
1366# ########################################
1467
15- const  DEFAULT_CHAIN_TYPE =  MCMCChains. Chains
16- 
1768function  AbstractMCMC. sample (
18-     model:: AbstractModel , alg :: InferenceAlgorithm , N:: Integer ; kwargs... 
69+     model:: DynamicPPL.Model , spl :: AbstractSampler , N:: Integer ; kwargs... 
1970)
20-     return  AbstractMCMC. sample (Random. default_rng (), model, alg , N; kwargs... )
71+     return  AbstractMCMC. sample (Random. default_rng (), model, spl , N; kwargs... )
2172end 
2273
2374function  AbstractMCMC. sample (
2475    rng:: AbstractRNG ,
25-     model:: AbstractModel ,
26-     alg :: InferenceAlgorithm ,
76+     model:: DynamicPPL.Model ,
77+     spl :: AbstractSampler ,
2778    N:: Integer ;
79+     initial_params= init_strategy (spl),
2880    check_model:: Bool = true ,
2981    chain_type= DEFAULT_CHAIN_TYPE,
3082    kwargs... ,
3183)
32-     check_model &&  _check_model (model, alg)
33-     return  AbstractMCMC. sample (rng, model, Sampler (alg), N; chain_type, kwargs... )
84+     check_model &&  _check_model (model, spl)
85+     return  AbstractMCMC. mcmcsample (
86+         rng,
87+         model,
88+         spl,
89+         N;
90+         initial_params= _convert_initial_params (initial_params),
91+         chain_type,
92+         kwargs... ,
93+     )
3494end 
3595
3696function  AbstractMCMC. sample (
37-     model:: AbstractModel ,
38-     alg:: InferenceAlgorithm ,
97+     model:: DynamicPPL.Model ,
98+     alg:: AbstractSampler ,
3999    ensemble:: AbstractMCMC.AbstractMCMCEnsemble ,
40100    N:: Integer ,
41101    n_chains:: Integer ;
@@ -47,18 +107,66 @@ function AbstractMCMC.sample(
47107end 
48108
49109function  AbstractMCMC. sample (
50-     rng:: AbstractRNG ,
51-     model:: AbstractModel ,
52-     alg :: InferenceAlgorithm ,
110+     rng:: Random. AbstractRNG
111+     model:: DynamicPPL.Model ,
112+     spl :: AbstractSampler ,
53113    ensemble:: AbstractMCMC.AbstractMCMCEnsemble ,
54114    N:: Integer ,
55115    n_chains:: Integer ;
56116    chain_type= DEFAULT_CHAIN_TYPE,
57117    check_model:: Bool = true ,
118+     initial_params= fill (init_strategy (spl), n_chains),
58119    kwargs... ,
59120)
60-     check_model &&  _check_model (model, alg)
61-     return  AbstractMCMC. sample (
62-         rng, model, Sampler (alg), ensemble, N, n_chains; chain_type, kwargs... 
121+     check_model &&  _check_model (model, spl)
122+     if  ! (initial_params isa  AbstractVector) ||  length (initial_params) !=  n_chains
123+         errmsg =  " `initial_params` must be an AbstractVector of length `n_chains`; one element per chain" 
124+         throw (ArgumentError (errmsg))
125+     end 
126+     return  AbstractMCMC. mcmcsample (
127+         rng,
128+         model,
129+         spl,
130+         ensemble,
131+         N,
132+         n_chains;
133+         chain_type,
134+         initial_params= map (_convert_initial_params, initial_params),
135+         kwargs... ,
63136    )
64137end 
138+ 
139+ function  loadstate (chain:: MCMCChains.Chains )
140+     if  ! haskey (chain. info, :samplerstate )
141+         throw (
142+             ArgumentError (
143+                 " the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`" 
144+             ),
145+         )
146+     end 
147+     return  chain. info[:samplerstate ]
148+ end 
149+ 
150+ #  TODO (penelopeysm): Remove initialstep and generalise MCMC sampling procedures
151+ function  initialstep end 
152+ 
153+ function  AbstractMCMC. step (
154+     rng:: Random.AbstractRNG ,
155+     model:: DynamicPPL.Model ,
156+     spl:: AbstractSampler ;
157+     initial_params,
158+     kwargs... ,
159+ )
160+     #  Generate the default varinfo. Note that any parameters inside this varinfo
161+     #  will be immediately overwritten by the next call to `init!!`.
162+     vi =  default_varinfo (rng, model, spl)
163+ 
164+     #  Fill it with initial parameters. Note that, if `InitFromParams` is used, the
165+     #  parameters provided must be in unlinked space (when inserted into the
166+     #  varinfo, they will be adjusted to match the linking status of the
167+     #  varinfo).
168+     _, vi =  DynamicPPL. init!! (rng, model, vi, initial_params)
169+ 
170+     #  Call the actual function that does the first step.
171+     return  initialstep (rng, model, spl, vi; initial_params, kwargs... )
172+ end 
0 commit comments