Skip to content

Commit f927308

Browse files
penelopeysmsunxd3
andauthored
Remove Sampler, remove InferenceAlgorithm, transfer initialstep, init_strategy, and other functions from DynamicPPL to Turing (#2689)
* Remove `Sampler` and move its interface to Turing * Test fixes (this is admittedly quite tiring) * Fix a couple of Gibbs tests (no doubt there are more) * actually fix the Gibbs ones * actually fix it this time * fix typo * point to breaking * Improve loadstate implementation * Re-add tests that were removed from DynamicPPL * Fix qualifier in src/mcmc/external_sampler.jl Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> * Remove the default argument for initial_params * [skip ci] Remove DynamicPPL sources --------- Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com>
1 parent bbbde35 commit f927308

32 files changed

+623
-409
lines changed

Project.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,3 @@ julia = "1.10.8"
9090
[extras]
9191
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
9292
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
93-
94-
[sources]
95-
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"}

ext/TuringDynamicHMCExt.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S}
4444
stepsize::S
4545
end
4646

47-
function DynamicPPL.initialstep(
47+
function Turing.Inference.initialstep(
4848
rng::Random.AbstractRNG,
4949
model::DynamicPPL.Model,
50-
spl::DynamicPPL.Sampler{<:DynamicNUTS},
50+
spl::DynamicNUTS,
5151
vi::DynamicPPL.AbstractVarInfo;
5252
kwargs...,
5353
)
@@ -59,7 +59,7 @@ function DynamicPPL.initialstep(
5959

6060
# Define log-density function.
6161
= DynamicPPL.LogDensityFunction(
62-
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
62+
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
6363
)
6464

6565
# Perform initial step.
@@ -80,14 +80,14 @@ end
8080
function AbstractMCMC.step(
8181
rng::Random.AbstractRNG,
8282
model::DynamicPPL.Model,
83-
spl::DynamicPPL.Sampler{<:DynamicNUTS},
83+
spl::DynamicNUTS,
8484
state::DynamicNUTSState;
8585
kwargs...,
8686
)
8787
# Compute next sample.
8888
vi = state.vi
8989
= state.logdensity
90-
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
90+
steps = DynamicHMC.mcmc_steps(rng, spl.sampler, state.metric, ℓ, state.stepsize)
9191
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
9292

9393
# Create next sample and state.

src/Turing.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ export
160160
maximum_a_posteriori,
161161
maximum_likelihood,
162162
MAP,
163-
MLE
163+
MLE,
164+
# Chain save/resume
165+
loadstate
164166

165167
end

src/mcmc/Inference.jl

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ using DynamicPPL:
2222
getsym,
2323
getdist,
2424
Model,
25-
Sampler,
2625
DefaultContext
2726
using Distributions, Libtask, Bijectors
2827
using DistributionsAD: VectorOfMultivariate
@@ -50,8 +49,7 @@ import Random
5049
import MCMCChains
5150
import StatsBase: predict
5251

53-
export InferenceAlgorithm,
54-
Hamiltonian,
52+
export Hamiltonian,
5553
StaticHamiltonian,
5654
AdaptiveHamiltonian,
5755
MH,
@@ -71,15 +69,16 @@ export InferenceAlgorithm,
7169
RepeatSampler,
7270
Prior,
7371
predict,
74-
externalsampler
72+
externalsampler,
73+
init_strategy,
74+
loadstate
7575

76-
###############################################
77-
# Abstract interface for inference algorithms #
78-
###############################################
79-
80-
const TURING_CHAIN_TYPE = MCMCChains.Chains
76+
#########################################
77+
# Generic AbstractMCMC methods dispatch #
78+
#########################################
8179

82-
include("algorithm.jl")
80+
const DEFAULT_CHAIN_TYPE = MCMCChains.Chains
81+
include("abstractmcmc.jl")
8382

8483
####################
8584
# Sampler wrappers #
@@ -312,8 +311,8 @@ getlogevidence(transitions, sampler, state) = missing
312311
# Default MCMCChains.Chains constructor.
313312
function AbstractMCMC.bundle_samples(
314313
ts::Vector{<:Transition},
315-
model::AbstractModel,
316-
spl::Union{Sampler{<:InferenceAlgorithm},RepeatSampler},
314+
model::DynamicPPL.Model,
315+
spl::AbstractSampler,
317316
state,
318317
chain_type::Type{MCMCChains.Chains};
319318
save_state=false,
@@ -374,8 +373,8 @@ end
374373

375374
function AbstractMCMC.bundle_samples(
376375
ts::Vector{<:Transition},
377-
model::AbstractModel,
378-
spl::Union{Sampler{<:InferenceAlgorithm},RepeatSampler},
376+
model::DynamicPPL.Model,
377+
spl::AbstractSampler,
379378
state,
380379
chain_type::Type{Vector{NamedTuple}};
381380
kwargs...,
@@ -416,7 +415,7 @@ function group_varnames_by_symbol(vns)
416415
return d
417416
end
418417

419-
function save(c::MCMCChains.Chains, spl::Sampler, model, vi, samples)
418+
function save(c::MCMCChains.Chains, spl::AbstractSampler, model, vi, samples)
420419
nt = NamedTuple{(:sampler, :model, :vi, :samples)}((spl, model, deepcopy(vi), samples))
421420
return setinfo(c, merge(nt, c.info))
422421
end
@@ -435,18 +434,12 @@ include("sghmc.jl")
435434
include("emcee.jl")
436435
include("prior.jl")
437436

438-
#################################################
439-
# Generic AbstractMCMC methods dispatch #
440-
#################################################
441-
442-
include("abstractmcmc.jl")
443-
444437
################
445438
# Typing tools #
446439
################
447440

448441
function DynamicPPL.get_matching_type(
449-
spl::Sampler{<:Union{PG,SMC}}, vi, ::Type{TV}
442+
spl::Union{PG,SMC}, vi, ::Type{TV}
450443
) where {T,N,TV<:Array{T,N}}
451444
return Array{T,N}
452445
end

src/mcmc/abstractmcmc.jl

Lines changed: 125 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,98 @@ function _check_model(model::DynamicPPL.Model)
44
new_model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext())
55
return DynamicPPL.check_model(new_model, VarInfo(); error_on_failure=true)
66
end
7-
function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm)
7+
function _check_model(model::DynamicPPL.Model, ::AbstractSampler)
88
return _check_model(model)
99
end
1010

11+
"""
12+
Turing.Inference.init_strategy(spl::AbstractSampler)
13+
14+
Get the default initialization strategy for a given sampler `spl`, i.e. how initial
15+
parameters for sampling are chosen if not specified by the user. By default, this is
16+
`InitFromPrior()`, which samples initial parameters from the prior distribution.
17+
"""
18+
init_strategy(::AbstractSampler) = DynamicPPL.InitFromPrior()
19+
20+
"""
21+
_convert_initial_params(initial_params)
22+
23+
Convert `initial_params` to a `DynamicPPl.AbstractInitStrategy` if it is not already one, or
24+
throw a useful error message.
25+
"""
26+
_convert_initial_params(initial_params::DynamicPPL.AbstractInitStrategy) = initial_params
27+
function _convert_initial_params(nt::NamedTuple)
28+
@info "Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead."
29+
return DynamicPPL.InitFromParams(nt)
30+
end
31+
function _convert_initial_params(d::AbstractDict{<:VarName})
32+
@info "Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead."
33+
return DynamicPPL.InitFromParams(d)
34+
end
35+
function _convert_initial_params(::AbstractVector{<:Real})
36+
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally a `DynamicPPL.AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code."
37+
throw(ArgumentError(errmsg))
38+
end
39+
function _convert_initial_params(@nospecialize(_::Any))
40+
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or a `DynamicPPL.AbstractInitStrategy`."
41+
throw(ArgumentError(errmsg))
42+
end
43+
44+
"""
45+
default_varinfo(rng, model, sampler)
46+
47+
Return a default varinfo object for the given `model` and `sampler`.
48+
The default method for this returns a NTVarInfo (i.e. 'typed varinfo').
49+
"""
50+
function default_varinfo(
51+
rng::Random.AbstractRNG, model::DynamicPPL.Model, ::AbstractSampler
52+
)
53+
# Note that in `AbstractMCMC.step`, the values in the varinfo returned here are
54+
# immediately overwritten by a subsequent call to `init!!`. The reason why we
55+
# _do_ create a varinfo with parameters here (as opposed to simply returning
56+
# an empty `typed_varinfo(VarInfo())`) is to avoid issues where pushing to an empty
57+
# typed VarInfo would fail. This can happen if two VarNames have different types
58+
# but share the same symbol (e.g. `x.a` and `x.b`).
59+
# TODO(mhauru) Fix push!! to work with arbitrary lens types, and then remove the arguments
60+
# and return an empty VarInfo instead.
61+
return DynamicPPL.typed_varinfo(VarInfo(rng, model))
62+
end
63+
1164
#########################################
1265
# Default definitions for the interface #
1366
#########################################
1467

15-
const DEFAULT_CHAIN_TYPE = MCMCChains.Chains
16-
1768
function AbstractMCMC.sample(
18-
model::AbstractModel, alg::InferenceAlgorithm, N::Integer; kwargs...
69+
model::DynamicPPL.Model, spl::AbstractSampler, N::Integer; kwargs...
1970
)
20-
return AbstractMCMC.sample(Random.default_rng(), model, alg, N; kwargs...)
71+
return AbstractMCMC.sample(Random.default_rng(), model, spl, N; kwargs...)
2172
end
2273

2374
function AbstractMCMC.sample(
2475
rng::AbstractRNG,
25-
model::AbstractModel,
26-
alg::InferenceAlgorithm,
76+
model::DynamicPPL.Model,
77+
spl::AbstractSampler,
2778
N::Integer;
79+
initial_params=init_strategy(spl),
2880
check_model::Bool=true,
2981
chain_type=DEFAULT_CHAIN_TYPE,
3082
kwargs...,
3183
)
32-
check_model && _check_model(model, alg)
33-
return AbstractMCMC.sample(rng, model, Sampler(alg), N; chain_type, kwargs...)
84+
check_model && _check_model(model, spl)
85+
return AbstractMCMC.mcmcsample(
86+
rng,
87+
model,
88+
spl,
89+
N;
90+
initial_params=_convert_initial_params(initial_params),
91+
chain_type,
92+
kwargs...,
93+
)
3494
end
3595

3696
function AbstractMCMC.sample(
37-
model::AbstractModel,
38-
alg::InferenceAlgorithm,
97+
model::DynamicPPL.Model,
98+
alg::AbstractSampler,
3999
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
40100
N::Integer,
41101
n_chains::Integer;
@@ -47,18 +107,66 @@ function AbstractMCMC.sample(
47107
end
48108

49109
function AbstractMCMC.sample(
50-
rng::AbstractRNG,
51-
model::AbstractModel,
52-
alg::InferenceAlgorithm,
110+
rng::Random.AbstractRNG,
111+
model::DynamicPPL.Model,
112+
spl::AbstractSampler,
53113
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
54114
N::Integer,
55115
n_chains::Integer;
56116
chain_type=DEFAULT_CHAIN_TYPE,
57117
check_model::Bool=true,
118+
initial_params=fill(init_strategy(spl), n_chains),
58119
kwargs...,
59120
)
60-
check_model && _check_model(model, alg)
61-
return AbstractMCMC.sample(
62-
rng, model, Sampler(alg), ensemble, N, n_chains; chain_type, kwargs...
121+
check_model && _check_model(model, spl)
122+
if !(initial_params isa AbstractVector) || length(initial_params) != n_chains
123+
errmsg = "`initial_params` must be an AbstractVector of length `n_chains`; one element per chain"
124+
throw(ArgumentError(errmsg))
125+
end
126+
return AbstractMCMC.mcmcsample(
127+
rng,
128+
model,
129+
spl,
130+
ensemble,
131+
N,
132+
n_chains;
133+
chain_type,
134+
initial_params=map(_convert_initial_params, initial_params),
135+
kwargs...,
63136
)
64137
end
138+
139+
function loadstate(chain::MCMCChains.Chains)
140+
if !haskey(chain.info, :samplerstate)
141+
throw(
142+
ArgumentError(
143+
"the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`",
144+
),
145+
)
146+
end
147+
return chain.info[:samplerstate]
148+
end
149+
150+
# TODO(penelopeysm): Remove initialstep and generalise MCMC sampling procedures
151+
function initialstep end
152+
153+
function AbstractMCMC.step(
154+
rng::Random.AbstractRNG,
155+
model::DynamicPPL.Model,
156+
spl::AbstractSampler;
157+
initial_params,
158+
kwargs...,
159+
)
160+
# Generate the default varinfo. Note that any parameters inside this varinfo
161+
# will be immediately overwritten by the next call to `init!!`.
162+
vi = default_varinfo(rng, model, spl)
163+
164+
# Fill it with initial parameters. Note that, if `InitFromParams` is used, the
165+
# parameters provided must be in unlinked space (when inserted into the
166+
# varinfo, they will be adjusted to match the linking status of the
167+
# varinfo).
168+
_, vi = DynamicPPL.init!!(rng, model, vi, initial_params)
169+
170+
# Call the actual function that does the first step.
171+
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
172+
end

src/mcmc/algorithm.jl

Lines changed: 0 additions & 16 deletions
This file was deleted.

0 commit comments

Comments
 (0)