@@ -3,7 +3,7 @@ module Inference
33using .. Core, .. Core. RandomVariables, .. Utilities
44using .. Core. RandomVariables: Metadata, _tail, VarInfo, TypedVarInfo,
55 islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize,
6- settrans!
6+ settrans!, _getvns, getdist
77using .. Core: split_var_str
88using Distributions, Libtask, Bijectors
99using ProgressMeter, LinearAlgebra
@@ -13,7 +13,7 @@ using ..Turing: Model, runmodel!, Turing,
1313 Selector, AbstractSamplerState, DefaultContext, PriorContext,
1414 LikelihoodContext, MiniBatchContext, NamedDist, NoDist
1515using StatsFuns: logsumexp
16- using Random: GLOBAL_RNG, AbstractRNG
16+ using Random: GLOBAL_RNG, AbstractRNG, randexp
1717using AbstractMCMC
1818
1919import MCMCChains: Chains
@@ -33,6 +33,7 @@ export InferenceAlgorithm,
3333 SampleFromUniform,
3434 SampleFromPrior,
3535 MH,
36+ ESS,
3637 Gibbs, # classic sampling
3738 HMC,
3839 SGLD,
@@ -274,8 +275,8 @@ function _params_to_array(ts::Vector{T}, spl::Sampler) where {T<:AbstractTransit
274275 end
275276 push! (dicts, d)
276277 end
277-
278- # Convert the set to an ordered vector so the parameter ordering
278+
279+ # Convert the set to an ordered vector so the parameter ordering
279280 # is deterministic.
280281 ordered_names = collect (names)
281282 vals = Matrix {Union{Real, Missing}} (undef, length (ts), length (ordered_names))
486487# Concrete algorithm implementations. #
487488# ######################################
488489
490+ include (" ess.jl" )
489491include (" hmc.jl" )
490492include (" mh.jl" )
491493include (" is.jl" )
@@ -498,7 +500,7 @@ include("../contrib/inference/AdvancedSMCExtensions.jl")
498500# Typing tools #
499501# ###############
500502
501- for alg in (:SMC , :PG , :PMMH , :IPMCMC , :MH , :IS , :Gibbs )
503+ for alg in (:SMC , :PG , :PMMH , :IPMCMC , :MH , :IS , :ESS , : Gibbs )
502504 @eval getspace (:: $alg{space} ) where {space} = space
503505end
504506for alg in (:HMC , :HMCDA , :NUTS , :SGLD , :SGHMC )
@@ -635,7 +637,14 @@ function assume(
635637 vi:: VarInfo ,
636638)
637639 if haskey (vi, vn)
640+ if is_flagged (vi, vn, " del" )
641+ unset_flag! (vi, vn, " del" )
642+ r = spl isa SampleFromUniform ? init (dist) : rand (dist)
643+ vi[vn] = vectorize (dist, r)
644+ setorder! (vi, vn, vi. num_produce)
645+ else
638646 r = vi[vn]
647+ end
639648 else
640649 r = isa (spl, SampleFromUniform) ? init (dist) : rand (dist)
641650 push! (vi, vn, r, dist, spl)
@@ -792,9 +801,19 @@ function get_and_set_val!(
792801)
793802 n = length (vns)
794803 if haskey (vi, vns[1 ])
804+ if is_flagged (vi, vns[1 ], " del" )
805+ unset_flag! (vi, vns[1 ], " del" )
806+ r = spl isa SampleFromUniform ? init (dist, n) : rand (dist, n)
807+ for i in 1 : n
808+ vn = vns[i]
809+ vi[vn] = vectorize (dist, r[:, i])
810+ setorder! (vi, vn, vi. num_produce)
811+ end
812+ else
795813 r = vi[vns]
814+ end
796815 else
797- r = isa ( spl, SampleFromUniform) ? init (dist, n) : rand (dist, n)
816+ r = spl isa SampleFromUniform ? init (dist, n) : rand (dist, n)
798817 for i in 1 : n
799818 push! (vi, vns[i], r[:,i], dist, spl)
800819 end
@@ -808,9 +827,21 @@ function get_and_set_val!(
808827 spl:: AbstractSampler ,
809828)
810829 if haskey (vi, vns[1 ])
830+ if is_flagged (vi, vns[1 ], " del" )
831+ unset_flag! (vi, vns[1 ], " del" )
832+ f = (vn, dist) -> spl isa SampleFromUniform ? init (dist) : rand (dist)
833+ r = f .(vns, dists)
834+ for i in eachindex (vns)
835+ vn = vns[i]
836+ dist = dists isa AbstractArray ? dists[i] : dists
837+ vi[vn] = vectorize (dist, r[i])
838+ setorder! (vi, vn, vi. num_produce)
839+ end
840+ else
811841 r = reshape (vi[vec (vns)], size (vns))
842+ end
812843 else
813- f (vn, dist) = isa ( spl, SampleFromUniform) ? init (dist) : rand (dist)
844+ f = (vn, dist) -> spl isa SampleFromUniform ? init (dist) : rand (dist)
814845 r = f .(vns, dists)
815846 push! .(Ref (vi), vns, r, dists, Ref (spl))
816847 end
0 commit comments