Skip to content
Merged
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# DynamicPPL Changelog

## 0.38.4

Improve performance of VarNamedVector. It should now be very nearly on par with Metadata for all models we've benchmarked on.

## 0.38.3

Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.38.3"
version = "0.38.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ DynamicPPL.reset!
DynamicPPL.update!
DynamicPPL.insert!
DynamicPPL.loosen_types!!
DynamicPPL.tighten_types
Comment on lines 416 to -417
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two don't actually ever mutate their inputs. Hence the !! is questionable. However, they sometimes return the original object, sometimes an object that shares memory with the original object, so you should use them like you use !! functions: Always catch the return value and never assume that the return value is independent of the input.

DynamicPPL.tighten_types!!
```

```@docs
Expand Down
4 changes: 3 additions & 1 deletion src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ function tilde_assume!!(
end
# Neither of these set the `trans` flag so we have to do it manually if
# necessary.
insert_transformed_value && set_transformed!!(vi, true, vn)
if insert_transformed_value
vi = set_transformed!!(vi, true, vn)
end
# `accumulate_assume!!` wants untransformed values as the second argument.
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
# We always return the untransformed value here, as that will determine
Expand Down
4 changes: 2 additions & 2 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ add_io_context(io::IO) = IOContext(io, :compact => true, :limit => true)
show_varname(io::IO, varname::VarName) = print(io, varname)
function show_varname(io::IO, varname::Array{<:VarName,N}) where {N}
# Attempt to make the type concrete in case the symbol is shared.
return _show_varname(io, map(identity, varname))
return _show_varname(io, [vn for vn in varname])
end
function _show_varname(io::IO, varname::Array{<:VarName,N}) where {N}
# Print the first and last element of the array.
Expand Down Expand Up @@ -407,7 +407,7 @@ julia> @model function demo_incorrect()
end
demo_incorrect (generic function with 2 methods)

julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually
julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually
# alert us to the issue of `x` being sampled twice.
model = demo_incorrect(); varinfo = VarInfo(model);

Expand Down
6 changes: 3 additions & 3 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ box:
- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring
any effects of linking
- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected
by linking, since transforms are only applied to random variables)
by linking, since transforms are only applied to random variables)

!!! note
By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the
Expand Down Expand Up @@ -146,7 +146,7 @@ struct LogDensityFunction{
is_supported(adtype) ||
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
# Get a set of dummy params to use for prep
x = map(identity, varinfo[:])
x = [val for val in varinfo[:]]
if use_closure(adtype)
prep = DI.prepare_gradient(
LogDensityAt(model, getlogdensity, varinfo), adtype, x
Expand Down Expand Up @@ -282,7 +282,7 @@ function LogDensityProblems.logdensity_and_gradient(
) where {M,F,V,AD<:ADTypes.AbstractADType}
f.prep === nothing &&
error("Gradient preparation not available; this should not happen")
x = map(identity, x) # Concretise type
x = [val for val in x] # Concretise type
# Make branching statically inferrable, i.e. type-stable (even if the two
# branches happen to return different types)
return if use_closure(f.adtype)
Expand Down
1 change: 1 addition & 0 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName)
"Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.",
)
end
return vi
end

is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
Expand Down
4 changes: 2 additions & 2 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ Everything else is optional, and can be categorised into several groups:
1. _How to specify the results to compare against._

Once logp and its gradient has been calculated with the specified `adtype`,
it can optionally be tested for correctness. The exact way this is tested
it can optionally be tested for correctness. The exact way this is tested
is specified in the `test` parameter.

