Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand Down
3 changes: 3 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
using Distributions
using Bijectors
using MacroTools

import AbstractMCMC
import Random
import ZygoteRules

import Base: Symbol,
Expand Down
62 changes: 22 additions & 40 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,56 +111,31 @@ function observe(spl::Sampler, weight)
error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))")
end

# If parameters exist, they are used and not overwritten.
function assume(
spl::SampleFromPrior,
spl::Union{SampleFromPrior,SampleFromUniform},
dist::Distribution,
vn::VarName,
vi::VarInfo,
)
if haskey(vi, vn)
if is_flagged(vi, vn, "del")
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if spl isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = rand(dist)
r = init(dist, spl)
vi[vn] = vectorize(dist, r)
settrans!(vi, false, vn)
setorder!(vi, vn, get_num_produce(vi))
else
r = vi[vn]
end
else
r = rand(dist)
r = init(dist, spl)
push!(vi, vn, r, dist, spl)
settrans!(vi, false, vn)
end
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
end

# Always overwrites the parameters with new ones.
function assume(
spl::SampleFromUniform,
dist::Distribution,
vn::VarName,
vi::VarInfo,
)
if haskey(vi, vn)
unset_flag!(vi, vn, "del")
r = init(dist)
vi[vn] = vectorize(dist, r)
settrans!(vi, true, vn)
setorder!(vi, vn, get_num_produce(vi))
else
r = init(dist)
push!(vi, vn, r, dist, spl)
settrans!(vi, true, vn)
end
# NOTE: The importance weight is not correctly computed here because
# r is genereated from some uniform distribution which is different from the prior
# acclogp!(vi, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)))

return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
end

function observe(
spl::Union{SampleFromPrior, SampleFromUniform},
dist::Distribution,
Expand Down Expand Up @@ -307,53 +282,60 @@ function get_and_set_val!(
vi::VarInfo,
vns::AbstractVector{<:VarName},
dist::MultivariateDistribution,
spl::AbstractSampler,
spl::Union{SampleFromPrior,SampleFromUniform},
)
n = length(vns)
if haskey(vi, vns[1])
if is_flagged(vi, vns[1], "del")
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
unset_flag!(vi, vns[1], "del")
r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n)
r = init(dist, spl, n)
for i in 1:n
vn = vns[i]
vi[vn] = vectorize(dist, r[:, i])
settrans!(vi, false, vn)
setorder!(vi, vn, get_num_produce(vi))
end
else
r = vi[vns]
r = vi[vns]
end
else
r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n)
r = init(dist, spl, n)
for i in 1:n
push!(vi, vns[i], r[:,i], dist, spl)
settrans!(vi, false, vn)
end
end
return r
end

function get_and_set_val!(
vi::VarInfo,
vns::AbstractArray{<:VarName},
dists::Union{Distribution, AbstractArray{<:Distribution}},
spl::AbstractSampler,
spl::Union{SampleFromPrior,SampleFromUniform},
)
if haskey(vi, vns[1])
if is_flagged(vi, vns[1], "del")
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
unset_flag!(vi, vns[1], "del")
f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist)
f = (vn, dist) -> init(dist, spl)
r = f.(vns, dists)
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
vi[vn] = vectorize(dist, r[i])
settrans!(vi, false, vn)
setorder!(vi, vn, get_num_produce(vi))
end
else
r = reshape(vi[vec(vns)], size(vns))
r = reshape(vi[vec(vns)], size(vns))
end
else
f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist)
f = (vn, dist) -> init(dist, spl)
r = f.(vns, dists)
push!.(Ref(vi), vns, r, dists, Ref(spl))
settrans!.(Ref(vi), false, vns)
end
return r
end
Expand Down
24 changes: 24 additions & 0 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ struct SampleFromPrior <: AbstractSampler end

getspace(::Union{SampleFromPrior, SampleFromUniform}) = ()

# Initializations.
init(dist, ::SampleFromPrior) = rand(dist)
init(dist, ::SampleFromUniform) = istransformable(dist) ? inittrans(dist) : rand(dist)

init(dist, ::SampleFromPrior, n::Int) = rand(dist, n)
function init(dist, ::SampleFromUniform, n::Int)
return istransformable(dist) ? inittrans(dist, n) : rand(dist, n)
end

