-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from devmotion/abstractmcmc
Switch to AbstractMCMC interface
- Loading branch information
Showing
12 changed files
with
332 additions
and
210 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,17 @@ | ||
module EllipticalSliceSampling | ||
|
||
using ArrayInterface | ||
using Distributions | ||
using Parameters | ||
using ProgressLogging | ||
import AbstractMCMC | ||
import ArrayInterface | ||
import Distributions | ||
|
||
using Random | ||
import Random | ||
import Statistics | ||
|
||
export ESS_mcmc, ESS_mcmc_sampler | ||
export ESS_mcmc | ||
|
||
include("utils.jl") | ||
include("types.jl") | ||
include("iterator.jl") | ||
include("abstractmcmc.jl") | ||
include("model.jl") | ||
include("distributions.jl") | ||
include("interface.jl") | ||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# elliptical slice sampler | ||
struct EllipticalSliceSampler <: AbstractMCMC.AbstractSampler end | ||
|
||
# state of the elliptical slice sampler | ||
struct EllipticalSliceSamplerState{S,L} | ||
"Sample of the elliptical slice sampler." | ||
sample::S | ||
"Log-likelihood of the sample." | ||
loglikelihood::L | ||
end | ||
|
||
# first step of the elliptical slice sampler | ||
function AbstractMCMC.step!( | ||
rng::Random.AbstractRNG, | ||
model::AbstractMCMC.AbstractModel, | ||
::EllipticalSliceSampler, | ||
N::Integer, | ||
::Nothing; | ||
kwargs... | ||
) | ||
# initial sample from the Gaussian prior | ||
f = initial_sample(rng, model) | ||
|
||
# compute log-likelihood of the initial sample | ||
loglikelihood = Distributions.loglikelihood(model, f) | ||
|
||
return EllipticalSliceSamplerState(f, loglikelihood) | ||
end | ||
|
||
# subsequent steps of the elliptical slice sampler | ||
function AbstractMCMC.step!( | ||
rng::Random.AbstractRNG, | ||
model::AbstractMCMC.AbstractModel, | ||
::EllipticalSliceSampler, | ||
N::Integer, | ||
state::EllipticalSliceSamplerState; | ||
kwargs... | ||
) | ||
# sample from Gaussian prior | ||
ν = sample_prior(rng, model) | ||
|
||
# sample log-likelihood threshold | ||
loglikelihood = state.loglikelihood | ||
threshold = loglikelihood - Random.randexp(rng) | ||
|
||
# sample initial angle | ||
θ = 2 * π * rand(rng) | ||
θmin = θ - 2 * π | ||
θmax = θ | ||
|
||
# compute the proposal | ||
f = state.sample | ||
fnext = proposal(model, f, ν, θ) | ||
|
||
# compute the log-likelihood of the proposal | ||
loglikelihood = Distributions.loglikelihood(model, fnext) | ||
|
||
# stop if the log-likelihood threshold is reached | ||
while loglikelihood < threshold | ||
# shrink the bracket | ||
if θ < zero(θ) | ||
θmin = θ | ||
else | ||
θmax = θ | ||
end | ||
|
||
# sample angle | ||
θ = θmin + rand(rng) * (θmax - θmin) | ||
|
||
# recompute the proposal | ||
if ArrayInterface.ismutable(fnext) | ||
proposal!(fnext, model, f, ν, θ) | ||
else | ||
fnext = proposal(model, f, ν, θ) | ||
end | ||
|
||
# compute the log-likelihood of the proposal | ||
loglikelihood = Distributions.loglikelihood(model, fnext) | ||
end | ||
|
||
return EllipticalSliceSamplerState(fnext, loglikelihood) | ||
end | ||
|
||
# only save the samples by default | ||
function AbstractMCMC.transitions_init( | ||
state::EllipticalSliceSamplerState, | ||
model::AbstractMCMC.AbstractModel, | ||
::EllipticalSliceSampler, | ||
N::Integer; | ||
kwargs... | ||
) | ||
return Vector{typeof(state.sample)}(undef, N) | ||
end | ||
|
||
function AbstractMCMC.transitions_save!( | ||
samples::AbstractVector{S}, | ||
iteration::Integer, | ||
state::EllipticalSliceSamplerState{S}, | ||
model::AbstractMCMC.AbstractModel, | ||
::EllipticalSliceSampler, | ||
N::Integer; | ||
kwargs... | ||
) where S | ||
samples[iteration] = state.sample | ||
return | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# define the element type of the samples | ||
randtype(::Type{D}) where {D<:Distributions.MultivariateDistribution} = Vector{eltype(D)} | ||
randtype(::Type{D}) where {D<:Distributions.MatrixDistribution} = Matrix{eltype(D)} | ||
function randtype( | ||
::Type{D} | ||
) where {D<:Distributions.Sampleable{Distributions.Multivariate}} | ||
return Vector{eltype(D)} | ||
end | ||
function randtype( | ||
::Type{D} | ||
) where {D<:Distributions.Sampleable{Distributions.Matrixvariate}} | ||
return Matrix{eltype(D)} | ||
end | ||
|
||
# define trait for Gaussian distributions | ||
isgaussian(::Type{<:Distributions.Normal}) = true | ||
isgaussian(::Type{<:Distributions.NormalCanon}) = true | ||
isgaussian(::Type{<:Distributions.AbstractMvNormal}) = true | ||
|
||
# compute the proposal of the next sample | ||
function proposal(prior::Distributions.Normal, f::Real, ν::Real, θ) | ||
sinθ, cosθ = sincos(θ) | ||
μ = prior.μ | ||
if iszero(μ) | ||
return cosθ * f + sinθ * ν | ||
else | ||
a = 1 - (sinθ + cosθ) | ||
return cosθ * f + sinθ * ν + a * μ | ||
end | ||
end | ||
|
||
function proposal( | ||
prior::Distributions.MvNormal, | ||
f::AbstractVector{<:Real}, | ||
ν::AbstractVector{<:Real}, | ||
θ | ||
) | ||
sinθ, cosθ = sincos(θ) | ||
a = 1 - (sinθ + cosθ) | ||
return @. cosθ * f + sinθ * ν + a * prior.μ | ||
end | ||
|
||
function proposal!( | ||
out::AbstractVector{<:Real}, | ||
prior::Distributions.MvNormal, | ||
f::AbstractVector{<:Real}, | ||
ν::AbstractVector{<:Real}, | ||
θ | ||
) | ||
sinθ, cosθ = sincos(θ) | ||
a = 1 - (sinθ + cosθ) | ||
@. out = cosθ * f + sinθ * ν + a * prior.μ | ||
return out | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,72 @@ | ||
# perform elliptical slice sampling for a fixed number of iterations | ||
ESS_mcmc(prior, loglikelihood, N::Int; kwargs...) = | ||
ESS_mcmc(Random.GLOBAL_RNG, prior, loglikelihood, N; kwargs...) | ||
# public interface | ||
|
||
function ESS_mcmc(rng::AbstractRNG, prior, loglikelihood, N::Int; burnin::Int = 0) | ||
# define the internal model | ||
""" | ||
ESS_mcmc([rng, ]prior, loglikelihood, N; kwargs...) | ||
Create a Markov chain of `N` samples for a model with given `prior` and `loglikelihood` | ||
functions using the elliptical slice sampling algorithm. | ||
""" | ||
function ESS_mcmc( | ||
rng::Random.AbstractRNG, | ||
prior, | ||
loglikelihood, | ||
N::Integer; | ||
kwargs... | ||
) | ||
model = Model(prior, loglikelihood) | ||
return AbstractMCMC.sample(rng, model, EllipticalSliceSampler(), N; kwargs...) | ||
end | ||
|
||
function ESS_mcmc(prior, loglikelihood, N::Integer; kwargs...) | ||
return ESS_mcmc(Random.GLOBAL_RNG, prior, loglikelihood, N; kwargs...) | ||
end | ||
|
||
# private interface | ||
|
||
""" | ||
initial_sample(rng, model) | ||
Return the initial sample for the `model` using the random number generator `rng`. | ||
# create the sampler | ||
sampler = EllipticalSliceSampler(rng, model) | ||
|
||
# create MCMC chain | ||
chain = Vector{eltype(sampler)}(undef, N) | ||
niters = N + burnin | ||
@withprogress name = "Performing elliptical slice sampling" begin | ||
# discard burnin phase | ||
for (i, _) in zip(1:burnin, sampler) | ||
@logprogress i / niters | ||
end | ||
|
||
for (i, f) in zip(1:N, sampler) | ||
@inbounds chain[i] = f | ||
@logprogress (i + burnin) / niters | ||
end | ||
end | ||
|
||
chain | ||
By default, sample from the prior by calling [`sample_prior(rng, model)`](@ref). | ||
""" | ||
function initial_sample(rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel) | ||
return sample_prior(rng, model) | ||
end | ||
|
||
# create an elliptical slice sampler | ||
ESS_mcmc_sampler(prior, loglikelihood) = ESS_mcmc_sampler(Random.GLOBAL_RNG, prior, loglikelihood) | ||
ESS_mcmc_sampler(rng::AbstractRNG, prior, loglikelihood) = | ||
EllipticalSliceSampler(rng, Model(prior, loglikelihood)) | ||
""" | ||
sample_prior(rng, model) | ||
Sample from the prior of the `model` using the random number generator `rng`. | ||
""" | ||
function sample_prior(::Random.AbstractRNG, ::AbstractMCMC.AbstractModel) end | ||
|
||
""" | ||
proposal(model, f, ν, θ) | ||
Compute the proposal for the next sample in the elliptical slice sampling algorithm for the | ||
`model` from the previous sample `f`, the sample `ν` from the Gaussian prior, and the angle | ||
`θ`. | ||
Mathematically, the proposal can be computed as | ||
```math | ||
\\cos θ f + ν \\sin θ ν + μ (1 - \\sin θ + \\cos θ), | ||
``` | ||
where ``μ`` is the mean of the Gaussian prior. | ||
""" | ||
function proposal(model::AbstractMCMC.AbstractModel, f, ν, θ) end | ||
|
||
""" | ||
proposal!(out, model, f, ν, θ) | ||
Compute the proposal for the next sample in the elliptical slice sampling algorithm for the | ||
`model` from the previous sample `f`, the sample `ν` from the Gaussian prior, and the angle | ||
`θ`, and save it to `out`. | ||
Mathematically, the proposal can be computed as | ||
```math | ||
\\cos θ f + ν \\sin θ ν + μ (1 - \\sin θ + \\cos θ), | ||
``` | ||
where ``μ`` is the mean of the Gaussian prior. | ||
""" | ||
function proposal!(out, model::AbstractMCMC.AbstractModel, f, ν, θ) end |
Oops, something went wrong.
2028fb2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
2028fb2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/10519
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via: