Skip to content

Commit 80cf12d

Browse files
mhaurupenelopeysm
andauthored
Remove unnecessary consistency checks for VarNamedVector (#1092)
* Remove unnecessary consistency checks for VarNamedVector * Fix benchmark setting Co-authored-by: Penelope Yong <penelopeysm@gmail.com> * Fix typo * Add two benchmarks * Improvements to VarNamedVector (#1098) * Change VNV to use Dict rather than OrderedDict * Change concretisation from map(identity, x) to a comprehension * Improve tighten_types!! and loosen_types!! * Fix use of set_transformed!! * Fix push!! for VarInfos * Change the default element types in VNV to be Union{} * In untyped_vector_varinfo, don't rely on Metadata * Code style * Run formatter * In VNV, use typejoin rather than promote_type * Bump patch version to 0.38.4 --------- Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
1 parent 90b591b commit 80cf12d

File tree

14 files changed

+350
-178
lines changed

14 files changed

+350
-178
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.38.4
4+
5+
Improve performance of VarNamedVector. It should now be very nearly on par with Metadata for all models we've benchmarked on.
6+
37
## 0.38.3
48

59
Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.38.3"
3+
version = "0.38.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

benchmarks/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ LogDensityProblems = "2.1.2"
3030
Mooncake = "0.4"
3131
PrettyTables = "3"
3232
ReverseDiff = "1.15.3"
33-
StableRNGs = "1"
33+
StableRNGs = "1"

benchmarks/benchmarks.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ chosen_combinations = [
6262
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
6363
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
6464
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
65+
("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true),
66+
("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true),
6567
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),
6668
("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true),
6769
("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true),

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8080
retvals = model(rng)
8181
vns = [VarName{k}() for k in keys(retvals)]
8282
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
83+
elseif varinfo_choice == :typed_vector
84+
DynamicPPL.typed_vector_varinfo(rng, model)
85+
elseif varinfo_choice == :untyped_vector
86+
DynamicPPL.untyped_vector_varinfo(rng, model)
8387
else
8488
error("Unknown varinfo choice: $varinfo_choice")
8589
end

docs/src/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ DynamicPPL.reset!
414414
DynamicPPL.update!
415415
DynamicPPL.insert!
416416
DynamicPPL.loosen_types!!
417-
DynamicPPL.tighten_types
417+
DynamicPPL.tighten_types!!
418418
```
419419

420420
```@docs

src/contexts/init.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ function tilde_assume!!(
180180
end
181181
# Neither of these set the `trans` flag so we have to do it manually if
182182
# necessary.
183-
insert_transformed_value && set_transformed!!(vi, true, vn)
183+
if insert_transformed_value
184+
vi = set_transformed!!(vi, true, vn)
185+
end
184186
# `accumulate_assume!!` wants untransformed values as the second argument.
185187
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
186188
# We always return the untransformed value here, as that will determine

src/debug_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ add_io_context(io::IO) = IOContext(io, :compact => true, :limit => true)
2727
show_varname(io::IO, varname::VarName) = print(io, varname)
2828
function show_varname(io::IO, varname::Array{<:VarName,N}) where {N}
2929
# Attempt to make the type concrete in case the symbol is shared.
30-
return _show_varname(io, map(identity, varname))
30+
return _show_varname(io, [vn for vn in varname])
3131
end
3232
function _show_varname(io::IO, varname::Array{<:VarName,N}) where {N}
3333
# Print the first and last element of the array.
@@ -407,7 +407,7 @@ julia> @model function demo_incorrect()
407407
end
408408
demo_incorrect (generic function with 2 methods)
409409
410-
julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually
410+
julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually
411411
# alert us to the issue of `x` being sampled twice.
412412
model = demo_incorrect(); varinfo = VarInfo(model);
413413

src/logdensityfunction.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ box:
4949
- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring
5050
any effects of linking
5151
- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected
52-
by linking, since transforms are only applied to random variables)
52+
by linking, since transforms are only applied to random variables)
5353
5454
!!! note
5555
By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the
@@ -146,7 +146,7 @@ struct LogDensityFunction{
146146
is_supported(adtype) ||
147147
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
148148
# Get a set of dummy params to use for prep
149-
x = map(identity, varinfo[:])
149+
x = [val for val in varinfo[:]]
150150
if use_closure(adtype)
151151
prep = DI.prepare_gradient(
152152
LogDensityAt(model, getlogdensity, varinfo), adtype, x
@@ -282,7 +282,7 @@ function LogDensityProblems.logdensity_and_gradient(
282282
) where {M,F,V,AD<:ADTypes.AbstractADType}
283283
f.prep === nothing &&
284284
error("Gradient preparation not available; this should not happen")
285-
x = map(identity, x) # Concretise type
285+
x = [val for val in x] # Concretise type
286286
# Make branching statically inferrable, i.e. type-stable (even if the two
287287
# branches happen to return different types)
288288
return if use_closure(f.adtype)

src/simple_varinfo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName)
484484
"Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.",
485485
)
486486
end
487+
return vi
487488
end
488489

489490
is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)

0 commit comments

Comments
 (0)