Skip to content

Commit 06f5bc6

Browse files
authored
Try #150:
2 parents ac7a649 + e74052a commit 06f5bc6

File tree

21 files changed

+800
-976
lines changed

21 files changed

+800
-976
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313

1414
[compat]
15-
AbstractMCMC = "1"
15+
AbstractMCMC = "2"
1616
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
1717
ChainRulesCore = "0.9.7"
1818
Distributions = "0.23.8"

src/DynamicPPL.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ export AbstractVarInfo,
6161
Sample,
6262
init,
6363
vectorize,
64-
set_resume!,
6564
# Model
6665
Model,
6766
getmissings,

src/model.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,6 @@ See also: [`evaluate_threadsafe`](@ref)
109109
"""
110110
function evaluate_threadunsafe(rng, model, varinfo, sampler, context)
111111
resetlogp!(varinfo)
112-
if has_eval_num(sampler)
113-
sampler.state.eval_num += 1
114-
end
115112
return _evaluate(rng, model, varinfo, sampler, context)
116113
end
117114

@@ -128,9 +125,6 @@ See also: [`evaluate_threadunsafe`](@ref)
128125
"""
129126
function evaluate_threadsafe(rng, model, varinfo, sampler, context)
130127
resetlogp!(varinfo)
131-
if has_eval_num(sampler)
132-
sampler.state.eval_num += 1
133-
end
134128
wrapper = ThreadSafeVarInfo(varinfo)
135129
result = _evaluate(rng, model, wrapper, sampler, context)
136130
setlogp!(varinfo, getlogp(wrapper))

src/prob_macro.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ function logprior(
136136

137137
# When all of model args are on the lhs of |, this is also equal to the logjoint.
138138
model = make_prior_model(left, right, _model)
139-
vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext()) : _vi
139+
vi = _vi === nothing ? VarInfo(deepcopy(model), SampleFromPrior(), PriorContext()) : _vi
140140
foreach(keys(vi.metadata)) do n
141141
@assert n in keys(left) "Variable $n is not defined."
142142
end

src/sampler.jl

