Skip to content

Commit 20e81cc

Browse files
devmotioncpfiffer
authored andcommitted
Add elliptical slice sampling algorithm (#1000)
* Add elliptical slice sampling algorithm * Allow nonzero mean and update according to comments * Update error message Co-Authored-By: Cameron Pfiffer <cpfiffer@gmail.com> * Add _getvns methods and remove static parameters * Remove Nothing sampler again * Update implementation of elliptical slice sampling * Remove more nothing * Fix tests * Remove some Unicode characters * Overload tilde and dot_tilde and test dot notation * Fix error * Fix test errors on Julia 1.0 Co-authored-by: Cameron Pfiffer <cpfiffer@gmail.com>
1 parent 5db142b commit 20e81cc

File tree

8 files changed

+268
-15
lines changed

8 files changed

+268
-15
lines changed

src/Turing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ export @model, # modelling
214214
@sampler,
215215

216216
MH, # classic sampling
217+
ESS,
217218
Gibbs,
218219

219220
HMC, # Hamiltonian-like sampling

src/core/RandomVariables.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ export VarName,
3434
resetlogp!,
3535
set_retained_vns_del_by_spl!,
3636
is_flagged,
37+
set_flag!,
3738
unset_flag!,
3839
setgid!,
3940
updategid!,
@@ -495,11 +496,10 @@ end
495496
end
496497

497498
# Get all vns of variables belonging to spl
498-
_getvns(vi::UntypedVarInfo, spl::AbstractSampler) = view(vi.metadata.vns, _getidcs(vi, spl))
499-
function _getvns(vi::TypedVarInfo, spl::AbstractSampler)
500-
# Get a NamedTuple of the indices of variables belonging to `spl`, one entry for each symbol
501-
idcs = _getidcs(vi, spl)
502-
return _getvns(vi.metadata, idcs)
499+
_getvns(vi::AbstractVarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl)))
500+
_getvns(vi::UntypedVarInfo, s::Selector, space) = view(vi.metadata.vns, _getidcs(vi, s, space))
501+
function _getvns(vi::TypedVarInfo, s::Selector, space)
502+
return _getvns(vi.metadata, _getidcs(vi, s, space))
503503
end
504504
# Get a NamedTuple for all the `vns` of indices `idcs`, one entry for each symbol
505505
@generated function _getvns(metadata, idcs::NamedTuple{names}) where {names}
@@ -525,7 +525,7 @@ end
525525
#end
526526
end
527527
# Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space`
528-
@inline function _getranges(vi::AbstractVarInfo, s::Selector, space = Val(()))
528+
@inline function _getranges(vi::AbstractVarInfo, s::Selector, space)
529529
return _getranges(vi, _getidcs(vi, s, space))
530530
end
531531
@inline function _getranges(vi::UntypedVarInfo, idcs::Vector{Int})

src/inference/Inference.jl

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module Inference
33
using ..Core, ..Core.RandomVariables, ..Utilities
44
using ..Core.RandomVariables: Metadata, _tail, VarInfo, TypedVarInfo,
55
islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize,
6-
settrans!
6+
settrans!, _getvns, getdist
77
using ..Core: split_var_str
88
using Distributions, Libtask, Bijectors
99
using ProgressMeter, LinearAlgebra
@@ -13,7 +13,7 @@ using ..Turing: Model, runmodel!, Turing,
1313
Selector, AbstractSamplerState, DefaultContext, PriorContext,
1414
LikelihoodContext, MiniBatchContext, NamedDist, NoDist
1515
using StatsFuns: logsumexp
16-
using Random: GLOBAL_RNG, AbstractRNG
16+
using Random: GLOBAL_RNG, AbstractRNG, randexp
1717
using AbstractMCMC
1818

