Skip to content

Commit 01bf0bc

Browse files
authored
Replace Metadata.flags with Metadata.trans (#1060)
* Replace Medata.flags with Metadata.trans * Fix a bug * Fix a typo * Fix two bugs * Rename trans to is_transformed * Rename islinked to is_transformed, remove duplication
1 parent 0fa5540 commit 01bf0bc

18 files changed

+193
-272
lines changed

HISTORY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ The separation of these functions was primarily implemented to avoid performing
5454

5555
Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed.
5656

57+
The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus the generic functions `set_flag!`, `unset_flag!` and `is_flagged!` have also been removed in favour of more specific ones. We've also used this opportunity to name the `"trans"` flag and the corresponding `istrans` function to be more explicit. The new, exported interface consists of the `is_transformed` and `set_transformed!!` functions.
58+
5759
### Removal of `resume_from`
5860

5961
The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead.

docs/src/api.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,8 @@ The [Transformations section below](#Transformations) describes the methods used
345345
In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions.
346346

347347
```@docs
348-
set_flag!
349-
unset_flag!
350-
is_flagged
348+
is_transformed
349+
set_transformed!!
351350
```
352351

353352
```@docs
@@ -439,8 +438,6 @@ DynamicPPL.StaticTransformation
439438
```
440439

441440
```@docs
442-
DynamicPPL.istrans
443-
DynamicPPL.settrans!!
444441
DynamicPPL.transformation
445442
DynamicPPL.link
446443
DynamicPPL.invlink

ext/DynamicPPLEnzymeCoreExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ else
88
using ..EnzymeCore
99
end
1010

11-
# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
11+
# Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme
1212
# only checks whether such a method exists, and never runs it.
13-
@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.istrans), args...) = nothing
13+
@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) =
14+
nothing
1415

1516
end

ext/DynamicPPLMooncakeExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
module DynamicPPLMooncakeExt
22

3-
using DynamicPPL: DynamicPPL, istrans
3+
using DynamicPPL: DynamicPPL, is_transformed
44
using Mooncake: Mooncake
55

66
# This is purely an optimisation.
7-
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}
7+
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg}
88

99
end # module

src/DynamicPPL.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,8 @@ export AbstractVarInfo,
7070
acclogjac!!,
7171
acclogprior!!,
7272
accloglikelihood!!,
73-
is_flagged,
74-
set_flag!,
75-
unset_flag!,
76-
istrans,
73+
is_transformed,
74+
set_transformed!!,
7775
link,
7876
link!!,
7977
invlink,

src/abstract_varinfo.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ end
769769

770770
# Transformations
771771
"""
772-
istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])
772+
is_transformed(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])
773773
774774
Return `true` if `vi` is working in unconstrained space, and `false`
775775
if `vi` is assuming realizations to be in support of the corresponding distributions.
@@ -780,27 +780,27 @@ If `vns` is provided, then only check if this/these varname(s) are transformed.
780780
Not all implementations of `AbstractVarInfo` support transforming only a subset of
781781
the variables.
782782
"""
783-
istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi)))
784-
function istrans(vi::AbstractVarInfo, vns::AbstractVector)
785-
# This used to be: `!isempty(vns) && all(Base.Fix1(istrans, vi), vns)`.
783+
is_transformed(vi::AbstractVarInfo) = is_transformed(vi, collect(keys(vi)))
784+
function is_transformed(vi::AbstractVarInfo, vns::AbstractVector)
785+
# This used to be: `!isempty(vns) && all(Base.Fix1(is_transformed, vi), vns)`.
786786
# In theory that should work perfectly fine. For unbeknownst reasons,
787787
# Julia 1.10 fails to infer its return type correctly. Thus we use this
788788
# slightly longer definition.
789789
isempty(vns) && return false
790790
for vn in vns
791-
istrans(vi, vn) || return false
791+
is_transformed(vi, vn) || return false
792792
end
793793
return true
794794
end
795795

796796
"""
797-
settrans!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName])
797+
set_transformed!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName])
798798
799-
Return `vi` with `istrans(vi, vn)` evaluating to `true`.
799+
Return `vi` with `is_transformed(vi, vn)` evaluating to `true`.
800800
801-
If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variables.
801+
If `vn` is not specified, then `is_transformed(vi)` evaluates to `true` for all variables.
802802
"""
803-
function settrans!! end
803+
function set_transformed!! end
804804

805805
# For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback
806806
# method for the case when no `vns` is provided, that would get all the keys from the
@@ -832,7 +832,7 @@ function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
832832
# has a dedicated implementation
833833
model = setleafcontext(model, DynamicTransformationContext{false}())
834834
vi = last(evaluate!!(model, vi))
835-
return settrans!!(vi, t)
835+
return set_transformed!!(vi, t)
836836
end
837837
function link!!(
838838
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
@@ -845,7 +845,7 @@ function link!!(
845845
if hasacc(vi, Val(:LogJacobian))
846846
vi = acclogjac!!(vi, logjac)
847847
end
848-
return settrans!!(vi, t)
848+
return set_transformed!!(vi, t)
849849
end
850850

851851
"""
@@ -894,7 +894,7 @@ function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model)
894894
# has a dedicated implementation
895895
model = setleafcontext(model, DynamicTransformationContext{true}())
896896
vi = last(evaluate!!(model, vi))
897-
return settrans!!(vi, NoTransformation())
897+
return set_transformed!!(vi, NoTransformation())
898898
end
899899
function invlink!!(
900900
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
@@ -910,7 +910,7 @@ function invlink!!(
910910
if hasacc(vi, Val(:LogJacobian))
911911
vi = acclogjac!!(vi, inv_logjac)
912912
end
913-
return settrans!!(vi, NoTransformation())
913+
return set_transformed!!(vi, NoTransformation())
914914
end
915915

916916
"""
@@ -1018,7 +1018,7 @@ function unflatten end
10181018
"""
10191019
to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)
10201020
1021-
Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`.
1021+
Return reconstructed `val`, possibly linked if `is_transformed(vi, vn)` is `true`.
10221022
"""
10231023
function to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)
10241024
f = to_maybe_linked_internal_transform(vi, vn, dist)
@@ -1028,7 +1028,7 @@ end
10281028
"""
10291029
from_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)
10301030
1031-
Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`.
1031+
Return reconstructed `val`, possibly invlinked if `is_transformed(vi, vn)` is `true`.
10321032
"""
10331033
function from_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)
10341034
f = from_maybe_linked_internal_transform(vi, vn, dist)
@@ -1085,14 +1085,14 @@ in `varinfo` to a representation compatible with `dist`.
10851085
If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`.
10861086
"""
10871087
function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
1088-
return if istrans(varinfo, vn)
1088+
return if is_transformed(varinfo, vn)
10891089
from_linked_internal_transform(varinfo, vn, dist)
10901090
else
10911091
from_internal_transform(varinfo, vn, dist)
10921092
end
10931093
end
10941094
function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName)
1095-
return if istrans(varinfo, vn)
1095+
return if is_transformed(varinfo, vn)
10961096
from_linked_internal_transform(varinfo, vn)
10971097
else
10981098
from_internal_transform(varinfo, vn)

src/contexts/init.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ function tilde_assume!!(
163163
# If the VarInfo alrady had a value for this variable, we will
164164
# keep the same linked status as in the original VarInfo. If not, we
165165
# check the rest of the VarInfo to see if other variables are linked.
166-
# istrans(vi) returns true if vi is nonempty and all variables in vi
166+
# is_transformed(vi) returns true if vi is nonempty and all variables in vi
167167
# are linked.
168-
insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi)
168+
insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi)
169169
f = if insert_transformed_value
170170
link_transform(dist)
171171
else
@@ -181,7 +181,7 @@ function tilde_assume!!(
181181
end
182182
# Neither of these set the `trans` flag so we have to do it manually if
183183
# necessary.
184-
insert_transformed_value && settrans!!(vi, true, vn)
184+
insert_transformed_value && set_transformed!!(vi, true, vn)
185185
# `accumulate_assume!!` wants untransformed values as the second argument.
186186
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
187187
# We always return the untransformed value here, as that will determine

src/contexts/transformation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function tilde_assume!!(
2121
# vi[vn, right] always provides the value in unlinked space.
2222
x = vi[vn, right]
2323

24-
if istrans(vi, vn)
24+
if is_transformed(vi, vn)
2525
isinverse || @warn "Trying to link an already transformed variable ($vn)"
2626
else
2727
isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"

src/simple_varinfo.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,23 @@ julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo());
9696
julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞
9797
1.8632965762164932
9898
99-
julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
99+
julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true));
100100
101101
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
102102
-0.21080155351918753
103103
104-
julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
104+
julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
105105
106106
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
107107
true
108108
109109
julia> # And with `OrderedDict` of course!
110-
_, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
110+
_, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
111111
112112
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
113113
0.6225185067787314
114114
115-
julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
115+
julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
116116
117117
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
118118
true
@@ -121,15 +121,15 @@ true
121121
Evaluation in transformed space of course also works:
122122
123123
```jldoctest simplevarinfo-general
124-
julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true)
124+
julia> vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), true)
125125
Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0)))
126126
127127
julia> # (✓) Positive probability mass on negative numbers!
128128
getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi)))
129129
-1.3678794411714423
130130
131131
julia> # While if we forget to indicate that it's transformed:
132-
vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false)
132+
vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), false)
133133
SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0)))
134134
135135
julia> # (✓) No probability mass on negative numbers!
@@ -466,32 +466,32 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
466466
return SimpleVarInfo(values, accs, transformation)
467467
end
468468

469-
function settrans!!(vi::SimpleVarInfo, trans)
470-
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
469+
function set_transformed!!(vi::SimpleVarInfo, trans)
470+
return set_transformed!!(vi, trans ? DynamicTransformation() : NoTransformation())
471471
end
472-
function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation)
472+
function set_transformed!!(vi::SimpleVarInfo, transformation::AbstractTransformation)
473473
return Accessors.@set vi.transformation = transformation
474474
end
475-
function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
476-
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans)
475+
function set_transformed!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
476+
return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, trans)
477477
end
478-
function settrans!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName)
478+
function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName)
479479
# We keep this method around just to obey the AbstractVarInfo interface.
480480
# However, note that this would only be a valid operation if it would be a
481481
# no-op, which we check here.
482-
if trans != istrans(vi)
482+
if trans != is_transformed(vi)
483483
error(
484-
"Individual variables in SimpleVarInfo cannot have different `settrans` statuses.",
484+
"Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.",
485485
)
486486
end
487487
end
488488

489-
istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
490-
istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi)
491-
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)
492-
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo)
493-
494-
islinked(vi::SimpleVarInfo) = istrans(vi)
489+
is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
490+
is_transformed(vi::SimpleVarInfo, ::VarName) = is_transformed(vi)
491+
function is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName)
492+
return is_transformed(vi.varinfo, vn)
493+
end
494+
is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = is_transformed(vi.varinfo)
495495

496496
values_as(vi::SimpleVarInfo) = vi.values
497497
values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values
@@ -618,7 +618,7 @@ function link!!(
618618
if hasacc(vi_new, Val(:LogJacobian))
619619
vi_new = acclogjac!!(vi_new, logjac)
620620
end
621-
return settrans!!(vi_new, t)
621+
return set_transformed!!(vi_new, t)
622622
end
623623

624624
function invlink!!(
@@ -636,7 +636,7 @@ function invlink!!(
636636
if hasacc(vi_new, Val(:LogJacobian))
637637
vi_new = acclogjac!!(vi_new, inv_logjac)
638638
end
639-
return settrans!!(vi_new, NoTransformation())
639+
return set_transformed!!(vi_new, NoTransformation())
640640
end
641641

642642
# With `SimpleVarInfo`, when we're not working with linked variables, there's no need to do anything.

src/threadsafe.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)
8080
keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)
8181
haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn)
8282

83-
islinked(vi::ThreadSafeVarInfo) = islinked(vi.varinfo)
83+
is_transformed(vi::ThreadSafeVarInfo) = is_transformed(vi.varinfo)
8484

8585
function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
8686
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...)
@@ -104,12 +104,12 @@ end
104104
# to define `getacc(vi)`.
105105
function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
106106
model = setleafcontext(model, DynamicTransformationContext{false}())
107-
return settrans!!(last(evaluate!!(model, vi)), t)
107+
return set_transformed!!(last(evaluate!!(model, vi)), t)
108108
end
109109

110110
function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
111111
model = setleafcontext(model, DynamicTransformationContext{true}())
112-
return settrans!!(last(evaluate!!(model, vi)), NoTransformation())
112+
return set_transformed!!(last(evaluate!!(model, vi)), NoTransformation())
113113
end
114114

115115
function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
@@ -181,20 +181,15 @@ end
181181
values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo)
182182
values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T)
183183

184-
function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
185-
return unset_flag!(vi.varinfo, vn, flag)
186-
end
187-
function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
188-
return is_flagged(vi.varinfo, vn, flag)
184+
function set_transformed!!(vi::ThreadSafeVarInfo, val::Bool, vn::VarName)
185+
return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, val, vn)
189186
end
190187

191-
function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName)
192-
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn)
188+
is_transformed(vi::ThreadSafeVarInfo, vn::VarName) = is_transformed(vi.varinfo, vn)
189+
function is_transformed(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName})
190+
return is_transformed(vi.varinfo, vns)
193191
end
194192

195-
istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
196-
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)
197-
198193
getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn)
199194

200195
function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector)

0 commit comments

Comments
 (0)