Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
70eb771
Update Turing test folder (#173)
devmotion Oct 5, 2020
6c516ac
Update to AbstractMCMC 2
devmotion Jul 4, 2020
35a7341
Fix test errors
devmotion Aug 1, 2020
e7f2ace
Bump test dependency
devmotion Sep 8, 2020
258ac16
Update more test dependencies
devmotion Sep 8, 2020
666f5af
Add BangBang
devmotion Sep 8, 2020
c665b9f
Remove Turing.DEBUG
devmotion Sep 8, 2020
7a8183e
Fix typo
devmotion Sep 8, 2020
fdd8aeb
Fix rebase artifact
devmotion Oct 5, 2020
ce547ae
Fix another remaining DEBUG
devmotion Oct 5, 2020
f7aad24
Fix some MH issues
devmotion Oct 5, 2020
a975d58
Fix HMC
devmotion Oct 5, 2020
feb1d76
Fix ESS
devmotion Oct 6, 2020
1c2b2aa
Fix DynamicHMC
devmotion Oct 6, 2020
c3d0024
Fix resume
devmotion Oct 6, 2020
1b1bed9
Allow VarInfo construction with AbstractContext but without AbstractS…
devmotion Oct 6, 2020
3cfe0de
Fix general inference and prediction
devmotion Oct 6, 2020
2c8573b
Update HMC
devmotion Oct 6, 2020
c54e542
Update src/prob_macro.jl
devmotion Oct 6, 2020
8e16e77
Update test/Turing/inference/is.jl
devmotion Oct 6, 2020
4a828dc
Update test/Turing/inference/gibbs.jl
devmotion Oct 6, 2020
71b6a78
Update test/Turing/inference/gibbs.jl
devmotion Oct 6, 2020
143a507
Update Gibbs sampler
devmotion Oct 6, 2020
1a01329
Add docstrings
devmotion Oct 10, 2020
ae40012
Remove `range` from Gibbs sampler
devmotion Oct 10, 2020
a005c1e
Some fixes
devmotion Nov 12, 2020
522cf7f
Merge branch 'master' into abstractmcmc2
devmotion Nov 12, 2020
941decc
Bump version and remove deprecations
devmotion Nov 13, 2020
ccb3296
Update test/Turing/contrib/inference/dynamichmc.jl
devmotion Nov 16, 2020
6d6d663
Update test/Turing/inference/Inference.jl
devmotion Nov 16, 2020
e3d754f
Apply suggestions
devmotion Nov 22, 2020
fa3deb2
Merge branch 'master' into abstractmcmc2
devmotion Nov 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.9.8"
version = "0.10.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -12,7 +12,7 @@ NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
AbstractMCMC = "1"
AbstractMCMC = "2"
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
ChainRulesCore = "0.9.7"
Distributions = "0.23.8"
Expand Down
3 changes: 0 additions & 3 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ export AbstractVarInfo,
Sample,
init,
vectorize,
set_resume!,
# Model
Model,
getmissings,
Expand Down Expand Up @@ -122,6 +121,4 @@ include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")

include("deprecations.jl")

end # module
22 changes: 0 additions & 22 deletions src/deprecations.jl

This file was deleted.

6 changes: 0 additions & 6 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ See also: [`evaluate_threadsafe`](@ref)
"""
function evaluate_threadunsafe(rng, model, varinfo, sampler, context)
resetlogp!(varinfo)
if has_eval_num(sampler)
sampler.state.eval_num += 1
end
return _evaluate(rng, model, varinfo, sampler, context)
end

Expand All @@ -143,9 +140,6 @@ See also: [`evaluate_threadunsafe`](@ref)
"""
function evaluate_threadsafe(rng, model, varinfo, sampler, context)
resetlogp!(varinfo)
if has_eval_num(sampler)
sampler.state.eval_num += 1
end
wrapper = ThreadSafeVarInfo(varinfo)
result = _evaluate(rng, model, wrapper, sampler, context)
setlogp!(varinfo, getlogp(wrapper))
Expand Down
140 changes: 105 additions & 35 deletions src/sampler.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler`
# That would let us use all defaults for Sampler, combine it with other samplers etc.
"""
Robust initialization method for model parameters in Hamiltonian samplers.
"""
Expand All @@ -17,55 +19,123 @@ function init(rng, dist, ::SampleFromUniform, n::Int)
return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n)
end

"""
has_eval_num(spl::AbstractSampler)

Check whether `spl` has a field called `eval_num` in its state variables or not.
"""
has_eval_num(spl::SampleFromUniform) = false
has_eval_num(spl::SampleFromPrior) = false
has_eval_num(spl::AbstractSampler) = :eval_num in fieldnames(typeof(spl.state))

"""
An abstract type that mutable sampler state structs inherit from.
"""
abstract type AbstractSamplerState end

"""
Sampler{T}

Generic interface for implementing inference algorithms.
An implementation of an algorithm should include the following:

1. A type specifying the algorithm and its parameters, derived from InferenceAlgorithm
2. A method of `sample` function that produces results of inference, which is where actual inference happens.
Generic sampler type for inference algorithms of type `T` in DynamicPPL.

DynamicPPL translates models to chunks that call the modelling functions at specified points.
The dispatch is based on the value of a `sampler` variable.
To include a new inference algorithm implements the requirements mentioned above in a separate file,
then include that file at the end of this one.
`Sampler` should implement the AbstractMCMC interface, and in particular
[`AbstractMCMC.step`](@ref). A default implementation of the initial sampling step is
provided that supports resuming sampling from a previous state and setting initial
parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref)
for loading previous states and actually performing the initial sampling step,
respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref)
that specifies how the initial parameter values are sampled if they are not provided.
By default, values are sampled from the prior.
"""
mutable struct Sampler{T, S<:AbstractSamplerState} <: AbstractSampler
alg :: T
info :: Dict{Symbol, Any} # sampler infomation
selector :: Selector
state :: S
struct Sampler{T} <: AbstractSampler
alg::T
selector::Selector # Can we remove it?
# TODO: add space such that we can integrate existing external samplers in DynamicPPL
end
Sampler(alg) = Sampler(alg, Selector())
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
Sampler(alg, model::Model, s::Selector) = Sampler(alg, model, s)
Sampler(alg, model::Model, s::Selector) = Sampler(alg, s)

# AbstractMCMC interface for SampleFromUniform and SampleFromPrior

function AbstractMCMC.step!(
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::Model,
sampler::Union{SampleFromUniform,SampleFromPrior},
::Integer,
transition;
state = nothing;
kwargs...
)
vi = VarInfo()
model(vi, sampler)
return vi
model(rng, vi, sampler)
return vi, nothing
end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::Model,
spl::Sampler;
resume_from = nothing,
kwargs...
)
if resume_from !== nothing
state = loadstate(resume_from)
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
end

# Sample initial values.
_spl = initialsampler(spl)
vi = VarInfo(rng, model, _spl)

# Update the parameters if provided.
if haskey(kwargs, :init_params)
initialize_parameters!(vi, kwargs[:init_params], spl)

# Update joint log probability.
model(rng, vi, _spl)
end

return initialstep(rng, model, spl, vi; kwargs...)
end

"""
loadstate(data)

Load sampler state from `data`.
"""
function loadstate end

"""
initialsampler(sampler::Sampler)

Return the sampler that is used for generating the initial parameters when sampling with
`sampler`.

By default, it returns an instance of [`SampleFromPrior`](@ref).
"""
initialsampler(spl::Sampler) = SampleFromPrior()

function initialize_parameters!(vi::AbstractVarInfo, init_params, spl::Sampler)
@debug "Using passed-in initial variable values" init_params

# Flatten parameters.
init_theta = mapreduce(vcat, init_params) do x
vec([x;])
end

# Get all values.
linked = islinked(vi, spl)
linked && invlink!(vi, spl)
theta = vi[spl]
length(theta) == length(init_theta_flat) ||
error("Provided initial value doesn't match the dimension of the model")

# Update values that are provided.
for i in 1:length(init_theta)
x = init_theta[i]
if x !== missing
theta[i] = x
end
end

# Update in `vi`.
vi[spl] = theta
linked && link!(vi, spl)

return
end

"""
initialstep(rng, model, sampler, varinfo; kwargs...)

Perform the initial sampling step of the `sampler` for the `model`.

The `varinfo` contains the initial samples, which can be provided by the user or
sampled randomly.
"""
function initialstep end
30 changes: 23 additions & 7 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,38 @@ end
const UntypedVarInfo = VarInfo{<:Metadata}
const TypedVarInfo = VarInfo{<:NamedTuple}

function VarInfo(model::Model, ctx = DefaultContext())
vi = VarInfo()
model(vi, SampleFromPrior(), ctx)
return TypedVarInfo(vi)
end

function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector)
new_vi = deepcopy(old_vi)
new_vi[spl] = x
return new_vi
end

function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector)
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
VarInfo(md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)))
end

function VarInfo(
rng::Random.AbstractRNG,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
context::AbstractContext = DefaultContext(),
)
varinfo = VarInfo()
model(rng, varinfo, sampler, context)
return TypedVarInfo(varinfo)
end
VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...)

# without AbstractSampler
function VarInfo(
rng::Random.AbstractRNG,
model::Model,
context::AbstractContext,
)
return VarInfo(rng, model, SampleFromPrior(), context)
end

@generated function newmetadata(metadata::NamedTuple{names}, ::Val{space}, x) where {names, space}
exprs = []
offset = :(0)
Expand Down Expand Up @@ -1000,7 +1017,6 @@ from a distribution `dist` to `VarInfo` `vi`.
The sampler is passed here to invalidate its cache where defined.
"""
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler)
spl.info[:cache_updated] = CACHERESET
return push!(vi, vn, r, dist, spl.selector)
end
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler)
Expand Down
8 changes: 5 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down Expand Up @@ -31,15 +32,16 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "1.0.1"
AbstractMCMC = "2.1"
AdvancedHMC = "0.2.25"
AdvancedMH = "0.5.1"
AdvancedMH = "0.5.2"
AdvancedVI = "0.1"
BangBang = "0.3"
Bijectors = "0.8.2"
Distributions = "0.23.8"
DistributionsAD = "0.6.3"
DocStringExtensions = "0.8.2"
EllipticalSliceSampling = "0.2.2"
EllipticalSliceSampling = "0.3"
ForwardDiff = "0.10.12"
Libtask = "0.4.1"
LogDensityProblems = "0.10.3"
Expand Down
Loading