1919
import MCMCChains: Chains
@@ -33,6 +33,7 @@ export InferenceAlgorithm,
3333
SampleFromUniform,
3434
SampleFromPrior,
3535
MH,
36+
ESS,
3637
Gibbs, # classic sampling
3738
HMC,
3839
SGLD,
@@ -274,8 +275,8 @@ function _params_to_array(ts::Vector{T}, spl::Sampler) where {T<:AbstractTransit
274275
end
275276
push!(dicts, d)
276277
end
277-
278-
# Convert the set to an ordered vector so the parameter ordering
278+
279+
# Convert the set to an ordered vector so the parameter ordering
279280
# is deterministic.
280281
ordered_names = collect(names)
281282
vals = Matrix{Union{Real, Missing}}(undef, length(ts), length(ordered_names))
@@ -486,6 +487,7 @@ end
486487
# Concrete algorithm implementations. #
487488
#######################################
488489

490+
include("ess.jl")
489491
include("hmc.jl")
490492
include("mh.jl")
491493
include("is.jl")
@@ -498,7 +500,7 @@ include("../contrib/inference/AdvancedSMCExtensions.jl")
498500
# Typing tools #
499501
################
500502

501-
for alg in (:SMC, :PG, :PMMH, :IPMCMC, :MH, :IS, :Gibbs)
503+
for alg in (:SMC, :PG, :PMMH, :IPMCMC, :MH, :IS, :ESS, :Gibbs)
502504
@eval getspace(::$alg{space}) where {space} = space
503505
end
504506
for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC)
@@ -635,7 +637,14 @@ function assume(
635637
vi::VarInfo,
636638
)
637639
if haskey(vi, vn)
640+
if is_flagged(vi, vn, "del")
641+
unset_flag!(vi, vn, "del")
642+
r = spl isa SampleFromUniform ? init(dist) : rand(dist)
643+
vi[vn] = vectorize(dist, r)
644+
setorder!(vi, vn, vi.num_produce)
645+
else
638646
r = vi[vn]
647+
end
639648
else
640649
r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist)
641650
push!(vi, vn, r, dist, spl)
@@ -792,9 +801,19 @@ function get_and_set_val!(
792801
)
793802
n = length(vns)
794803
if haskey(vi, vns[1])
804+
if is_flagged(vi, vns[1], "del")
805+
unset_flag!(vi, vns[1], "del")
806+
r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n)
807+
for i in 1:n
808+
vn = vns[i]
809+
vi[vn] = vectorize(dist, r[:, i])
810+
setorder!(vi, vn, vi.num_produce)
811+
end
812+
else
795813
r = vi[vns]
814+
end
796815
else
797-
r = isa(spl, SampleFromUniform) ? init(dist, n) : rand(dist, n)
816+
r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n)
798817
for i in 1:n
799818
push!(vi, vns[i], r[:,i], dist, spl)
800819
end
@@ -808,9 +827,21 @@ function get_and_set_val!(
808827
spl::AbstractSampler,
809828
)
810829
if haskey(vi, vns[1])
830+
if is_flagged(vi, vns[1], "del")
831+
unset_flag!(vi, vns[1], "del")
832+
f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist)
833+
r = f.(vns, dists)
834+
for i in eachindex(vns)
835+
vn = vns[i]
836+
dist = dists isa AbstractArray ? dists[i] : dists
837+
vi[vn] = vectorize(dist, r[i])
838+
setorder!(vi, vn, vi.num_produce)
839+
end
840+
else
811841
r = reshape(vi[vec(vns)], size(vns))
842+
end
812843
else
813-
f(vn, dist) = isa(spl, SampleFromUniform) ? init(dist) : rand(dist)
844+
f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist)
814845
r = f.(vns, dists)
815846
push!.(Ref(vi), vns, r, dists, Ref(spl))
816847
end