There are several options for this:
Expand Down Expand Up @@ -260,7 +260,7 @@ function run_ad(
if isnothing(params)
params = varinfo[:]
end
params = map(identity, params) # Concretise
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a little benchmark of which of these two ways is faster for turning e.g. Vector{Any}[1, 2] into Vector{Int}[1, 2]. The answer is that the comprehension is a tiny bit faster. Since I was changing that in VNV, I also changed it everywhere else where it's used.

params = [p for p in params] # Concretise

# Calculate log-density and gradient with the backend of interest
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
Expand Down
72 changes: 50 additions & 22 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ function untyped_vector_varinfo(
model::Model,
init_strategy::AbstractInitStrategy=InitFromPrior(),
)
return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy))
return last(init!!(rng, model, VarInfo(VarNamedVector()), init_strategy))
end
function untyped_vector_varinfo(
model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()
Expand Down Expand Up @@ -789,18 +789,24 @@ function setval!(md::Metadata, val, vn::VarName)
return md.vals[getrange(md, vn)] = tovec(val)
end

function set_transformed!!(vi::NTVarInfo, val::Bool, vn::VarName)
md = set_transformed!!(getmetadata(vi, vn), val, vn)
return Accessors.@set vi.metadata[getsym(vn)] = md
end

function set_transformed!!(vi::VarInfo, val::Bool, vn::VarName)
set_transformed!!(getmetadata(vi, vn), val, vn)
return vi
md = set_transformed!!(getmetadata(vi, vn), val, vn)
return VarInfo(md, vi.accs)
end

function set_transformed!!(metadata::Metadata, val::Bool, vn::VarName)
metadata.is_transformed[getidx(metadata, vn)] = val
return metadata
end

function set_transformed!!(vi::VarInfo, val::Bool)
for vn in keys(vi)
set_transformed!!(vi, val, vn)
vi = set_transformed!!(vi, val, vn)
end
Comment on lines 807 to 810
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how Julia works on this front, but is it ever possible that changing what vi is bound to might invalidate the iteration over keys(vi)? I think probably not, but ...??

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's not so much rebinding vi that's the issue, but if set_transformed!! were to mutate vi, then that might be problematic. I suppose it can't mutate it in a way such that keys(vi) would result in something different, so this probably can't error. But I'm a bit paranoid about this kind of code structure. Obviously, this predates this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine, the keys(vi) call I think is evaluated once when the loop is started, and it'll keep its reference to the original vi even if the name vi is then bound to something else. Like in

julia> d = Dict(:a => 1)
Dict{Symbol, Int64} with 1 entry:
  :a => 1

julia> ks = keys(d)
KeySet for a Dict{Symbol, Int64} with 1 entry. Keys:
  :a

julia> d = Dict(:b => 1)
Dict{Symbol, Int64} with 1 entry:
  :b => 1

julia> collect(ks)
1-element Vector{Symbol}:
 :a

keys(vi) is a lazy iterator though, so what I would feel dicey about is changing the set of keys in vi mid-iteration. This is a thing:

julia> d = Dict(:a => 1)
Dict{Symbol, Int64} with 1 entry:
  :a => 1

julia> ks = keys(d)
KeySet for a Dict{Symbol, Int64} with 1 entry. Keys:
  :a

julia> d[:b] = 2
2

julia> collect(ks)
2-element Vector{Symbol}:
 :a
 :b

which can result in weirdness:

julia> d = Dict{Any,Int}(:a => 1, :b => 2)
Dict{Any, Int64} with 2 entries:
  :a => 1
  :b => 2

julia> for k in keys(d)
           d[repr(k)] = d[k]
       end

julia> d
Dict{Any, Int64} with 7 entries:
  :a               => 1
  :b               => 2
  "\":b\""         => 2
  ":b"             => 2
  "\"\\\":a\\\"\"" => 1
  ":a"             => 1
  "\":a\""         => 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think from a programming perspective, maybe the safest way would be to do a version of keys(vi) that's eagerly evaluated, maybe collect(keys(vi))?

But also from a human perspective it's quite easy to see that keys(vi) can't be changed, so also happy to let it slide.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep it as is, because I'm pretty convinced it's okay (and not due to an implementation detail, but because set_transformed!! has no business touching the keys), and lazy iterators are nice. Could add the call to collect if it bothers you.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't bother me enough to want to do it.


return vi
Expand Down Expand Up @@ -977,7 +983,7 @@ function filter_subsumed(filter_vns, filtered_vns)
end

@generated function _link!!(
::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names}
::NamedTuple{metadata_names}, vi, varnames::NamedTuple{vns_names}
) where {metadata_names,vns_names}
expr = Expr(:block)
for f in metadata_names
Expand All @@ -988,7 +994,7 @@ end
expr.args,
quote
f_vns = vi.metadata.$f.vns
f_vns = filter_subsumed(vns.$f, f_vns)
f_vns = filter_subsumed(varnames.$f, f_vns)
if !isempty(f_vns)
if !is_transformed(vi, f_vns[1])
# Iterate over all `f_vns` and transform
Expand Down Expand Up @@ -1652,30 +1658,47 @@ end
Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to
the `VarInfo` `vi`, mutating if it makes sense.
"""
function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution)
if vi isa UntypedVarInfo
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist"
elseif vi isa NTVarInfo
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist"
end
function BangBang.push!!(vi::VarInfo, vn::VarName, val, dist::Distribution)
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist"
md = push!!(getmetadata(vi, vn), vn, val, dist)
return VarInfo(md, vi.accs)
end

function BangBang.push!!(vi::NTVarInfo, vn::VarName, val, dist::Distribution)
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist"
sym = getsym(vn)
if vi isa NTVarInfo && ~haskey(vi.metadata, sym)
meta = if ~haskey(vi.metadata, sym)
# The NamedTuple doesn't have an entry for this variable, let's add one.
val = tovec(r)
md = Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false]))
vi = Accessors.@set vi.metadata[sym] = md
_new_submetadata(vi, vn, val, dist)
else
meta = getmetadata(vi, vn)
push!(meta, vn, r, dist)
push!!(getmetadata(vi, vn), vn, val, dist)
end

vi = Accessors.@set vi.metadata[sym] = meta
return vi
end

function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...)
push!(getmetadata(vi, vn), vn, val, args...)
return vi
"""
_new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, args...) where {Names,SubMetas}

Create a new sub-metadata for an NTVarInfo. The type is chosen by the types of existing
SubMetas.
"""
@generated function _new_submetadata(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generated function is needed because we need to check the types in the NamedTuple to see if any of the other sub-metadatas are VNVs, to know whether to make a new one be a VNV or not. This should have always been done, you can see how Metadata used to be hardcoded as the sub-metadata type in push!! before this PR. I hope to get rid of this later by either having only one possible sub-metadata type, or having the VarInfo have a field that specifies this (like what make_leaf does in #1074).

vi::VarInfo{NamedTuple{Names,SubMetas}}, vn, r, dist
) where {Names,SubMetas}
has_vnv = any(s -> s <: VarNamedVector, SubMetas.parameters)
return if has_vnv
:(return _new_vnv_submetadata(vn, r, dist))
else
:(return _new_metadata_submetadata(vn, r, dist))
end
end

_new_vnv_submetadata(vn, r, _) = VarNamedVector([vn], [r])

function _new_metadata_submetadata(vn, r, dist)
val = tovec(r)
return Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false]))
end

function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...)
Expand All @@ -1700,6 +1723,11 @@ function Base.push!(meta::Metadata, vn, r, dist)
return meta
end

function BangBang.push!!(meta::Metadata, vn, r, dist)
push!(meta, vn, r, dist)
return meta
end

function Base.delete!(vi::VarInfo, vn::VarName)
delete!(getmetadata(vi, vn), vn)
return vi
Expand Down
Loading