Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ export @model, # modelling
@varname,
DynamicPPL,

Prior, # Sampling from the prior

MH, # classic sampling
RWMH,
ESS,
Expand Down
2 changes: 2 additions & 0 deletions src/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ function additional_parameters(::Type{<:ParticleTransition})
return [:lp,:le, :weight]
end

DynamicPPL.getlogp(t::ParticleTransition) = t.lp

####
#### Generic Sequential Monte Carlo sampler.
####
Expand Down
141 changes: 90 additions & 51 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export InferenceAlgorithm,
SMC,
CSMC,
PG,
Prior,
assume,
dot_assume,
observe,
Expand All @@ -70,6 +71,9 @@ abstract type AdaptiveHamiltonian{AD} <: Hamiltonian{AD} end
getchunksize(::Type{<:Hamiltonian{AD}}) where AD = getchunksize(AD)
getADbackend(::Hamiltonian{AD}) where AD = AD()

# Algorithm for sampling from the prior
struct Prior <: InferenceAlgorithm end

"""
mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::Real)

Expand Down Expand Up @@ -107,6 +111,8 @@ function additional_parameters(::Type{<:Transition})
return [:lp]
end

DynamicPPL.getlogp(t::Transition) = t.lp

##########################################
# Internal variable names for MCMCChains #
##########################################
Expand Down Expand Up @@ -158,7 +164,7 @@ end
function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
sampler::Sampler,
sampler::Sampler{<:InferenceAlgorithm},
N::Integer;
chain_type=MCMCChains.Chains,
resume_from=nothing,
Expand All @@ -173,6 +179,24 @@ function AbstractMCMC.sample(
end
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::Prior,
N::Integer;
chain_type=MCMCChains.Chains,
resume_from=nothing,
progress=PROGRESS[],
kwargs...
)
if resume_from === nothing
return AbstractMCMC.mcmcsample(rng, model, SampleFromPrior(), N;
chain_type=chain_type, progress=progress, kwargs...)
else
return resume(resume_from, N; chain_type=chain_type, progress=progress, kwargs...)
end
end

function AbstractMCMC.sample(
model::AbstractModel,
alg::InferenceAlgorithm,
Expand Down Expand Up @@ -201,7 +225,7 @@ end
function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
sampler::Sampler,
sampler::Sampler{<:InferenceAlgorithm},
parallel::AbstractMCMC.AbstractMCMCParallel,
N::Integer,
n_chains::Integer;
Expand All @@ -213,10 +237,25 @@ function AbstractMCMC.sample(
chain_type=chain_type, progress=progress, kwargs...)
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::Prior,
parallel::AbstractMCMC.AbstractMCMCParallel,
N::Integer,
n_chains::Integer;
chain_type=MCMCChains.Chains,
progress=PROGRESS[],
kwargs...
)
return AbstractMCMC.sample(rng, model, SampleFromPrior(), parallel, N, n_chains;
chain_type=chain_type, progress=progress, kwargs...)
end

function AbstractMCMC.sample_init!(
::AbstractRNG,
model::Model,
spl::Sampler,
model::AbstractModel,
spl::Sampler{<:InferenceAlgorithm},
N::Integer;
kwargs...
)
Expand All @@ -227,17 +266,6 @@ function AbstractMCMC.sample_init!(
initialize_parameters!(spl; kwargs...)
end

function AbstractMCMC.sample_end!(
::AbstractRNG,
::Model,
::Sampler,
::Integer,
::Vector;
kwargs...
)
# Silence the default API function.
end

function initialize_parameters!(
spl::Sampler;
init_theta::Union{Nothing,Vector}=nothing,
Expand Down Expand Up @@ -268,19 +296,27 @@ end
# Chain making utilities #
##########################

function _params_to_array(ts::Vector, spl::Sampler)
"""
getparams(t)

