Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "1.0"
Bijectors = "0.5.2, 0.6, 0.7"
Distributions = "0.22, 0.23"
MacroTools = "0.5.1"
Requires = "0.5, 1.0"
ZygoteRules = "0.2"
julia = "1"

Expand All @@ -37,7 +39,7 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -48,4 +50,4 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "ReverseDiff", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]
4 changes: 3 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module DynamicPPL

using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
using Requires
using Distributions
using Bijectors
using MacroTools
Expand Down Expand Up @@ -36,13 +37,14 @@ export AbstractVarInfo,
set_num_produce!,
reset_num_produce!,
increment_num_produce!,
getmode,
set_retained_vns_del_by_spl!,
is_flagged,
unset_flag!,
setgid!,
updategid!,
setorder!,
istrans,
islinked_and_trans,
link!,
invlink!,
tonamedtuple,
Expand Down
3 changes: 3 additions & 0 deletions src/compat/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ ZygoteRules.@adjoint function push!(
)
return push!(vi, vn, r, dist, gidset), _ -> nothing
end
ZygoteRules.@adjoint function zygote_setval!(vi, val, vn)
return zygote_setval!(vi, val, vn), _ -> nothing
end
50 changes: 25 additions & 25 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ function tilde(ctx::DefaultContext, sampler, right, vn::VarName, _, vi)
return _tilde(sampler, right, vn, vi)
end
function tilde(ctx::PriorContext, sampler, right, vn::VarName, inds, vi)
@assert !islinked(vi)
if ctx.vars !== nothing
vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds))
settrans!(vi, false, vn)
vi[vn, right] = _getindex(getfield(ctx.vars, getsym(vn)), inds)
end
return _tilde(sampler, right, vn, vi)
end
function tilde(ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi)
@assert !islinked(vi)
if ctx.vars !== nothing
vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds))
settrans!(vi, false, vn)
vi[vn, right] = _getindex(getfield(ctx.vars, getsym(vn)), inds)
end
return _tilde(sampler, NoDist(right), vn, vi)
end
Expand Down Expand Up @@ -125,20 +125,20 @@ function assume(
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if spl isa SampleFromUniform || is_flagged(vi, vn, "del")
@assert !islinked(vi)
unset_flag!(vi, vn, "del")
r = init(dist, spl)
vi[vn] = vectorize(dist, r)
settrans!(vi, false, vn)
vi[vn, dist] = r
setorder!(vi, vn, get_num_produce(vi))
else
r = vi[vn]
r = vi[vn, dist]
end
else
@assert !islinked(vi)
r = init(dist, spl)
push!(vi, vn, r, dist, spl)
settrans!(vi, false, vn)
end
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
return r, Bijectors.logpdf_with_trans(dist, r, islinked_and_trans(vi, vn))
end

function observe(
Expand Down Expand Up @@ -167,11 +167,11 @@ function dot_tilde(
inds,
vi,
)
@assert !islinked(vi)
if ctx.vars !== nothing
var = _getindex(getfield(ctx.vars, getsym(vn)), inds)
vns, dist = get_vns_and_dist(right, var, vn)
set_val!(vi, vns, dist, var)
settrans!.(Ref(vi), false, vns)
else
vns, dist = get_vns_and_dist(right, left, vn)
end
Expand All @@ -189,11 +189,11 @@ function dot_tilde(
inds,
vi,
)
@assert !islinked(vi)
if ctx.vars !== nothing
var = _getindex(getfield(ctx.vars, getsym(vn)), inds)
vns, dist = get_vns_and_dist(right, var, vn)
set_val!(vi, vns, dist, var)
settrans!.(Ref(vi), false, vns)
else
vns, dist = get_vns_and_dist(right, left, vn)
end
Expand All @@ -214,14 +214,12 @@ function dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi)
return value
end


function get_vns_and_dist(dist::NamedDist, var, vn::VarName)
return get_vns_and_dist(dist.dist, var, dist.name)
end
function get_vns_and_dist(dist::MultivariateDistribution, var::AbstractMatrix, vn::VarName)
getvn = i -> VarName(vn, (vn.indexing..., (Colon(), i)))
return getvn.(1:size(var, 2)), dist

end
function get_vns_and_dist(
dist::Union{Distribution, AbstractArray{<:Distribution}},
Expand Down Expand Up @@ -256,7 +254,7 @@ function dot_assume(
)
@assert length(dist) == size(var, 1)
r = get_and_set_val!(vi, vns, dist, spl)
lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1])))
lp = sum(Bijectors.logpdf_with_trans(dist, r, islinked_and_trans(vi, vns[1])))
var .= r
return var, lp
end
Expand All @@ -269,7 +267,9 @@ function dot_assume(
)
r = get_and_set_val!(vi, vns, dists, spl)
# Make sure `r` is not a matrix for multivariate distributions
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
mode = getmode(vi)
trans = istrans(vi, vns[1]) && (mode isa LinkMode || mode isa InitLinkMode)
lp = sum(Bijectors.logpdf_with_trans.(dists, r, trans))
var .= r
return var, lp
end
Expand All @@ -293,23 +293,23 @@ function get_and_set_val!(
if haskey(vi, vns[1])
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
@assert !islinked(vi)
unset_flag!(vi, vns[1], "del")
r = init(dist, spl, n)
for i in 1:n
vn = vns[i]
vi[vn] = vectorize(dist, r[:, i])
settrans!(vi, false, vn)
vi[vn, dist] = r[:, i]
setorder!(vi, vn, get_num_produce(vi))
end
else
r = vi[vns]
r = vi[vns, dist]
end
else
@assert !islinked(vi)
r = init(dist, spl, n)
for i in 1:n
vn = vns[i]
push!(vi, vn, r[:,i], dist, spl)
settrans!(vi, false, vn)
end
end
return r
Expand All @@ -324,24 +324,24 @@ function get_and_set_val!(
if haskey(vi, vns[1])
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
@assert !islinked(vi)
unset_flag!(vi, vns[1], "del")
f = (vn, dist) -> init(dist, spl)
r = f.(vns, dists)
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
vi[vn] = vectorize(dist, r[i])
settrans!(vi, false, vn)
vi[vn, dist] = r[i]
setorder!(vi, vn, get_num_produce(vi))
end
else
r = reshape(vi[vec(vns)], size(vns))
r = vi[vns, dists]
end
else
@assert !islinked(vi)
f = (vn, dist) -> init(dist, spl)
r = f.(vns, dists)
push!.(Ref(vi), vns, r, dists, Ref(spl))
settrans!.(Ref(vi), false, vns)
end
return r
end
Expand All @@ -354,7 +354,7 @@ function set_val!(
)
@assert size(val, 2) == length(vns)
foreach(enumerate(vns)) do (i, vn)
vi[vn] = val[:,i]
vi[vn, dist] = val[:,i]
end
return val
end
Expand All @@ -367,7 +367,7 @@ function set_val!(
@assert size(val) == size(vns)
foreach(CartesianIndices(val)) do ind
dist = dists isa AbstractArray ? dists[ind] : dists
vi[vns[ind]] = vectorize(dist, val[ind])
vi[vns[ind], dist] = val[ind]
end
return val
end
Expand Down
1 change: 0 additions & 1 deletion src/prob_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ _setval!(vi::TypedVarInfo, c::AbstractChains) = _setval!(vi.metadata, vi, c)
for vn in md.$n.vns
val = copy.(vec(c[Symbol(string(vn))].value))
setval!(vi, val, vn)
settrans!(vi, false, vn)
end
end
end...)
Expand Down
26 changes: 21 additions & 5 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ function setlogp!(vi::ThreadSafeVarInfo, logp)
return setlogp!(vi.varinfo, logp)
end

Bijectors.link(vi::ThreadSafeVarInfo) = ThreadSafeVarInfo(link(vi.varinfo), vi.logps)
Bijectors.invlink(vi::ThreadSafeVarInfo) = ThreadSafeVarInfo(invlink(vi.varinfo), vi.logps)
initlink(vi::ThreadSafeVarInfo) = ThreadSafeVarInfo(initlink(vi.varinfo), vi.logps)

getrange(vi::ThreadSafeVarInfo, vn::VarName) = getrange(vi.varinfo, vn)
get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo)
increment_num_produce!(vi::ThreadSafeVarInfo) = increment_num_produce!(vi.varinfo)
reset_num_produce!(vi::ThreadSafeVarInfo) = reset_num_produce!(vi.varinfo)
Expand All @@ -50,20 +55,27 @@ function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName)
setgid!(vi.varinfo, gid, vn)
end
setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index)
setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)

keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)
haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn)
getmode(vi::ThreadSafeVarInfo) = getmode(vi.varinfo)
issynced(vi::ThreadSafeVarInfo) = issynced(vi.varinfo)
function setsynced!(vi::ThreadSafeVarInfo, b::Bool)
setsynced!(vi.varinfo, b)
return vi
end
getmetadata(vi::ThreadSafeVarInfo, vn::VarName) = getmetadata(vi.varinfo, vn)

link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl)
invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl)
init_dist_link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = init_dist_link!(vi.varinfo, spl)
init_dist_invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = init_dist_invlink!(vi.varinfo, spl)
islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl)
getinitdist(vi::ThreadSafeVarInfo, vn::VarName) = getinitdist(vi.varinfo, vn)
has_fixed_support(vi::ThreadSafeVarInfo) = has_fixed_support(vi.varinfo)
set_fixed_support!(vi::ThreadSafeVarInfo, b::Bool) = set_fixed_support!(vi.varinfo, b)

getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl)
getindex(vi::ThreadSafeVarInfo, spl::SampleFromPrior) = getindex(vi.varinfo, spl)
getindex(vi::ThreadSafeVarInfo, spl::SampleFromUniform) = getindex(vi.varinfo, spl)
getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn)
getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns)

function setindex!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler)
setindex!(vi.varinfo, val, spl)
Expand All @@ -85,6 +97,10 @@ function empty!(vi::ThreadSafeVarInfo)
fill!(vi.logps, zero(getlogp(vi)))
return vi
end
function empty!(vi::ThreadSafeVarInfo, spl::AbstractSampler)
empty!(vi.varinfo, spl)
return vi
end

function push!(
vi::ThreadSafeVarInfo,
Expand Down
16 changes: 16 additions & 0 deletions src/varinfo/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
function __init__()
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
value(x::ForwardDiff.Dual) = ForwardDiff.value(x)
value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x)
end
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x)
value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x)
value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x)
end
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
value(x::Tracker.TrackedReal) = Tracker.data(x)
value(x::Tracker.TrackedArray) = Tracker.data(x)
value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x)
end
end
Comment on lines +1 to +16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These definitions should not be part of DynamicPPL but moved to some other package. The only AD-related part in DynamicPPL should be custom adjoints for functions defined in DynamicPPL if required, which should be implemented without depending on the AD packages if possible (by using, e.g., ZygoteRules or ChainRulesCore). If these definitions are removed, there's also no need for adding Requires as a dependency, AFAICT.

Loading