Lines changed: 76 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler`
2+
# That would let us use all defaults for Sampler, combine it with other samplers etc.
13
"""
24
Robust initialization method for model parameters in Hamiltonian samplers.
35
"""
@@ -17,55 +19,93 @@ function init(rng, dist, ::SampleFromUniform, n::Int)
1719
return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n)
1820
end
1921

20-
"""
21-
has_eval_num(spl::AbstractSampler)
22-
23-
Check whether `spl` has a field called `eval_num` in its state variables or not.
24-
"""
25-
has_eval_num(spl::SampleFromUniform) = false
26-
has_eval_num(spl::SampleFromPrior) = false
27-
has_eval_num(spl::AbstractSampler) = :eval_num in fieldnames(typeof(spl.state))
28-
29-
"""
30-
An abstract type that mutable sampler state structs inherit from.
31-
"""
32-
abstract type AbstractSamplerState end
33-
3422
"""
3523
Sampler{T}
3624
37-
Generic interface for implementing inference algorithms.
38-
An implementation of an algorithm should include the following:
39-
40-
1. A type specifying the algorithm and its parameters, derived from InferenceAlgorithm
41-
2. A method of `sample` function that produces results of inference, which is where actual inference happens.
42-
43-
DynamicPPL translates models to chunks that call the modelling functions at specified points.
44-
The dispatch is based on the value of a `sampler` variable.
45-
To include a new inference algorithm implements the requirements mentioned above in a separate file,
46-
then include that file at the end of this one.
25+
Generic sampler type for inference algorithms in DynamicPPL.
4726
"""
48-
mutable struct Sampler{T, S<:AbstractSamplerState} <: AbstractSampler
49-
alg :: T
50-
info :: Dict{Symbol, Any} # sampler infomation
51-
selector :: Selector
52-
state :: S
27+
struct Sampler{T} <: AbstractSampler
28+
alg::T
29+
# TODO: remove selector & add space
30+
selector::Selector
5331
end
5432
Sampler(alg) = Sampler(alg, Selector())
5533
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
56-
Sampler(alg, model::Model, s::Selector) = Sampler(alg, model, s)
34+
Sampler(alg, model::Model, s::Selector) = Sampler(alg, s)
5735

5836
# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
59-
60-
function AbstractMCMC.step!(
37+
function AbstractMCMC.step(
6138
rng::Random.AbstractRNG,
6239
model::Model,
6340
sampler::Union{SampleFromUniform,SampleFromPrior},
64-
::Integer,
65-
transition;
41+
state = nothing;
6642
kwargs...
6743
)
6844
vi = VarInfo()
69-
model(vi, sampler)
70-
return vi
45+
model(rng, vi, sampler)
46+
return vi, nothing
47+
end
48+
49+
# initial step: general interface for resuming and
50+
function AbstractMCMC.step(
51+
rng::Random.AbstractRNG,
52+
model::Model,
53+
spl::Sampler;
54+
resume_from = nothing,
55+
kwargs...
56+
)
57+
if resume_from !== nothing
58+
state = loadstate(resume_from)
59+
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
60+
end
61+
62+
# Sample initial values.
63+
_spl = initialsampler(spl)
64+
vi = VarInfo(rng, model, _spl)
65+
66+
# Update the parameters if provided.
67+
if haskey(kwargs, :init_params)
68+
initialize_parameters!(vi, kwargs[:init_params], spl)
69+
70+
# Update joint log probability.
71+
model(rng, vi, _spl)
72+
end
73+
74+
return initialstep(rng, model, spl, vi; kwargs...)
75+
end
76+
77+
function loadstate end
78+
79+
initialsampler(spl::Sampler) = SampleFromPrior()
80+
81+
function initialstep end
82+
83+
function initialize_parameters!(vi::AbstractVarInfo, init_params, spl::Sampler)
84+
@debug "Using passed-in initial variable values" init_params
85+
86+
# Flatten parameters.
87+
init_theta = mapreduce(vcat, init_params) do x
88+
vec([x;])
89+
end
90+
91+
# Get all values.
92+
linked = islinked(vi, spl)
93+
linked && invlink!(vi, spl)
94+
theta = vi[spl]
95+
length(theta) == length(init_theta_flat) ||
96+
error("Provided initial value doesn't match the dimension of the model")
97+
98+
# Update values that are provided.
99+
for i in 1:length(init_theta)
100+
x = init_theta[i]
101+
if x !== missing
102+
theta[i] = x
103+
end
104+
end
105+
106+
# Update in `vi`.
107+
vi[spl] = theta
108+
linked && link!(vi, spl)
109+
110+
return
71111
end

src/varinfo.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,21 +105,30 @@ end
105105
const UntypedVarInfo = VarInfo{<:Metadata}
106106
const TypedVarInfo = VarInfo{<:NamedTuple}
107107

108-
function VarInfo(model::Model, ctx = DefaultContext())
109-
vi = VarInfo()
110-
model(vi, SampleFromPrior(), ctx)
111-
return TypedVarInfo(vi)
112-
end
113-
114108
function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector)
115109
new_vi = deepcopy(old_vi)
116110
new_vi[spl] = x
117111
return new_vi
118112
end
113+
119114
function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector)
120115
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
121116
VarInfo(md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)))
122117
end
118+
119+
VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...)
120+
121+
function VarInfo(
122+
rng::Random.AbstractRNG,
123+
model::Model,
124+
sampler::AbstractSampler = SampleFromPrior(),
125+
context::AbstractContext = DefaultContext(),
126+
)
127+
varinfo = VarInfo()
128+
model(rng, varinfo, sampler, context)
129+
return TypedVarInfo(varinfo)
130+
end
131+
123132
@generated function newmetadata(metadata::NamedTuple{names}, ::Val{space}, x) where {names, space}
124133
exprs = []
125134
offset = :(0)
@@ -1000,7 +1009,6 @@ from a distribution `dist` to `VarInfo` `vi`.
10001009
The sampler is passed here to invalidate its cache where defined.
10011010
"""
10021011
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler)
1003-
spl.info[:cache_updated] = CACHERESET
10041012
return push!(vi, vn, r, dist, spl.selector)
10051013
end
10061014
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler)

test/Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
44
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
5+
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
56
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
67
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
78
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -32,14 +33,15 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
3233
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3334

3435
[compat]
35-
AbstractMCMC = "1.0.1"
36+
AbstractMCMC = "2"
3637
AdvancedHMC = "0.2.25"
37-
AdvancedMH = "0.5.1"
38+
AdvancedMH = "0.5.2"
39+
BangBang = "0.3"
3840
Bijectors = "0.8.2"
3941
Distributions = "0.23.8"
4042
DistributionsAD = "0.6.3"
4143
DocStringExtensions = "0.8.2"
42-
EllipticalSliceSampling = "0.2.2"
44+
EllipticalSliceSampling = "0.3"
4345
ForwardDiff = "0.10.12"
4446
Libtask = "0.4.1"
4547
LogDensityProblems = "0.10.3"

test/Turing/contrib/inference/dynamichmc.jl

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -41,82 +41,82 @@ function DynamicNUTS{AD}(space::Symbol...) where AD
4141
DynamicNUTS{AD, space}()
4242
end
4343

44-
mutable struct DynamicNUTSState{V<:VarInfo, D} <: AbstractSamplerState
44+
struct DynamicNUTSState{V<:AbstractVarInfo,D}
4545
vi::V
4646
draws::Vector{D}
4747
end
4848

4949
DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space
5050

51-
function AbstractMCMC.sample_init!(
51+
# initial step: general interface for resuming and
52+
DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform()
53+
54+
function DynamicPPL.initialstep(
5255
rng::AbstractRNG,
5356
model::Model,
5457
spl::Sampler{<:DynamicNUTS},
55-
N::Integer;
58+
vi::AbstractVarInfo;
5659
kwargs...
5760
)
5861
# Set up lp function.
5962
function _lp(x)
60-
gradient_logp(x, spl.state.vi, model, spl)
63+
gradient_logp(x, vi, model, spl)
6164
end
6265

63-
# Set the parameters to a starting value.
64-
initialize_parameters!(spl; kwargs...)
65-
66-
model(spl.state.vi, SampleFromUniform())
67-
link!(spl.state.vi, spl)
68-
l, dl = _lp(spl.state.vi[spl])
66+
link!(vi, spl)
67+
l, dl = _lp(vi[spl])
6968
while !isfinite(l) || !isfinite(dl)
70-
model(spl.state.vi, SampleFromUniform())
71-
link!(spl.state.vi, spl)
72-
l, dl = _lp(spl.state.vi[spl])
69+
model(vi, SampleFromUniform())
70+
link!(vi, spl)
71+
l, dl = _lp(vi[spl])
7372
end
7473

75-
if spl.selector.tag == :default && !islinked(spl.state.vi, spl)
76-
link!(spl.state.vi, spl)
77-
model(spl.state.vi, spl)
74+
if spl.selector.tag == :default && !islinked(vi, spl)
75+
link!(vi, spl)
76+
model(vi, spl)
7877
end
7978

8079
results = mcmc_with_warmup(
8180
rng,
8281
FunctionLogDensity(
83-
length(spl.state.vi[spl]),
82+
length(vi[spl]),
8483
_lp
8584
),
8685
N
8786
)
87+
draws = results.chain
8888

89-
spl.state.draws = results.chain
89+
# Compute first transition and state.
90+
draw = popfirst!(draws)
91+
vi[spl] = draw
92+
transition = Transition(vi)
93+
state = DynamicNUTSState(vi, draws)
94+
95+
return transition, state
9096
end
9197

9298
function AbstractMCMC.step!(
9399
rng::AbstractRNG,
94100
model::Model,
95101
spl::Sampler{<:DynamicNUTS},
96-
N::Integer,
97-
transition;
102+
state::DynamicNUTSState;
98103
kwargs...
99104
)
105+
# Extract VarInfo object.
106+
vi = state.vi
107+
100108
# Pop the next draw off the vector.
101-
draw = popfirst!(spl.state.draws)
102-
spl.state.vi[spl] = draw
103-
return Transition(spl)
104-
end
109+
draw = popfirst!(state.draws)
110+
vi[spl] = draw
105111

106-
function Sampler(
107-
alg::DynamicNUTS,
108-
model::Model,
109-
s::Selector=Selector()
110-
)
111-
# Construct a state, using a default function.
112-
state = DynamicNUTSState(VarInfo(model), [])
112+
# Compute next transition.
113+
transition = Transition(vi)
113114

114-
# Return a new sampler.
115-
return Sampler(alg, Dict{Symbol,Any}(), s, state)
115+
return transition, state
116116
end
117117

118-
# Disable the progress logging for DynamicHMC, since it has its own progress meter.
119-
function AbstractMCMC.sample(
118+
# Disable the progress logging for DynamicHMC, since it has its own progress meter.
119+
function AbstractMCMC.sample(
120120
rng::AbstractRNG,
121121
model::AbstractModel,
122122
alg::DynamicNUTS,

0 commit comments

Comments
 (0)