-
Notifications
You must be signed in to change notification settings - Fork 37
Improvements to VarNamedVector #1098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
150de71
30ac1d0
4ae0c6d
4c8b006
c8b0b88
1f7152b
61c96b0
2a10be9
bb83d93
99a7d32
513edc5
d30eca8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -197,7 +197,7 @@ Everything else is optional, and can be categorised into several groups: | |
| 1. _How to specify the results to compare against._ | ||
|
|
||
| Once logp and its gradient has been calculated with the specified `adtype`, | ||
| it can optionally be tested for correctness. The exact way this is tested | ||
| it can optionally be tested for correctness. The exact way this is tested | ||
| is specified in the `test` parameter. | ||
|
|
||
| There are several options for this: | ||
|
|
@@ -260,7 +260,7 @@ function run_ad( | |
| if isnothing(params) | ||
| params = varinfo[:] | ||
| end | ||
| params = map(identity, params) # Concretise | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did a little benchmark of which of these two ways is faster for turning e.g. |
||
| params = [p for p in params] # Concretise | ||
|
|
||
| # Calculate log-density and gradient with the backend of interest | ||
| verbose && @info "Running AD on $(model.f) with $(adtype)\n" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -315,7 +315,7 @@ function untyped_vector_varinfo( | |
| model::Model, | ||
| init_strategy::AbstractInitStrategy=InitFromPrior(), | ||
| ) | ||
| return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy)) | ||
| return last(init!!(rng, model, VarInfo(VarNamedVector()), init_strategy)) | ||
| end | ||
| function untyped_vector_varinfo( | ||
| model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() | ||
|
|
@@ -789,18 +789,24 @@ function setval!(md::Metadata, val, vn::VarName) | |
| return md.vals[getrange(md, vn)] = tovec(val) | ||
| end | ||
|
|
||
| function set_transformed!!(vi::NTVarInfo, val::Bool, vn::VarName) | ||
| md = set_transformed!!(getmetadata(vi, vn), val, vn) | ||
| return Accessors.@set vi.metadata[getsym(vn)] = md | ||
| end | ||
|
|
||
| function set_transformed!!(vi::VarInfo, val::Bool, vn::VarName) | ||
| set_transformed!!(getmetadata(vi, vn), val, vn) | ||
| return vi | ||
| md = set_transformed!!(getmetadata(vi, vn), val, vn) | ||
| return VarInfo(md, vi.accs) | ||
| end | ||
|
|
||
| function set_transformed!!(metadata::Metadata, val::Bool, vn::VarName) | ||
| metadata.is_transformed[getidx(metadata, vn)] = val | ||
| return metadata | ||
| end | ||
|
|
||
| function set_transformed!!(vi::VarInfo, val::Bool) | ||
| for vn in keys(vi) | ||
| set_transformed!!(vi, val, vn) | ||
| vi = set_transformed!!(vi, val, vn) | ||
| end | ||
|
Comment on lines
807
to
810
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know how Julia works on this front, but is it ever possible that changing what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it's not so much rebinding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is fine, the julia> d = Dict(:a => 1)
Dict{Symbol, Int64} with 1 entry:
:a => 1
julia> ks = keys(d)
KeySet for a Dict{Symbol, Int64} with 1 entry. Keys:
:a
julia> d = Dict(:b => 1)
Dict{Symbol, Int64} with 1 entry:
:b => 1
julia> collect(ks)
1-element Vector{Symbol}:
:a
julia> d = Dict(:a => 1)
Dict{Symbol, Int64} with 1 entry:
:a => 1
julia> ks = keys(d)
KeySet for a Dict{Symbol, Int64} with 1 entry. Keys:
:a
julia> d[:b] = 2
2
julia> collect(ks)
2-element Vector{Symbol}:
:a
:bwhich can result in weirdness: julia> d = Dict{Any,Int}(:a => 1, :b => 2)
Dict{Any, Int64} with 2 entries:
:a => 1
:b => 2
julia> for k in keys(d)
d[repr(k)] = d[k]
end
julia> d
Dict{Any, Int64} with 7 entries:
:a => 1
:b => 2
"\":b\"" => 2
":b" => 2
"\"\\\":a\\\"\"" => 1
":a" => 1
"\":a\"" => 1There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think from a programming perspective, maybe the safest way would be to do a version of But also from a human perspective it's quite easy to see that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would keep it as is, because I'm pretty convinced it's okay (and not due to an implementation detail, but because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't bother me enough to want to do it. |
||
|
|
||
| return vi | ||
|
|
@@ -977,7 +983,7 @@ function filter_subsumed(filter_vns, filtered_vns) | |
| end | ||
|
|
||
| @generated function _link!!( | ||
| ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} | ||
| ::NamedTuple{metadata_names}, vi, varnames::NamedTuple{vns_names} | ||
| ) where {metadata_names,vns_names} | ||
| expr = Expr(:block) | ||
| for f in metadata_names | ||
|
|
@@ -988,7 +994,7 @@ end | |
| expr.args, | ||
| quote | ||
| f_vns = vi.metadata.$f.vns | ||
| f_vns = filter_subsumed(vns.$f, f_vns) | ||
| f_vns = filter_subsumed(varnames.$f, f_vns) | ||
| if !isempty(f_vns) | ||
| if !is_transformed(vi, f_vns[1]) | ||
| # Iterate over all `f_vns` and transform | ||
|
|
@@ -1652,30 +1658,47 @@ end | |
| Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to | ||
| the `VarInfo` `vi`, mutating if it makes sense. | ||
| """ | ||
| function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) | ||
| if vi isa UntypedVarInfo | ||
| @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" | ||
| elseif vi isa NTVarInfo | ||
| @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" | ||
| end | ||
| function BangBang.push!!(vi::VarInfo, vn::VarName, val, dist::Distribution) | ||
| @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" | ||
| md = push!!(getmetadata(vi, vn), vn, val, dist) | ||
| return VarInfo(md, vi.accs) | ||
| end | ||
|
|
||
| function BangBang.push!!(vi::NTVarInfo, vn::VarName, val, dist::Distribution) | ||
| @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" | ||
| sym = getsym(vn) | ||
| if vi isa NTVarInfo && ~haskey(vi.metadata, sym) | ||
| meta = if ~haskey(vi.metadata, sym) | ||
| # The NamedTuple doesn't have an entry for this variable, let's add one. | ||
| val = tovec(r) | ||
| md = Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) | ||
| vi = Accessors.@set vi.metadata[sym] = md | ||
| _new_submetadata(vi, vn, val, dist) | ||
| else | ||
| meta = getmetadata(vi, vn) | ||
| push!(meta, vn, r, dist) | ||
| push!!(getmetadata(vi, vn), vn, val, dist) | ||
| end | ||
|
|
||
| vi = Accessors.@set vi.metadata[sym] = meta | ||
| return vi | ||
| end | ||
|
|
||
| function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...) | ||
| push!(getmetadata(vi, vn), vn, val, args...) | ||
| return vi | ||
| """ | ||
| _new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, args...) where {Names,SubMetas} | ||
|
|
||
| Create a new sub-metadata for an NTVarInfo. The type is chosen by the types of existing | ||
| SubMetas. | ||
| """ | ||
| @generated function _new_submetadata( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The generated function is needed because we need to check the types in the NamedTuple to see if any of the other sub-metadatas are VNVs, to know whether to make a new one be a VNV or not. This should have always been done, you can see how Metadata used to be hardcoded as the sub-metadata type in |
||
| vi::VarInfo{NamedTuple{Names,SubMetas}}, vn, r, dist | ||
| ) where {Names,SubMetas} | ||
| has_vnv = any(s -> s <: VarNamedVector, SubMetas.parameters) | ||
| return if has_vnv | ||
| :(return _new_vnv_submetadata(vn, r, dist)) | ||
| else | ||
| :(return _new_metadata_submetadata(vn, r, dist)) | ||
| end | ||
| end | ||
|
|
||
| _new_vnv_submetadata(vn, r, _) = VarNamedVector([vn], [r]) | ||
|
|
||
| function _new_metadata_submetadata(vn, r, dist) | ||
| val = tovec(r) | ||
| return Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) | ||
| end | ||
|
|
||
| function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...) | ||
|
|
@@ -1700,6 +1723,11 @@ function Base.push!(meta::Metadata, vn, r, dist) | |
| return meta | ||
| end | ||
|
|
||
| function BangBang.push!!(meta::Metadata, vn, r, dist) | ||
| push!(meta, vn, r, dist) | ||
| return meta | ||
| end | ||
|
|
||
| function Base.delete!(vi::VarInfo, vn::VarName) | ||
| delete!(getmetadata(vi, vn), vn) | ||
| return vi | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two don't actually ever mutate their inputs. Hence the
!!is questionable. However, they sometimes return the original object, sometimes an object that shares memory with the original object, so you should use them like you use!!functions: Always catch the return value and never assume that the return value is independent of the input.