Skip to content

Commit 07aab61

Browse files
authored
Allow more flexible initial_params (#1064)
* Enable NamedTuple/Dict initialisation * Add more tests
1 parent 908d402 commit 07aab61

File tree

3 files changed

+61
-8
lines changed

3 files changed

+61
-8
lines changed

src/contexts/init.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,11 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS
9494
params::P
9595
fallback::S
9696
function InitFromParams(
97-
params::AbstractDict{<:VarName}, fallback::Union{AbstractInitStrategy,Nothing}
97+
params::AbstractDict{<:VarName},
98+
fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(),
9899
)
99100
return new{typeof(params),typeof(fallback)}(params, fallback)
100101
end
101-
function InitFromParams(params::AbstractDict{<:VarName})
102-
return InitFromParams(params, InitFromPrior())
103-
end
104102
function InitFromParams(
105103
params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
106104
)

src/sampler.jl

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@ sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden.
5353
"""
5454
init_strategy(::AbstractSampler) = InitFromPrior()
5555

56+
"""
57+
_convert_initial_params(initial_params)
58+
59+
Convert `initial_params` to an `AbstractInitStrategy` if it is not already one.
60+
"""
61+
_convert_initial_params(initial_params::AbstractInitStrategy) = initial_params
62+
function _convert_initial_params(nt::NamedTuple)
63+
@info "Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead."
64+
return InitFromParams(nt)
65+
end
66+
function _convert_initial_params(d::AbstractDict{<:VarName})
67+
@info "Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead."
68+
return InitFromParams(d)
69+
end
70+
function _convert_initial_params(::AbstractVector)
71+
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally an `AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code."
72+
throw(ArgumentError(errmsg))
73+
end
74+
5675
function AbstractMCMC.sample(
5776
rng::Random.AbstractRNG,
5877
model::Model,
@@ -63,7 +82,13 @@ function AbstractMCMC.sample(
6382
kwargs...,
6483
)
6584
return AbstractMCMC.mcmcsample(
66-
rng, model, sampler, N; initial_params, initial_state, kwargs...
85+
rng,
86+
model,
87+
sampler,
88+
N;
89+
initial_params=_convert_initial_params(initial_params),
90+
initial_state,
91+
kwargs...,
6792
)
6893
end
6994

@@ -79,7 +104,15 @@ function AbstractMCMC.sample(
79104
kwargs...,
80105
)
81106
return AbstractMCMC.mcmcsample(
82-
rng, model, sampler, parallel, N, nchains; initial_params, initial_state, kwargs...
107+
rng,
108+
model,
109+
sampler,
110+
parallel,
111+
N,
112+
nchains;
113+
initial_params=map(_convert_initial_params, initial_params),
114+
initial_state,
115+
kwargs...,
83116
)
84117
end
85118

test/sampler.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,27 @@
138138
end
139139
end
140140

141+
# check that Vector no longer works
142+
@test_throws ArgumentError sample(
143+
model, sampler, 1; initial_params=[4, -1], progress=false
144+
)
145+
@test_throws ArgumentError sample(
146+
model, sampler, 1; initial_params=[missing, -1], progress=false
147+
)
148+
141149
# model with two variables: initialization s = 4, m = -1
142150
@model function twovars()
143151
s ~ InverseGamma(2, 3)
144152
return m ~ Normal(0, sqrt(s))
145153
end
146154
model = twovars()
147155
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
148-
let inits = InitFromParams((; s=4, m=-1))
156+
for inits in (
157+
InitFromParams((s=4, m=-1)),
158+
(s=4, m=-1),
159+
InitFromParams(Dict(@varname(s) => 4, @varname(m) => -1)),
160+
Dict(@varname(s) => 4, @varname(m) => -1),
161+
)
149162
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
150163
@test chain[1].metadata.s.vals == [4]
151164
@test chain[1].metadata.m.vals == [-1]
@@ -169,7 +182,16 @@
169182
end
170183

171184
# set only m = -1
172-
for inits in (InitFromParams((; s=missing, m=-1)), InitFromParams((; m=-1)))
185+
for inits in (
186+
InitFromParams((; s=missing, m=-1)),
187+
InitFromParams(Dict(@varname(s) => missing, @varname(m) => -1)),
188+
(; s=missing, m=-1),
189+
Dict(@varname(s) => missing, @varname(m) => -1),
190+
InitFromParams((; m=-1)),
191+
InitFromParams(Dict(@varname(m) => -1)),
192+
(; m=-1),
193+
Dict(@varname(m) => -1),
194+
)
173195
chain = sample(model, sampler, 1; initial_params=inits, progress=false)
174196
@test !ismissing(chain[1].metadata.s.vals[1])
175197
@test chain[1].metadata.m.vals == [-1]

0 commit comments

Comments
 (0)