Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,48 @@ function observe(spl::Sampler, weight)
error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))")
end

# If parameters exist, they are used and not overwritten.
function assume(
spl::Union{SampleFromPrior, SampleFromUniform},
spl::SampleFromPrior,
dist::Distribution,
vn::VarName,
vi::VarInfo,
)
if haskey(vi, vn)
if is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = spl isa SampleFromUniform ? init(dist) : rand(dist)
r = rand(dist)
vi[vn] = vectorize(dist, r)
settrans!(vi, false, vn)
setorder!(vi, vn, get_num_produce(vi))
else
r = vi[vn]
r = vi[vn]
end
else
r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist)
r = rand(dist)
push!(vi, vn, r, dist, spl)
settrans!(vi, false, vn)
end
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
end

# Always overwrites the parameters with new ones.
function assume(
spl::SampleFromUniform,
dist::Distribution,
vn::VarName,
vi::VarInfo,
)
if haskey(vi, vn)
unset_flag!(vi, vn, "del")
r = init(dist)
vi[vn] = vectorize(dist, r)
settrans!(vi, true, vn)
setorder!(vi, vn, get_num_produce(vi))
else
r = init(dist)
push!(vi, vn, r, dist, spl)
settrans!(vi, true, vn)
end
# NOTE: The importance weight is not correctly computed here because
# r is genereated from some uniform distribution which is different from the prior
Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ end

# ROBUST INITIALISATIONS
# Uniform rand with range 2; ref: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
randrealuni() = Real(2rand())
randrealuni(args...) = map(Real, 2rand(args...))
randrealuni() = 4 * rand() - 2
randrealuni(args...) = 4 .* rand(args...) .- 2

const Transformable = Union{TransformDistribution, SimplexDistribution, PDMatDistribution}

Expand Down
1 change: 1 addition & 0 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ Return the current value(s) of the random variables sampled by `spl` in `vi`.
The value(s) may or may not be transformed to Euclidean space.
"""
getindex(vi::AbstractVarInfo, spl::SampleFromPrior) = copy(getall(vi))
getindex(vi::AbstractVarInfo, spl::SampleFromUniform) = copy(getall(vi))
getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl)))
function getindex(vi::TypedVarInfo, spl::Sampler)
# Gets the ranges as a NamedTuple
Expand Down
24 changes: 12 additions & 12 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using DynamicPPL: Selector, reconstruct, invlink, CACHERESET,
set_flag!, unset_flag!, VarInfo, TypedVarInfo,
getlogp, setlogp!, resetlogp!, acclogp!, vectorize,
setorder!, updategid!
using DynamicPPL
using DynamicPPL, LinearAlgebra
using Distributions
using ForwardDiff: Dual
using Test
Expand Down Expand Up @@ -167,32 +167,32 @@ include(dir*"/test/test_utils/AllUtils.jl")
meta = vi.metadata
model(vi, SampleFromUniform())

@test all(x -> ~istrans(vi, x), meta.vns)
@test all(x -> istrans(vi, x), meta.vns)
alg = HMC(0.1, 5)
spl = Sampler(alg, model)
v = copy(meta.vals)
link!(vi, spl)
@test all(x -> istrans(vi, x), meta.vns)
invlink!(vi, spl)
@test all(x -> ~istrans(vi, x), meta.vns)
@test meta.vals == v
link!(vi, spl)
@test all(x -> istrans(vi, x), meta.vns)
@test norm(meta.vals - v) <= 1e-6

vi = TypedVarInfo(vi)
meta = vi.metadata
alg = HMC(0.1, 5)
spl = Sampler(alg, model)
@test all(x -> ~istrans(vi, x), meta.s.vns)
@test all(x -> ~istrans(vi, x), meta.m.vns)
v_s = copy(meta.s.vals)
v_m = copy(meta.m.vals)
link!(vi, spl)
@test all(x -> istrans(vi, x), meta.s.vns)
@test all(x -> istrans(vi, x), meta.m.vns)
v_s = copy(meta.s.vals)
v_m = copy(meta.m.vals)
invlink!(vi, spl)
@test all(x -> ~istrans(vi, x), meta.s.vns)
@test all(x -> ~istrans(vi, x), meta.m.vns)
@test meta.s.vals == v_s
@test meta.m.vals == v_m
link!(vi, spl)
@test all(x -> istrans(vi, x), meta.s.vns)
@test all(x -> istrans(vi, x), meta.m.vns)
@test norm(meta.s.vals - v_s) <= 1e-6
@test norm(meta.m.vals - v_m) <= 1e-6
end
@testset "setgid!" begin
vi = VarInfo()
Expand Down