Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
84d2883
set delete flag if value not provided for a variable in setval
torfjelde Mar 23, 2021
9c9fb9d
added setval_and_resample!, docstrings and doctest
torfjelde Mar 28, 2021
8b585b2
added doctest for setval! while I was at it
torfjelde Mar 28, 2021
62b031d
refactored setval! and setval_and_reparam! plus added some testing
torfjelde Mar 28, 2021
e3d842a
Apply suggestions from code review
torfjelde Mar 28, 2021
6028473
Merge branch 'master' into tor/minor-change-to-setval
torfjelde Apr 2, 2021
71adccf
introduce StableRNGs to allow proper doctests
torfjelde Apr 2, 2021
cf5b094
Merge branch 'tor/minor-change-to-setval' of github.com:TuringLang/Dy…
torfjelde Apr 2, 2021
19d8f72
changed setval! to setval_and_resample! where appropriate
torfjelde Apr 2, 2021
e47a6a7
StableRNGs is now only test-dependency
torfjelde Apr 2, 2021
2b9fc9b
Update src/varname.jl
torfjelde Apr 2, 2021
c015cbe
uses Fix1 instead of separte impl for partial applied subsumes_string
torfjelde Apr 2, 2021
28a421c
Update src/varinfo.jl
torfjelde Apr 2, 2021
7bb3a40
removed Function restriction for _apply!
torfjelde Apr 2, 2021
0d48add
Update src/varname.jl
torfjelde Apr 2, 2021
21f3a97
Merge branch 'tor/minor-change-to-setval' of github.com:TuringLang/Dy…
torfjelde Apr 2, 2021
721bb47
will now warn if keys which are not present in vi are provided to _ap…
torfjelde Apr 2, 2021
8812dc3
added docstrings outlining some limiations with setval! and similars
torfjelde Apr 2, 2021
4d2deb8
added compat entry for StableRNGs
torfjelde Apr 2, 2021
f42af3a
Update src/varinfo.jl
torfjelde Apr 2, 2021
ff44adc
Update src/varinfo.jl
torfjelde Apr 2, 2021
945d5da
convert keys to strings before applying kernel! in _apply!
torfjelde Apr 2, 2021
e6a7bbd
added Base.keys for TypedVarInfo
torfjelde Apr 3, 2021
ef447d0
only keep track of number of variables seen in _apply! rather than se…
torfjelde Apr 3, 2021
e76edfe
added test for keys in addition to fix for keys being sets in _apply!
torfjelde Apr 3, 2021
e6b5eb6
bump version
torfjelde Apr 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.8"
version = "0.10.9"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 1 addition & 1 deletion src/loglikelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ function pointwise_loglikelihoods(
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
for (sample_idx, chain_idx) in iters
# Update the values
setval!(vi, chain, sample_idx, chain_idx)
setval_and_resample!(vi, chain, sample_idx, chain_idx)

# Execute model
model(vi, spl, ctx)
Expand Down
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ function generated_quantities(model::Model, chain::AbstractChains)
varinfo = VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
setval!(varinfo, chain, sample_idx, chain_idx)
setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
model(varinfo)
end
end
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,10 @@ end
function inittrans(rng, dist::MatrixDistribution, n::Int)
return invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n])
end


#######################
# Convenience methods #
#######################
collectmaybe(x) = x
collectmaybe(x::Base.AbstractSet) = collect(x)
232 changes: 211 additions & 21 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,11 +584,22 @@ end

# Functions defined only for UntypedVarInfo
"""
keys(vi::UntypedVarInfo)
keys(vi::AbstractVarInfo)

Return an iterator over all `vns` in `vi`.
"""
keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs)
Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs)

@generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add some tests for keys I guess.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Also added a collectmaybe method to deal with the case where keys are AbstractSet, in which case map is not defined.

expr = Expr(:call)
push!(expr.args, :vcat)

for n in names
push!(expr.args, :(vi.metadata.$n.vns))
end

return expr
end

"""
setgid!(vi::VarInfo, gid::Selector, vn::VarName)
Expand Down Expand Up @@ -1165,19 +1176,39 @@ function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler)
end
end

setval!(vi::AbstractVarInfo, x) = _setval!(vi, values(x), keys(x))
function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
return _setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
end
# TODO: Maybe rename or something?
"""
_apply!(kernel!, vi::AbstractVarInfo, values, keys)

Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`.
"""
function _apply!(kernel!, vi::AbstractVarInfo, values, keys)
keys_strings = map(string, collectmaybe(keys))
num_indices_seen = 0

function _setval!(vi::AbstractVarInfo, values, keys)
for vn in Base.keys(vi)
_setval_kernel!(vi, vn, values, keys)
indices_found = kernel!(vi, vn, values, keys_strings)
if indices_found !== nothing
num_indices_seen += length(indices_found)
end
end

if length(keys) > num_indices_seen
# Some keys have not been seen, i.e. attempted to set variables which
# we were not able to locate in `vi`.
# Find the ones we missed so we can warn the user.
unused_keys = _find_missing_keys(vi, keys_strings)
@warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)"
end

return vi
end
_setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, values, keys)
@generated function _typed_setval!(

_apply!(kernel!, vi::TypedVarInfo, values, keys) = _typed_apply!(
kernel!, vi, vi.metadata, values, collectmaybe(keys))

@generated function _typed_apply!(
kernel!,
vi::TypedVarInfo,
metadata::NamedTuple{names},
values,
Expand All @@ -1186,30 +1217,189 @@ _setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, value
updates = map(names) do n
quote
for vn in metadata.$n.vns
_setval_kernel!(vi, vn, values, keys)
indices_found = kernel!(vi, vn, values, keys_strings)
if indices_found !== nothing
num_indices_seen += length(indices_found)
end
end
end
end

return quote
keys_strings = map(string, keys)
num_indices_seen = 0

$(updates...)

if length(keys) > num_indices_seen
# Some keys have not been seen, i.e. attempted to set variables which
# we were not able to locate in `vi`.
# Find the ones we missed so we can warn the user.
unused_keys = _find_missing_keys(vi, keys_strings)
@warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)"
end

return vi
end
end

function _find_missing_keys(vi::AbstractVarInfo, keys)
string_vns = map(string, collectmaybe(Base.keys(vi)))
# If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`.
missing_keys = filter(keys) do key
!any(Base.Fix2(subsumes_string, key), string_vns)
end