"""
has_eval_num(spl::AbstractSampler)

Expand Down Expand Up @@ -43,3 +52,18 @@ end
Sampler(alg) = Sampler(alg, Selector())
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
Sampler(alg, model::Model, s::Selector) = Sampler(alg, model, s)

# AbstractMCMC interface for SampleFromUniform and SampleFromPrior

function AbstractMCMC.step!(
rng::Random.AbstractRNG,
model::Model,
sampler::Union{SampleFromUniform,SampleFromPrior},
::Integer,
transition;
kwargs...
)
vi = VarInfo()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should store a VarInfo(model) inside the SampleFromPrior struct to use the TypedVarInfo.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it's worth constructing a TypedVarInfo object from the UntypedVarInfo object in every sample step. If you use Turing, by default a Chains object will be constructed from the vector of UntypedVarInfo anyways, and otherwise you can still convert the vector to a vector of TypedVarInfo afterwards if needed.

model(vi, sampler)
return vi
end
22 changes: 8 additions & 14 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,37 +114,31 @@ function reconstruct!(r, d::MultivariateDistribution, val::AbstractVector, n::In
return r
end


# ROBUST INITIALISATIONS
# Uniform rand with range 2; ref: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
# Uniform random numbers with range 4 for robust initializations
# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
randrealuni() = 4 * rand() - 2
randrealuni(args...) = 4 .* rand(args...) .- 2

const Transformable = Union{TransformDistribution, SimplexDistribution, PDMatDistribution}

const Transformable = Union{PositiveDistribution,UnitDistribution,TransformDistribution,
SimplexDistribution,PDMatDistribution}
istransformable(dist) = false
istransformable(::Transformable) = true

#################################
# Single-sample initialisations #
#################################

init(dist::Transformable) = inittrans(dist)
init(dist::Distribution) = rand(dist)

inittrans(dist::UnivariateDistribution) = invlink(dist, randrealuni())
inittrans(dist::MultivariateDistribution) = invlink(dist, randrealuni(size(dist)[1]))
inittrans(dist::MultivariateDistribution) = invlink(dist, randrealuni(size(dist, 1)))
inittrans(dist::MatrixDistribution) = invlink(dist, randrealuni(size(dist)...))


################################
# Multi-sample initialisations #
################################

init(dist::Transformable, n::Int) = inittrans(dist, n)
init(dist::Distribution, n::Int) = rand(dist, n)

inittrans(dist::UnivariateDistribution, n::Int) = invlink(dist, randrealuni(n))
function inittrans(dist::MultivariateDistribution, n::Int)
return invlink(dist, randrealuni(size(dist)[1], n))
return invlink(dist, randrealuni(size(dist, 1), n))
end
function inittrans(dist::MatrixDistribution, n::Int)
return invlink(dist, [randrealuni(size(dist)...) for _ in 1:n])
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ turnprogress(false)
include("utils.jl")
include("compiler.jl")
include("varinfo.jl")
include("sampler.jl")
include("prob_macro.jl")
include("independence.jl")
end
42 changes: 42 additions & 0 deletions test/sampler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using DynamicPPL
using Distributions
using AbstractMCMC: sample

using Random
using Statistics
using Test

Random.seed!(100)

@testset "AbstractMCMC interface" begin
@model gdemo(x, y) = begin
s ~ InverseGamma(2, 3)
m ~ Normal(2.0, sqrt(s))
x ~ Normal(m, sqrt(s))
y ~ Normal(m, sqrt(s))
end

model = gdemo(1.0, 2.0)
N = 1_000

chains = sample(model, SampleFromPrior(), N; progress = false)
@test chains isa Vector{<:VarInfo}
@test length(chains) == N

# Expected value of ``X`` where ``X ~ N(2, ...)`` is 2.
@test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1

# Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3.
@test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.1

chains = sample(model, SampleFromUniform(), N; progress = false)
@test chains isa Vector{<:VarInfo}
@test length(chains) == N

# Expected value of ``X`` where ``X ~ U[-2, 2]`` is ≈ 0.
@test mean(vi[@varname(m)] for vi in chains) ≈ 0 atol = 0.1

# Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8.
@test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1
end

23 changes: 12 additions & 11 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,34 +165,35 @@ include(dir*"/test/test_utils/AllUtils.jl")

vi = VarInfo()
meta = vi.metadata

model(vi, SampleFromUniform())
@test all(x -> !istrans(vi, x), meta.vns)

@test all(x -> istrans(vi, x), meta.vns)
alg = HMC(0.1, 5)
spl = Sampler(alg, model)
v = copy(meta.vals)
invlink!(vi, spl)
@test all(x -> ~istrans(vi, x), meta.vns)
link!(vi, spl)
@test all(x -> istrans(vi, x), meta.vns)
@test norm(meta.vals - v) <= 1e-6
invlink!(vi, spl)
@test all(x -> !istrans(vi, x), meta.vns)
@test meta.vals == v

vi = TypedVarInfo(vi)
meta = vi.metadata
alg = HMC(0.1, 5)
spl = Sampler(alg, model)
@test all(x -> istrans(vi, x), meta.s.vns)
@test all(x -> istrans(vi, x), meta.m.vns)
@test all(x -> !istrans(vi, x), meta.s.vns)
@test all(x -> !istrans(vi, x), meta.m.vns)
v_s = copy(meta.s.vals)
v_m = copy(meta.m.vals)
invlink!(vi, spl)
@test all(x -> ~istrans(vi, x), meta.s.vns)
@test all(x -> ~istrans(vi, x), meta.m.vns)
link!(vi, spl)
@test all(x -> istrans(vi, x), meta.s.vns)
@test all(x -> istrans(vi, x), meta.m.vns)
@test norm(meta.s.vals - v_s) <= 1e-6
@test norm(meta.m.vals - v_m) <= 1e-6
invlink!(vi, spl)
@test all(x -> ~istrans(vi, x), meta.s.vns)
@test all(x -> ~istrans(vi, x), meta.m.vns)
@test meta.s.vals == v_s
@test meta.m.vals == v_m
end
@testset "setgid!" begin
vi = VarInfo()
Expand Down