Skip to content

Commit 42bfc8a

Browse files
committed
Remove Sampler and move its interface to Turing
1 parent deff3fd commit 42bfc8a

28 files changed

+423
-387
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,4 @@ DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
9292
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
9393

9494
[sources]
95-
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"}
95+
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/no-sampler"}

ext/TuringDynamicHMCExt.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S}
4444
stepsize::S
4545
end
4646

47-
function DynamicPPL.initialstep(
47+
function Turing.Inference.initialstep(
4848
rng::Random.AbstractRNG,
4949
model::DynamicPPL.Model,
50-
spl::DynamicPPL.Sampler{<:DynamicNUTS},
50+
spl::DynamicNUTS,
5151
vi::DynamicPPL.AbstractVarInfo;
5252
kwargs...,
5353
)
@@ -59,7 +59,7 @@ function DynamicPPL.initialstep(
5959

6060
# Define log-density function.
6161
= DynamicPPL.LogDensityFunction(
62-
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype
62+
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
6363
)
6464

6565
# Perform initial step.
@@ -80,14 +80,14 @@ end
8080
function AbstractMCMC.step(
8181
rng::Random.AbstractRNG,
8282
model::DynamicPPL.Model,
83-
spl::DynamicPPL.Sampler{<:DynamicNUTS},
83+
spl::DynamicNUTS,
8484
state::DynamicNUTSState;
8585
kwargs...,
8686
)
8787
# Compute next sample.
8888
vi = state.vi
8989
= state.logdensity
90-
steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize)
90+
steps = DynamicHMC.mcmc_steps(rng, spl.sampler, state.metric, ℓ, state.stepsize)
9191
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
9292

9393
# Create next sample and state.

src/Turing.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ export
160160
maximum_a_posteriori,
161161
maximum_likelihood,
162162
MAP,
163-
MLE
163+
MLE,
164+
# Chain save/resume
165+
loadstate
164166

165167
end

src/mcmc/Inference.jl

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ using DynamicPPL:
2222
getsym,
2323
getdist,
2424
Model,
25-
Sampler,
2625
DefaultContext
2726
using Distributions, Libtask, Bijectors
2827
using DistributionsAD: VectorOfMultivariate
@@ -50,8 +49,7 @@ import Random
5049
import MCMCChains
5150
import StatsBase: predict
5251

53-
export InferenceAlgorithm,
54-
Hamiltonian,
52+
export Hamiltonian,
5553
StaticHamiltonian,
5654
AdaptiveHamiltonian,
5755
MH,
@@ -71,15 +69,16 @@ export InferenceAlgorithm,
7169
RepeatSampler,
7270
Prior,
7371
predict,
74-
externalsampler
72+
externalsampler,
73+
init_strategy,
74+
loadstate
7575

76-
###############################################
77-
# Abstract interface for inference algorithms #
78-
###############################################
79-
80-
const TURING_CHAIN_TYPE = MCMCChains.Chains
76+
#########################################
77+
# Generic AbstractMCMC methods dispatch #
78+
#########################################
8179

82-
include("algorithm.jl")
80+
const DEFAULT_CHAIN_TYPE = MCMCChains.Chains
81+
include("abstractmcmc.jl")
8382

8483
####################
8584
# Sampler wrappers #
@@ -312,8 +311,8 @@ getlogevidence(transitions, sampler, state) = missing
312311
# Default MCMCChains.Chains constructor.
313312
function AbstractMCMC.bundle_samples(
314313
ts::Vector{<:Transition},
315-
model::AbstractModel,
316-
spl::Union{Sampler{<:InferenceAlgorithm},RepeatSampler},
314+
model::DynamicPPL.Model,
315+
spl::AbstractSampler,
317316
state,
318317
chain_type::Type{MCMCChains.Chains};
319318
save_state=false,
@@ -374,8 +373,8 @@ end
374373

375374
function AbstractMCMC.bundle_samples(
376375
ts::Vector{<:Transition},
377-
model::AbstractModel,
378-
spl::Union{Sampler{<:InferenceAlgorithm},RepeatSampler},
376+
model::DynamicPPL.Model,
377+
spl::AbstractSampler,
379378
state,
380379
chain_type::Type{Vector{NamedTuple}};
381380
kwargs...,
@@ -416,7 +415,7 @@ function group_varnames_by_symbol(vns)
416415
return d
417416
end
418417

419-
function save(c::MCMCChains.Chains, spl::Sampler, model, vi, samples)
418+
function save(c::MCMCChains.Chains, spl::AbstractSampler, model, vi, samples)
420419
nt = NamedTuple{(:sampler, :model, :vi, :samples)}((spl, model, deepcopy(vi), samples))
421420
return setinfo(c, merge(nt, c.info))
422421
end
@@ -435,18 +434,12 @@ include("sghmc.jl")
435434
include("emcee.jl")
436435
include("prior.jl")
437436

438-
#################################################
439-
# Generic AbstractMCMC methods dispatch #
440-
#################################################
441-
442-
include("abstractmcmc.jl")
443-
444437
################
445438
# Typing tools #
446439
################
447440

448441
function DynamicPPL.get_matching_type(
449-
spl::Sampler{<:Union{PG,SMC}}, vi, ::Type{TV}
442+
spl::Union{PG,SMC}, vi, ::Type{TV}
450443
) where {T,N,TV<:Array{T,N}}
451444
return Array{T,N}
452445
end

src/mcmc/abstractmcmc.jl

Lines changed: 111 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,98 @@ function _check_model(model::DynamicPPL.Model)
44
new_model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext())
55
return DynamicPPL.check_model(new_model, VarInfo(); error_on_failure=true)
66
end
7-
function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm)
7+
function _check_model(model::DynamicPPL.Model, ::AbstractSampler)
88
return _check_model(model)
99
end
1010

11+
"""
12+
Turing.Inference.init_strategy(spl::AbstractSampler)
13+
14+
Get the default initialization strategy for a given sampler `spl`, i.e. how initial
15+
parameters for sampling are chosen if not specified by the user. By default, this is
16+
`InitFromPrior()`, which samples initial parameters from the prior distribution.
17+
"""
18+
init_strategy(::AbstractSampler) = DynamicPPL.InitFromPrior()
19+
20+
"""
21+
_convert_initial_params(initial_params)
22+
23+
Convert `initial_params` to a `DynamicPPl.AbstractInitStrategy` if it is not already one, or
24+
throw a useful error message.
25+
"""
26+
_convert_initial_params(initial_params::DynamicPPL.AbstractInitStrategy) = initial_params
27+
function _convert_initial_params(nt::NamedTuple)
28+
@info "Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead."
29+
return InitFromParams(nt)
30+
end
31+
function _convert_initial_params(d::AbstractDict{<:VarName})
32+
@info "Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead."
33+
return InitFromParams(d)
34+
end
35+
function _convert_initial_params(::AbstractVector{<:Real})
36+
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally a `DynamicPPL.AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code."
37+
throw(ArgumentError(errmsg))
38+
end
39+
function _convert_initial_params(@nospecialize(_::Any))
40+
errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or a `DynamicPPL.AbstractInitStrategy`."
41+
throw(ArgumentError(errmsg))
42+
end
43+
44+
"""
45+
default_varinfo(rng, model, sampler)
46+
47+
Return a default varinfo object for the given `model` and `sampler`.
48+
The default method for this returns a NTVarInfo (i.e. 'typed varinfo').
49+
"""
50+
function default_varinfo(
51+
rng::Random.AbstractRNG, model::DynamicPPL.Model, ::AbstractSampler
52+
)
53+
# Note that in `AbstractMCMC.step`, the values in the varinfo returned here are
54+
# immediately overwritten by a subsequent call to `init!!`. The reason why we
55+
# _do_ create a varinfo with parameters here (as opposed to simply returning
56+
# an empty `typed_varinfo(VarInfo())`) is to avoid issues where pushing to an empty
57+
# typed VarInfo would fail. This can happen if two VarNames have different types
58+
# but share the same symbol (e.g. `x.a` and `x.b`).
59+
# TODO(mhauru) Fix push!! to work with arbitrary lens types, and then remove the arguments
60+
# and return an empty VarInfo instead.
61+
return DynamicPPL.typed_varinfo(VarInfo(rng, model))
62+
end
63+
1164
#########################################
1265
# Default definitions for the interface #
1366
#########################################
1467

15-
const DEFAULT_CHAIN_TYPE = MCMCChains.Chains
16-
1768
function AbstractMCMC.sample(
18-
model::AbstractModel, alg::InferenceAlgorithm, N::Integer; kwargs...
69+
model::DynamicPPL.Model, spl::AbstractSampler, N::Integer; kwargs...
1970
)
20-
return AbstractMCMC.sample(Random.default_rng(), model, alg, N; kwargs...)
71+
return AbstractMCMC.sample(Random.default_rng(), model, spl, N; kwargs...)
2172
end
2273

2374
function AbstractMCMC.sample(
2475
rng::AbstractRNG,
25-
model::AbstractModel,
26-
alg::InferenceAlgorithm,
76+
model::DynamicPPL.Model,
77+
spl::AbstractSampler,
2778
N::Integer;
79+
initial_params=init_strategy(spl),
2880
check_model::Bool=true,
2981
chain_type=DEFAULT_CHAIN_TYPE,
3082
kwargs...,
3183
)
32-
check_model && _check_model(model, alg)
33-
return AbstractMCMC.sample(rng, model, Sampler(alg), N; chain_type, kwargs...)
84+
check_model && _check_model(model, spl)
85+
return AbstractMCMC.sample(
86+
rng,
87+
model,
88+
spl,
89+
N;
90+
initial_params=_convert_initial_params(initial_params),
91+
chain_type,
92+
kwargs...,
93+
)
3494
end
3595

3696
function AbstractMCMC.sample(
37-
model::AbstractModel,
38-
alg::InferenceAlgorithm,
97+
model::DynamicPPL.Model,
98+
alg::AbstractSampler,
3999
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
40100
N::Integer,
41101
n_chains::Integer;
@@ -47,18 +107,53 @@ function AbstractMCMC.sample(
47107
end
48108

49109
function AbstractMCMC.sample(
50-
rng::AbstractRNG,
51-
model::AbstractModel,
52-
alg::InferenceAlgorithm,
110+
rng::Random.AbstractRNG,
111+
model::DynamicPPL.Model,
112+
spl::AbstractSampler,
53113
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
54114
N::Integer,
55115
n_chains::Integer;
56116
chain_type=DEFAULT_CHAIN_TYPE,
57117
check_model::Bool=true,
118+
initial_params=init_strategy(spl),
58119
kwargs...,
59120
)
60-
check_model && _check_model(model, alg)
121+
check_model && _check_model(model, spl)
61122
return AbstractMCMC.sample(
62-
rng, model, Sampler(alg), ensemble, N, n_chains; chain_type, kwargs...
123+
rng,
124+
model,
125+
spl,
126+
ensemble,
127+
N,
128+
n_chains;
129+
chain_type,
130+
initial_params=map(_convert_initial_params, initial_params),
131+
kwargs...,
63132
)
64133
end
134+
135+
loadstate(c::MCMCChains.Chains) = c.info.samplerstate
136+
137+
# TODO(penelopeysm): Remove initialstep and generalise MCMC sampling procedures
138+
function initialstep end
139+
140+
function AbstractMCMC.step(
141+
rng::Random.AbstractRNG,
142+
model::DynamicPPL.Model,
143+
spl::AbstractSampler;
144+
initial_params,
145+
kwargs...,
146+
)
147+
# Generate the default varinfo. Note that any parameters inside this varinfo
148+
# will be immediately overwritten by the next call to `init!!`.
149+
vi = default_varinfo(rng, model, spl)
150+
151+
# Fill it with initial parameters. Note that, if `InitFromParams` is used, the
152+
# parameters provided must be in unlinked space (when inserted into the
153+
# varinfo, they will be adjusted to match the linking status of the
154+
# varinfo).
155+
_, vi = DynamicPPL.init!!(rng, model, vi, initial_params)
156+
157+
# Call the actual function that does the first step.
158+
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
159+
end

src/mcmc/algorithm.jl

Lines changed: 0 additions & 16 deletions
This file was deleted.

src/mcmc/emcee.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Foreman-Mackey, D., Hogg, D. W., Lang, D., & Goodman, J. (2013).
1313
emcee: The MCMC Hammer. Publications of the Astronomical Society of the
1414
Pacific, 125 (925), 306. https://doi.org/10.1086/670067
1515
"""
16-
struct Emcee{E<:AMH.Ensemble} <: InferenceAlgorithm
16+
struct Emcee{E<:AMH.Ensemble} <: AbstractSampler
1717
ensemble::E
1818
end
1919

@@ -33,23 +33,20 @@ end
3333

3434
# Utility function to tetrieve the number of walkers
3535
_get_n_walkers(e::Emcee) = e.ensemble.n_walkers
36-
_get_n_walkers(spl::Sampler{<:Emcee}) = _get_n_walkers(spl.alg)
3736

3837
# Because Emcee expects n_walkers initialisations, we need to override this
39-
function DynamicPPL.init_strategy(spl::Sampler{<:Emcee})
38+
function Turing.Inference.init_strategy(spl::Emcee)
4039
return fill(DynamicPPL.InitFromPrior(), _get_n_walkers(spl))
4140
end
42-
# TODO(penelopeysm / DPPL 0.38) This is type piracy (!!!) The function
43-
# `_convert_initial_params` will be moved to Turing soon, and this piracy SHOULD be removed
44-
# in https://github.com/TuringLang/Turing.jl/pull/2689, PLEASE make sure it is!
45-
function DynamicPPL._convert_initial_params(
41+
# We also have to explicitly allow this or else it will error...
42+
function Turing.Inference._convert_initial_params(
4643
x::AbstractVector{<:DynamicPPL.AbstractInitStrategy}
4744
)
4845
return x
4946
end
5047

5148
function AbstractMCMC.step(
52-
rng::Random.AbstractRNG, model::Model, spl::Sampler{<:Emcee}; initial_params, kwargs...
49+
rng::Random.AbstractRNG, model::Model, spl::Emcee; initial_params, kwargs...
5350
)
5451
# Sample from the prior
5552
n = _get_n_walkers(spl)
@@ -83,7 +80,7 @@ function AbstractMCMC.step(
8380
end
8481

8582
function AbstractMCMC.step(
86-
rng::AbstractRNG, model::Model, spl::Sampler{<:Emcee}, state::EmceeState; kwargs...
83+
rng::AbstractRNG, model::Model, spl::Emcee, state::EmceeState; kwargs...
8784
)
8885
# Generate a log joint function.
8986
vi = state.vi
@@ -95,7 +92,7 @@ function AbstractMCMC.step(
9592
)
9693

9794
# Compute the next states.
98-
t, states = AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states)
95+
t, states = AbstractMCMC.step(rng, densitymodel, spl.ensemble, state.states)
9996

10097
# Compute the next transition and state.
10198
transition = map(states) do _state
@@ -110,7 +107,7 @@ end
110107
function AbstractMCMC.bundle_samples(
111108
samples::Vector{<:Vector},
112109
model::AbstractModel,
113-
spl::Sampler{<:Emcee},
110+
spl::Emcee,
114111
state::EmceeState,
115112
chain_type::Type{MCMCChains.Chains};
116113
save_state=false,

0 commit comments

Comments
 (0)