return missing_keys
end

"""
setval!(vi::AbstractVarInfo, x)
setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)

Set the values in `vi` to the provided values and leave those which are not present in
`x` or `chains` unchanged.

## Notes
This is rather limited for two reasons:
1. It uses `subsumes_string(string(vn), map(string, keys))` under the hood,
and therefore suffers from the same limitations as [`subsumes_string`](@ref).
2. It will set every `vn` present in `keys`. It will NOT however
set every `k` present in `keys`. This means that if `vn == [m[1], m[2]]`,
representing some variable `m`, calling `setval!(vi, (m = [1.0, 2.0]))` will
be a no-op since it will try to find `m[1]` and `m[2]` in `keys((m = [1.0, 2.0]))`.

## Example
```jldoctest
julia> using DynamicPPL, Distributions, StableRNGs

julia> @model function demo(x)
m ~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m, 1)
end
end;

julia> rng = StableRNG(42);

julia> m = demo([missing]);

julia> var_info = DynamicPPL.VarInfo(rng, m);

julia> var_info[@varname(m)]
-0.6702516921145671

julia> var_info[@varname(x[1])]
-0.22312984965118443

julia> DynamicPPL.setval!(var_info, (m = 100.0, )); # set `m` and and keep `x[1]`

julia> var_info[@varname(m)] # [✓] changed
100.0

julia> var_info[@varname(x[1])] # [✓] unchanged
-0.22312984965118443

julia> m(rng, var_info); # rerun model

julia> var_info[@varname(m)] # [✓] unchanged
100.0

julia> var_info[@varname(x[1])] # [✓] unchanged
-0.22312984965118443
```
"""
setval!(vi::AbstractVarInfo, x) = _apply!(_setval_kernel!, vi, values(x), keys(x))
function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
return _apply!(_setval_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains))
end

function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
string_vn = string(vn)
string_vn_indexing = string_vn * "["
indices = findall(keys) do x
string_x = string(x)
return string_x == string_vn || startswith(string_x, string_vn_indexing)
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
if !isempty(indices)
sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural)
val = reduce(vcat, values[sorted_indices])
setval!(vi, val, vn)
settrans!(vi, false, vn)
end

return indices
end

"""
setval_and_resample!(vi::AbstractVarInfo, x)
setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx, chain_idx)

Set the values in `vi` to the provided values and those which are not present
in `x` or `chains` to *be* resampled.

Note that this does *not* resample the values not provided! It will call `setflag!(vi, vn, "del")`
for variables `vn` for which no values are provided, which means that the next time we call `model(vi)` these
variables will be resampled.

## Note
- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info.

## Example
```jldoctest
julia> using DynamicPPL, Distributions, StableRNGs

julia> @model function demo(x)
m ~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m, 1)
end
end;

julia> rng = StableRNG(42);

julia> m = demo([missing]);

julia> var_info = DynamicPPL.VarInfo(rng, m);

julia> var_info[@varname(m)]
-0.6702516921145671

julia> var_info[@varname(x[1])]
-0.22312984965118443

julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling

julia> var_info[@varname(m)] # [✓] changed
100.0

julia> var_info[@varname(x[1])] # [✓] unchanged
-0.22312984965118443

julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0`

julia> var_info[@varname(m)] # [✓] unchanged
100.0

julia> var_info[@varname(x[1])] # [✓] changed
101.37363069798343
```

## See also
- [`setval!`](@ref)
"""
setval_and_resample!(vi::AbstractVarInfo, x) = _apply!(_setval_and_resample_kernel!, vi, values(x), keys(x))
function setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
return _apply!(_setval_and_resample_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains))
end

function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
if !isempty(indices)
sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural)
val = mapreduce(vcat, sorted_indices) do i
values[i]
end
sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural)
val = reduce(vcat, values[sorted_indices])
setval!(vi, val, vn)
settrans!(vi, false, vn)
else
# Ensures that we'll resample the variable corresponding to `vn` if we run
# the model on `vi` again.
set_flag!(vi, vn, "del")
end

return indices
end
17 changes: 17 additions & 0 deletions src/varname.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
"""
subsumes_string(u::String, v::String[, u_indexing])

Check whether stringified variable name `v` describes a sub-range of stringified variable `u`.

This is a very restricted version `subumes(u::VarName, v::VarName)` only really supporting:
- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc.

## Note
- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)`
for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`,
and similarly to `v`. But this is slow.
"""
function subsumes_string(u::String, v::String, u_indexing=u * "[")
return u == v || startswith(v, u_indexing)
end

"""
inargnames(varname::VarName, model::Model)

Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -26,6 +27,7 @@ Documenter = "0.26.1"
ForwardDiff = "0.10.12"
MCMCChains = "4.0.4"
MacroTools = "0.5.5"
StableRNGs = "1"
Tracker = "0.2.11"
Zygote = "0.5.4, 0.6"
julia = "1.3"
Loading