src/inference/ess.jl

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""
2+
ESS
3+
4+
Elliptical slice sampling algorithm.
5+
6+
# Examples
7+
```jldoctest; setup = :(Random.seed!(1))
8+
julia> @model gdemo(x) = begin
9+
m ~ Normal()
10+
x ~ Normal(m, 0.5)
11+
end
12+
gdemo (generic function with 2 methods)
13+
14+
julia> sample(gdemo(1.0), ESS(), 1_000) |> mean
15+
Mean
16+
17+
│ Row │ parameters │ mean │
18+
│ │ Symbol │ Float64 │
19+
├─────┼────────────┼──────────┤
20+
│ 1 │ m │ 0.824853 │
21+
```
22+
"""
23+
struct ESS{space} <: InferenceAlgorithm end
24+
25+
ESS() = ESS{()}()
26+
ESS(space::Symbol) = ESS{(space,)}()
27+
28+
mutable struct ESSState{V<:VarInfo} <: AbstractSamplerState
29+
vi::V
30+
end
31+
32+
function Sampler(alg::ESS, model::Model, s::Selector)
33+
# sanity check
34+
vi = VarInfo(model)
35+
space = getspace(alg)
36+
vns = _getvns(vi, s, Val(space))
37+
length(vns) == 1 ||
38+
error("[ESS] does only support one variable ($(length(vns)) variables specified)")
39+
for vn in vns[1]
40+
dist = getdist(vi, vn)
41+
isgaussian(dist) ||
42+
error("[ESS] only supports Gaussian prior distributions")
43+
end
44+
45+
state = ESSState(vi)
46+
info = Dict{Symbol, Any}()
47+
48+
return Sampler(alg, info, s, state)
49+
end
50+
51+
isgaussian(dist) = false
52+
isgaussian(::Normal) = true
53+
isgaussian(::NormalCanon) = true
54+
isgaussian(::AbstractMvNormal) = true
55+
56+
# always accept in the first step
57+
function step!(::AbstractRNG, model::Model, spl::Sampler{<:ESS}, ::Integer; kwargs...)
58+
return Transition(spl)
59+
end
60+
61+
function step!(
62+
rng::AbstractRNG,
63+
model::Model,
64+
spl::Sampler{<:ESS},
65+
::Integer,
66+
::Transition;
67+
kwargs...
68+
)
69+
# obtain mean of distribution
70+
vi = spl.state.vi
71+
vns = _getvns(vi, spl)
72+
μ = mapreduce(vcat, vns[1]) do vn
73+
dist = getdist(vi, vn)
74+
vectorize(dist, mean(dist))
75+
end
76+
77+
# obtain previous sample
78+
f = vi[spl]
79+
80+
# recompute log-likelihood in logp
81+
if spl.selector.tag !== :default
82+
runmodel!(model, vi, spl)
83+
end
84+
85+
# sample log-likelihood threshold for the next sample
86+
threshold = getlogp(vi) - randexp(rng)
87+
88+
# sample from the prior
89+
set_flag!(vi, vns[1][1], "del")
90+
runmodel!(model, vi, spl)
91+
ν = vi[spl]
92+
93+
# sample initial angle
94+
θ = 2 * π * rand(rng)
95+
θmin = θ - 2 * π
96+
θmax = θ
97+
98+
while true
99+
# compute proposal and apply correction for distributions with nonzero mean
100+
sinθ, cosθ = sincos(θ)
101+
a = 1 - (sinθ + cosθ)
102+
vi[spl] = @. f * cosθ + ν * sinθ + μ * a
103+
104+
# recompute log-likelihood and check if threshold is reached
105+
runmodel!(model, vi, spl)
106+
if getlogp(vi) > threshold
107+
break
108+
end
109+
110+
# shrink the bracket
111+
if θ < 0
112+
θmin = θ
113+
else
114+
θmax = θ
115+
end
116+
117+
# sample new angle
118+
θ = θmin + rand(rng) * (θmax - θmin)
119+
end
120+
121+
return Transition(spl)
122+
end
123+
124+
function tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi)
125+
if vn in getspace(sampler)
126+
return tilde(LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi)
127+
else
128+
return tilde(ctx, SampleFromPrior(), right, vn, inds, vi)
129+
end
130+
end
131+
132+
function tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
133+
return tilde(ctx, SampleFromPrior(), right, left, vi)
134+
end
135+
136+
function dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vn::VarName, inds, vi)
137+
if vn in getspace(sampler)
138+
return dot_tilde(LikelihoodContext(), SampleFromPrior(), right, left, vn, inds, vi)
139+
else
140+
return dot_tilde(ctx, SampleFromPrior(), right, left, vn, inds, vi)
141+
end
142+
end
143+
144+
function dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
145+
return dot_tilde(ctx, SampleFromPrior(), right, left, vi)
146+
end

src/inference/gibbs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
### Gibbs samplers / compositional samplers.
33
###
44

5-
const GibbsComponent = Union{Hamiltonian,MH,PG}
5+
const GibbsComponent = Union{Hamiltonian,MH,ESS,PG}
66

77
"""
88
Gibbs(algs...)
@@ -150,7 +150,7 @@ function step!(
150150
# Uncomment when developing thinning functionality.
151151
# Retrieve symbol to store this subsample.
152152
# symbol_id = Symbol(local_spl.selector.gid)
153-
153+
154154
# # Store the subsample.
155155
# spl.state.subsamples[symbol_id][] = trans
156156

test/inference/ess.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
using Turing, Random, Test
2+
3+
dir = splitdir(splitdir(pathof(Turing))[1])[1]
4+
include(dir*"/test/test_utils/AllUtils.jl")
5+
6+
@testset "ESS" begin
7+
@model demo(x) = begin
8+
m ~ Normal()
9+
x ~ Normal(m, 0.5)
10+
end
11+
demo_default = demo(1.0)
12+
13+
@model demodot(x) = begin
14+
m = Vector{Float64}(undef, 2)
15+
@. m ~ Normal()
16+
x ~ Normal(m[2], 0.5)
17+
end
18+
demodot_default = demodot(1.0)
19+
20+
@turing_testset "ESS constructor" begin
21+
Random.seed!(0)
22+
N = 500
23+
s1 = ESS()
24+
s2 = ESS(:m)
25+
s3 = Gibbs(ESS(:m), MH(:s))
26+
27+
c1 = sample(demo_default, s1, N)
28+
c2 = sample(demo_default, s2, N)
29+
c3 = sample(demodot_default, s1, N)
30+
c4 = sample(demodot_default, s2, N)
31+
c5 = sample(gdemo_default, s3, N)
32+
end
33+
34+
@numerical_testset "ESS inference" begin
35+
Random.seed!(1)
36+
chain = sample(demo_default, ESS(), 5_000)
37+
check_numerical(chain, [:m], [0.8], atol = 0.1)
38+
39+
Random.seed!(1)
40+
chain = sample(demodot_default, ESS(), 5_000)
41+
check_numerical(chain, ["m[1]", "m[2]"], [0.0, 0.8], atol = 0.1)
42+
43+
Random.seed!(100)
44+
alg = Gibbs(
45+
CSMC(15, :s),
46+
ESS(:m))
47+
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
48+
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1)
49+
50+
# MoGtest
51+
Random.seed!(125)
52+
alg = Gibbs(
53+
CSMC(15, :z1, :z2, :z3, :z4),
54+
ESS(:mu1), ESS(:mu2))
55+
chain = sample(MoGtest_default, alg, 6000)
56+
check_MoGtest_default(chain, atol = 0.1)
57+
end
58+
end

0 commit comments

Comments
 (0)