Return a named tuple of parameters.
"""
getparams(t) = t.θ
getparams(t::VarInfo) = tonamedtuple(TypedVarInfo(t))

function _params_to_array(ts)
names_set = Set{String}()
# Extract the parameter names and values from each transition.
dicts = map(ts) do t
nms, vs = flatten_namedtuple(t.θ)
nms, vs = flatten_namedtuple(getparams(t))
for nm in nms
push!(names_set, nm)
end
# Convert the names and values to a single dictionary.
return Dict(nms[j] => vs[j] for j in 1:length(vs))
end
names = collect(names_set)
vals = [get(dicts[i], key, missing) for i in eachindex(dicts),
vals = [get(dicts[i], key, missing) for i in eachindex(dicts),
(j, key) in enumerate(names)]

return names, vals
Expand All @@ -300,7 +336,12 @@ function flatten_namedtuple(nt::NamedTuple)
return [vn[1] for vn in names_vals], [vn[2] for vn in names_vals]
end

function get_transition_extras(ts::Vector)
function get_transition_extras(ts::AbstractVector{<:VarInfo})
valmat = reshape([getlogp(t) for t in ts], :, 1)
return ["lp"], valmat
end

function get_transition_extras(ts::AbstractVector)
# Get the extra field names from the sampler state type.
# This handles things like :lp or :weight.
extra_params = additional_parameters(eltype(ts))
Expand Down Expand Up @@ -340,50 +381,46 @@ function get_transition_extras(ts::Vector)
return extra_names, valmat
end

getlogevidence(sampler) = missing
function getlogevidence(sampler::Sampler)
if isdefined(sampler.state, :average_logevidence)
return sampler.state.average_logevidence
elseif isdefined(sampler.state, :final_logevidence)
return sampler.state.final_logevidence
else
return missing
end
end

# Default MCMCChains.Chains constructor.
# This is type piracy (at least for SampleFromPrior).
function AbstractMCMC.bundle_samples(
rng::AbstractRNG,
model::Model,
spl::Sampler,
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
N::Integer,
ts::Vector,
chain_type::Type{MCMCChains.Chains};
discard_adapt::Bool=true,
save_state=false,
save_state = false,
kwargs...
)
# Check if we have adaptation samples.
if discard_adapt && :n_adapts in fieldnames(typeof(spl.alg))
ts = ts[(spl.alg.n_adapts+1):end]
end

# Convert transitions to array format.
# Also retrieve the variable names.
nms, vals = _params_to_array(ts, spl)
nms, vals = _params_to_array(ts)

# Get the values of the extra parameters in each Transition struct.
# Get the values of the extra parameters in each transition.
extra_params, extra_values = get_transition_extras(ts)

# Extract names & construct param array.
nms = [nms; extra_params]
parray = hcat(vals, extra_values)

# If the state field has average_logevidence or final_logevidence, grab that.
le = missing
if :average_logevidence in fieldnames(typeof(spl.state))
le = getproperty(spl.state, :average_logevidence)
elseif :final_logevidence in fieldnames(typeof(spl.state))
le = getproperty(spl.state, :final_logevidence)
end

# Check whether to invlink! the varinfo
if islinked(spl.state.vi, spl)
invlink!(spl.state.vi, spl)
end
# Get the average or final log evidence, if it exists.
le = getlogevidence(spl)

# Set up the info tuple.
if save_state
info = (range = rng, model = model, spl = spl, vi = spl.state.vi)
info = (range = rng, model = model, spl = spl)
else
info = NamedTuple()
end
Expand All @@ -402,10 +439,11 @@ function AbstractMCMC.bundle_samples(
)
end

# This is type piracy (for SampleFromPrior).
function AbstractMCMC.bundle_samples(
rng::AbstractRNG,
model::Model,
spl::Sampler,
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
N::Integer,
ts::Vector,
chain_type::Type{Vector{NamedTuple}};
Expand All @@ -415,17 +453,18 @@ function AbstractMCMC.bundle_samples(
)
nts = Vector{NamedTuple}(undef, N)

for (i,t) in enumerate(ts)
k = collect(keys(t.θ))
for (i, t) in enumerate(ts)
params = getparams(t)

k = collect(keys(params))
vs = []
for v in values(t.θ)
for v in values(params)
push!(vs, v[1])
end

push!(k, :lp)


nts[i] = NamedTuple{tuple(k...)}(tuple(vs..., t.lp))

nts[i] = NamedTuple{tuple(k...)}(tuple(vs..., getlogp(t)))
end

return map(identity, nts)
Expand Down
4 changes: 3 additions & 1 deletion src/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ function additional_parameters(::Type{<:GibbsTransition})
return [:lp]
end

DynamicPPL.getlogp(t::GibbsTransition) = t.lp

# Initialize the Gibbs sampler.
function AbstractMCMC.sample_init!(
rng::AbstractRNG,
Expand Down Expand Up @@ -207,4 +209,4 @@ function AbstractMCMC.transitions_save!(
)
transitions[iteration] = Transition(transition.θ, transition.lp)
return
end
end
40 changes: 39 additions & 1 deletion src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ function additional_parameters(::Type{<:HamiltonianTransition})
return [:lp,:stat]
end

DynamicPPL.getlogp(t::HamiltonianTransition) = t.lp

###
### Hamiltonian Monte Carlo samplers.
Expand Down Expand Up @@ -101,7 +102,7 @@ end

function AbstractMCMC.sample_init!(
rng::AbstractRNG,
model::Model,
model::AbstractModel,
spl::Sampler{<:Hamiltonian},
N::Integer;
verbose::Bool=true,
Expand Down Expand Up @@ -141,6 +142,43 @@ function AbstractMCMC.sample_init!(
end
end

function AbstractMCMC.transitions_init(
transition,
::AbstractModel,
sampler::Sampler{<:Hamiltonian},
N::Integer;
discard_adapt = true,
kwargs...
)
if discard_adapt && isdefined(sampler.alg, :n_adapts)
n = max(0, N - sampler.alg.n_adapts)
else
n = N
end
return Vector{typeof(transition)}(undef, n)
end

function AbstractMCMC.transitions_save!(
transitions::AbstractVector,
iteration::Integer,
transition,
::AbstractModel,
sampler::Sampler{<:Hamiltonian},
::Integer;
discard_adapt = true,
kwargs...
)
if discard_adapt && isdefined(sampler.alg, :n_adapts)
if iteration > sampler.alg.n_adapts
transitions[iteration - sampler.alg.n_adapts] = transition
end
return
end

transitions[iteration] = transition
return
end

"""
HMCDA(n_adapts::Int, δ::Float64, λ::Float64; ϵ::Float64=0.0)

Expand Down
Loading