Skip to content

Commit 9bee395

Browse files
committed
Merge branch 'tor/simple-varinfo-v2' into tor/immutable-varinfo-support
2 parents d1638a8 + be35be0 commit 9bee395

File tree

4 files changed

+306
-2
lines changed

4 files changed

+306
-2
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1111
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1212
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
14+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1415
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1516

1617
[compat]

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ export AbstractVarInfo,
3434
VarInfo,
3535
UntypedVarInfo,
3636
TypedVarInfo,
37+
SimpleVarInfo,
3738
push!!,
3839
empty!!,
3940
getlogp,
@@ -135,6 +136,7 @@ include("varname.jl")
135136
include("distribution_wrappers.jl")
136137
include("contexts.jl")
137138
include("varinfo.jl")
139+
include("simple_varinfo.jl")
138140
include("threadsafe.jl")
139141
include("context_implementations.jl")
140142
include("compiler.jl")

src/simple_varinfo.jl

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
using Setfield
2+
3+
"""
4+
SimpleVarInfo{NT,T} <: AbstractVarInfo
5+
6+
A simple wrapper of the parameters with a `logp` field for
7+
accumulation of the logdensity.
8+
9+
Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`.
10+
11+
# Notes
12+
The major differences between this and `TypedVarInfo` are:
13+
1. `SimpleVarInfo` does not require linearization.
14+
2. `SimpleVarInfo` can use more efficient bijectors.
15+
3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either
16+
a) no indexing is used in tilde-statements, or
17+
b) the values have been specified with the corret shapes.
18+
19+
# Examples
20+
```jldoctest; setup=:(using Distributions)
21+
julia> using StableRNGs
22+
23+
julia> @model function demo()
24+
m ~ Normal()
25+
x = Vector{Float64}(undef, 2)
26+
for i in eachindex(x)
27+
x[i] ~ Normal()
28+
end
29+
return x
30+
end
31+
demo (generic function with 1 method)
32+
33+
julia> m = demo();
34+
35+
julia> rng = StableRNG(42);
36+
37+
julia> ### Sampling ###
38+
ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext());
39+
40+
julia> # In the `NamedTuple` version we need to provide the place-holder values for
41+
# the variablse which are using "containers", e.g. `Array`.
42+
# In this case, this means that we need to specify `x` but not `m`.
43+
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi
44+
SimpleVarInfo{NamedTuple{(:x, :m), Tuple{Vector{Float64}, Float64}}, Float64}((x = [1.6642061055583879, 1.796319600944139], m = -0.16796295277202952), -5.769094411622931)
45+
46+
julia> # (✓) Vroom, vroom! FAST!!!
47+
DynamicPPL.getval(vi, @varname(x[1]))
48+
1.6642061055583879
49+
50+
julia> # We can also access arbitrary varnames pointing to `x`, e.g.
51+
DynamicPPL.getval(vi, @varname(x))
52+
2-element Vector{Float64}:
53+
1.6642061055583879
54+
1.796319600944139
55+
56+
julia> DynamicPPL.getval(vi, @varname(x[1:2]))
57+
2-element view(::Vector{Float64}, 1:2) with eltype Float64:
58+
1.6642061055583879
59+
1.796319600944139
60+
61+
julia> # (×) If we don't provide the container...
62+
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi
63+
ERROR: type NamedTuple has no field x
64+
[...]
65+
66+
julia> # If one does not know the varnames, we can use a `Dict` instead.
67+
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); vi
68+
SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.192696983568277, x[2] => 0.4914514300738121, m => 0.25572200616753643), -3.6215377732004237)
69+
70+
julia> # (✓) Sort of fast, but only possible at runtime.
71+
DynamicPPL.getval(vi, @varname(x[1]))
72+
1.192696983568277
73+
74+
julia> # In addtion, we can only access varnames as they appear in the model!
75+
DynamicPPL.getval(vi, @varname(x))
76+
ERROR: KeyError: key x not found
77+
[...]
78+
79+
julia> julia> DynamicPPL.getval(vi, @varname(x[1:2]))
80+
ERROR: KeyError: key x[1:2] not found
81+
[...]
82+
```
83+
"""
84+
struct SimpleVarInfo{NT,T} <: AbstractVarInfo
85+
θ::NT
86+
logp::T
87+
end
88+
89+
SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T))
90+
SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ)
91+
SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(NamedTuple())
92+
SimpleVarInfo() = SimpleVarInfo{Float64}()
93+
94+
# Constructor from `Model`.
95+
SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...)
96+
function SimpleVarInfo{T}(model::Model, args...) where {T<:Real}
97+
_, svi = DynamicPPL.evaluate(model, SimpleVarInfo{T}(), args...)
98+
return svi
99+
end
100+
101+
# Constructor from `VarInfo`.
102+
function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D}
103+
return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...)
104+
end
105+
function SimpleVarInfo{T}(
106+
vi::VarInfo{<:NamedTuple{names}}, ::Type{D}
107+
) where {T<:Real,names,D}
108+
values = values_as(vi, D)
109+
return SimpleVarInfo{T}(values)
110+
end
111+
112+
getlogp(vi::SimpleVarInfo) = vi.logp
113+
setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp)
114+
acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp)
115+
116+
function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
117+
vi.logp[] = logp
118+
return vi
119+
end
120+
121+
function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
122+
vi.logp[] += logp
123+
return vi
124+
end
125+
126+
function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym}
127+
# Use `getproperty` instead of `getfield`
128+
value = getproperty(nt, sym)
129+
# Note that this will return a `view`, even if the resulting value is 0-dim.
130+
# This makes it possible to call `setindex!` on the result later to update
131+
# in place even in the case where are retrieving a single element, e.g. `x[1]`.
132+
return _getindex(value, inds)
133+
end
134+
135+
# `NamedTuple`
136+
function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}) where {sym}
137+
return maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing))
138+
end
139+
140+
# `Dict`
141+
function getval(vi::SimpleVarInfo{<:Dict}, vn::VarName)
142+
return vi.θ[vn]
143+
end
144+
145+
# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than
146+
# just `Vector`.
147+
getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns)
148+
# To disambiguiate.
149+
getval(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getval(vi, vn), vns)
150+
151+
haskey(vi::SimpleVarInfo, vn) = haskey(vi.θ, getsym(vn))
152+
153+
istrans(::SimpleVarInfo, vn::VarName) = false
154+
155+
getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ
156+
getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ
157+
# TODO: Should we do better?
158+
getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ
159+
getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn)
160+
getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns)
161+
# HACK: Need to disambiguiate.
162+
getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns)
163+
164+
# Necessary for `matchingvalue` to work properly.
165+
function Base.eltype(
166+
vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior}
167+
) where {T}
168+
return T
169+
end
170+
171+
# `NamedTuple`
172+
function push!!(
173+
vi::SimpleVarInfo{<:NamedTuple},
174+
vn::VarName{sym,Tuple{}},
175+
value,
176+
dist::Distribution,
177+
gidset::Set{Selector},
178+
) where {sym}
179+
@set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,)))
180+
end
181+
function push!!(
182+
vi::SimpleVarInfo{<:NamedTuple},
183+
vn::VarName{sym},
184+
value,
185+
dist::Distribution,
186+
gidset::Set{Selector},
187+
) where {sym}
188+
# We update in place.
189+
# We need a view into the array, hence we call `_getvalue` directly
190+
# rather than `getval`.
191+
current = _getvalue(vi.θ, Val{sym}(), vn.indexing)
192+
current .= value
193+
return vi
194+
end
195+
196+
# `Dict`
197+
function push!!(
198+
vi::SimpleVarInfo{<:Dict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
199+
)
200+
vi.θ[vn] = r
201+
return vi
202+
end
203+
204+
# Context implementations
205+
function tilde_assume!!(context, right, vn, inds, vi::SimpleVarInfo)
206+
value, logp, vi_new = tilde_assume(context, right, vn, inds, vi)
207+
return value, acclogp!!(vi_new, logp)
208+
end
209+
210+
function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo)
211+
left = vi[vn]
212+
return left, Distributions.loglikelihood(dist, left), vi
213+
end
214+
215+
function assume(
216+
rng::Random.AbstractRNG,
217+
sampler::SampleFromPrior,
218+
dist::Distribution,
219+
vn::VarName,
220+
vi::SimpleVarInfo,
221+
)
222+
value = init(rng, dist, sampler)
223+
vi = push!!(vi, vn, value, dist, sampler)
224+
vi = settrans!!(vi, false, vn)
225+
return value, Distributions.loglikelihood(dist, value), vi
226+
end
227+
228+
# function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo)
229+
# throw(MethodError(dot_tilde_assume!!, (context, right, left, vn, inds, vi)))
230+
# end
231+
232+
function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo)
233+
value, logp, vi_new = dot_tilde_assume(context, right, left, vn, inds, vi)
234+
# Mutation of `value` no longer occurs in main body, so we do it here.
235+
left .= value
236+
return value, acclogp!!(vi_new, logp)
237+
end
238+
239+
function dot_assume(
240+
dist::MultivariateDistribution,
241+
var::AbstractMatrix,
242+
vns::AbstractVector{<:VarName},
243+
vi::SimpleVarInfo,
244+
)
245+
@assert length(dist) == size(var, 1)
246+
# NOTE: We cannot work with `var` here because we might have a model of the form
247+
#
248+
# m = Vector{Float64}(undef, n)
249+
# m .~ Normal()
250+
#
251+
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
252+
value = vi[vns]
253+
lp = sum(zip(vns, eachcol(value))) do vn, val
254+
return Distributions.logpdf(dist, val)
255+
end
256+
return value, lp, vi
257+
end
258+
259+
function dot_assume(
260+
dists::Union{Distribution,AbstractArray{<:Distribution}},
261+
var::AbstractArray,
262+
vns::AbstractArray{<:VarName},
263+
vi::SimpleVarInfo{<:NamedTuple},
264+
)
265+
# NOTE: We cannot work with `var` here because we might have a model of the form
266+
#
267+
# m = Vector{Float64}(undef, n)
268+
# m .~ Normal()
269+
#
270+
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
271+
value = vi[vns]
272+
lp = sum(Distributions.logpdf.(dists, value))
273+
return value, lp, vi
274+
end
275+
276+
# HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals.
277+
increment_num_produce!(::SimpleVarInfo) = nothing
278+
settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi

src/varinfo.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,7 +1398,7 @@ function setval!(
13981398
return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
13991399
end
14001400

1401-
function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
1401+
function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys)
14021402
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
14031403
if !isempty(indices)
14041404
val = reduce(vcat, values[indices])
@@ -1479,7 +1479,7 @@ function setval_and_resample!(
14791479
return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
14801480
end
14811481

1482-
function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
1482+
function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys)
14831483
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
14841484
if !isempty(indices)
14851485
val = reduce(vcat, values[indices])
@@ -1493,3 +1493,26 @@ function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values,
14931493

14941494
return indices
14951495
end
1496+
1497+
"""
1498+
values_as(vi::TypedVarInfo, ::Type{NamedTuple})
1499+
values_as(vi::TypedVarInfo, ::Type{Dict})
1500+
1501+
Return values in `vi` as the specified type, e.g. `NamedTuple` is returned if
1502+
"""
1503+
function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names}
1504+
iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names)
1505+
return NamedTuple(map(p -> Symbol(p.first) => p.second, iter))
1506+
end
1507+
1508+
function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{Dict}) where {names}
1509+
iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names)
1510+
return Dict(iter)
1511+
end
1512+
1513+
function values_from_metadata(md::Metadata)
1514+
return (
1515+
vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for
1516+
vn in md.vns
1517+
)
1518+
end

0 commit comments

Comments
 (0)