Skip to content

Commit fc32398

Browse files
authored
AbstractPPL 0.11 + change prefixing behaviour (#830)
* AbstractPPL 0.11; change prefixing behaviour * Use DynamicPPL.prefix rather than overloading
1 parent bb59885 commit fc32398

16 files changed

+180
-140
lines changed

HISTORY.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,54 @@
11
# DynamicPPL Changelog
22

3+
## 0.36.0
4+
5+
**Breaking changes**
6+
7+
### VarName prefixing behaviour
8+
9+
The way in which VarNames in submodels are prefixed has been changed.
10+
This is best explained through an example.
11+
Consider this model and submodel:
12+
13+
```julia
14+
using DynamicPPL, Distributions
15+
@model inner() = x ~ Normal()
16+
@model outer() = a ~ to_submodel(inner())
17+
```
18+
19+
In previous versions, the inner variable `x` would be saved as `a.x`.
20+
However, this was represented as a single symbol `Symbol("a.x")`:
21+
22+
```julia
23+
julia> dump(keys(VarInfo(outer()))[1])
24+
VarName{Symbol("a.x"), typeof(identity)}
25+
optic: identity (function of type typeof(identity))
26+
```
27+
28+
Now, the inner variable is stored as a field `x` on the VarName `a`:
29+
30+
```julia
31+
julia> dump(keys(VarInfo(outer()))[1])
32+
VarName{:a, Accessors.PropertyLens{:x}}
33+
optic: Accessors.PropertyLens{:x} (@o _.x)
34+
```
35+
36+
In practice, this means that if you are trying to condition a variable in the submodel, you now need to use
37+
38+
```julia
39+
outer() | (@varname(a.x) => 1.0,)
40+
```
41+
42+
instead of either of these (which would have worked previously)
43+
44+
```julia
45+
outer() | (@varname(var"a.x") => 1.0,)
46+
outer() | (a.x=1.0,)
47+
```
48+
49+
If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain.
50+
(This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.)
51+
352
## 0.35.5
453

554
Several internal methods have been removed:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
4444
[compat]
4545
ADTypes = "1"
4646
AbstractMCMC = "5"
47-
AbstractPPL = "0.10.1"
47+
AbstractPPL = "0.11"
4848
Accessors = "0.1"
4949
BangBang = "0.4.1"
5050
Bijectors = "0.13.18, 0.14, 0.15"

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
44
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
55
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
66
DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177"
7+
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
78
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
89
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
910
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"

docs/src/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ In the past, one would instead embed sub-models using [`@submodel`](@ref), which
149149
In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing:
150150

151151
```@docs
152-
prefix
152+
DynamicPPL.prefix
153153
```
154154

155155
Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else

src/DynamicPPL.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ using DocStringExtensions
2121

2222
using Random: Random
2323

24+
# For extending
25+
import AbstractPPL: predict
26+
2427
# TODO: Remove these when it's possible.
2528
import Bijectors: link, invlink
2629

@@ -39,8 +42,6 @@ import Base:
3942
keys,
4043
haskey
4144

42-
import AbstractPPL: predict
43-
4445
# VarInfo
4546
export AbstractVarInfo,
4647
VarInfo,

src/contexts.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -260,25 +260,21 @@ function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix}
260260
return PrefixContext{Prefix}(child)
261261
end
262262

263-
const PREFIX_SEPARATOR = Symbol(".")
264-
265-
@generated function PrefixContext{PrefixOuter}(
266-
context::PrefixContext{PrefixInner}
267-
) where {PrefixOuter,PrefixInner}
268-
return :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}(
269-
context.context
270-
))
271-
end
263+
"""
264+
prefix(ctx::AbstractContext, vn::VarName)
272265
266+
Apply the prefixes in the context `ctx` to the variable name `vn`.
267+
"""
273268
function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
274-
vn_prefixed_inner = prefix(childcontext(ctx), vn)
275-
return VarName{Symbol(Prefix, PREFIX_SEPARATOR, getsym(vn_prefixed_inner))}(
276-
getoptic(vn_prefixed_inner)
277-
)
269+
return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Symbol(Prefix)}())
270+
end
271+
function prefix(ctx::AbstractContext, vn::VarName)
272+
return prefix(NodeTrait(ctx), ctx, vn)
278273
end
279-
prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn)
280274
prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn
281-
prefix(::IsParent, ctx::AbstractContext, vn::VarName) = prefix(childcontext(ctx), vn)
275+
function prefix(::IsParent, ctx::AbstractContext, vn::VarName)
276+
return prefix(childcontext(ctx), vn)
277+
end
282278

283279
"""
284280
prefix(model::Model, x)

src/debug_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ function DynamicPPL.setchildcontext(context::DebugContext, child)
183183
end
184184

185185
function record_varname!(context::DebugContext, varname::VarName, dist)
186-
prefixed_varname = prefix(context, varname)
186+
prefixed_varname = DynamicPPL.prefix(context, varname)
187187
if haskey(context.varnames_seen, prefixed_varname)
188188
if context.error_on_failure
189189
error("varname $prefixed_varname used multiple times in model")

src/model.jl

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ julia> model() ≠ 1.0
243243
true
244244
245245
julia> # To condition the variable inside `demo_inner` we need to refer to it as `inner.m`.
246-
conditioned_model = model | (var"inner.m" = 1.0, );
246+
conditioned_model = model | (@varname(inner.m) => 1.0, );
247247
248248
julia> conditioned_model()
249249
1.0
@@ -255,15 +255,6 @@ julia> conditioned_model_fail()
255255
ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported
256256
[...]
257257
```
258-
259-
And similarly when using `Dict`:
260-
261-
```jldoctest condition
262-
julia> conditioned_model_dict = model | (@varname(var"inner.m") => 1.0);
263-
264-
julia> conditioned_model_dict()
265-
1.0
266-
```
267258
"""
268259
function AbstractPPL.condition(model::Model, values...)
269260
# Positional arguments - need to handle cases carefully
@@ -443,16 +434,16 @@ julia> conditioned(cm)
443434
julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed,
444435
# `a.m` is treated as a random variable.
445436
keys(VarInfo(cm))
446-
1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}:
437+
1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}:
447438
a.m
448439
449440
julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation.
450-
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0);
441+
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext(Dict(@varname(a.m) => 1.0)))), x=100.0);
451442
452-
julia> conditioned(cm).x
443+
julia> conditioned(cm)[@varname(x)]
453444
100.0
454445
455-
julia> conditioned(cm).var"a.m"
446+
julia> conditioned(cm)[@varname(a.m)]
456447
1.0
457448
458449
julia> keys(VarInfo(cm)) # No variables are sampled
@@ -583,7 +574,7 @@ julia> model = demo_outer();
583574
julia> model() ≠ 1.0
584575
true
585576
586-
julia> fixed_model = fix(model, var"inner.m" = 1.0, );
577+
julia> fixed_model = fix(model, (@varname(inner.m) => 1.0, ));
587578
588579
julia> fixed_model()
589580
1.0
@@ -599,24 +590,9 @@ julia> fixed_model()
599590
2.0
600591
```
601592
602-
And similarly when using `Dict`:
603-
604-
```jldoctest fix
605-
julia> fixed_model_dict = fix(model, @varname(var"inner.m") => 1.0);
606-
607-
julia> fixed_model_dict()
608-
1.0
609-
610-
julia> fixed_model_dict = fix(model, @varname(inner) => 2.0);
611-
612-
julia> fixed_model_dict()
613-
2.0
614-
```
615-
616593
## Difference from `condition`
617594
618-
A very similar functionality is also provided by [`condition`](@ref) which,
619-
not surprisingly, _conditions_ variables instead of fixing them. The only
595+
A very similar functionality is also provided by [`condition`](@ref). The only
620596
difference between fixing and conditioning is as follows:
621597
- `condition`ed variables are considered to be observations, and are thus
622598
included in the computation [`logjoint`](@ref) and [`loglikelihood`](@ref),
@@ -798,16 +774,16 @@ julia> fixed(cm)
798774
julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed,
799775
# `a.m` is treated as a random variable.
800776
keys(VarInfo(cm))
801-
1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}:
777+
1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}:
802778
a.m
803779
804780
julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation.
805-
cm = fix(contextualize(m, PrefixContext{:a}(fix(var"a.m"=1.0))), x=100.0);
781+
cm = fix(contextualize(m, PrefixContext{:a}(fix(@varname(a.m) => 1.0,))), x=100.0);
806782
807-
julia> fixed(cm).x
783+
julia> fixed(cm)[@varname(x)]
808784
100.0
809785
810-
julia> fixed(cm).var"a.m"
786+
julia> fixed(cm)[@varname(a.m)]
811787
1.0
812788
813789
julia> keys(VarInfo(cm)) # <= no variables are sampled
@@ -1365,7 +1341,7 @@ When we sample from the model `demo2(missing, 0.4)` random variable `x` will be
13651341
```jldoctest submodel-to_submodel
13661342
julia> vi = VarInfo(demo2(missing, 0.4));
13671343
1368-
julia> @varname(var\"a.x\") in keys(vi)
1344+
julia> @varname(a.x) in keys(vi)
13691345
true
13701346
```
13711347
@@ -1379,7 +1355,7 @@ false
13791355
We can check that the log joint probability of the model accumulated in `vi` is correct:
13801356
13811357
```jldoctest submodel-to_submodel
1382-
julia> x = vi[@varname(var\"a.x\")];
1358+
julia> x = vi[@varname(a.x)];
13831359
13841360
julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
13851361
true
@@ -1417,10 +1393,10 @@ julia> @model function demo2(x, y, z)
14171393
14181394
julia> vi = VarInfo(demo2(missing, missing, 0.4));
14191395
1420-
julia> @varname(var"sub1.x") in keys(vi)
1396+
julia> @varname(sub1.x) in keys(vi)
14211397
true
14221398
1423-
julia> @varname(var"sub2.x") in keys(vi)
1399+
julia> @varname(sub2.x) in keys(vi)
14241400
true
14251401
```
14261402
@@ -1437,9 +1413,9 @@ false
14371413
We can check that the log joint probability of the model accumulated in `vi` is correct:
14381414
14391415
```jldoctest submodel-to_submodel-prefix
1440-
julia> sub1_x = vi[@varname(var"sub1.x")];
1416+
julia> sub1_x = vi[@varname(sub1.x)];
14411417
1442-
julia> sub2_x = vi[@varname(var"sub2.x")];
1418+
julia> sub2_x = vi[@varname(sub2.x)];
14431419
14441420
julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x);
14451421

src/submodel_macro.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,10 @@ julia> vi = VarInfo(demo2(missing, missing, 0.4));
9696
│ caller = ip:0x0
9797
└ @ Core :-1
9898
99-
julia> @varname(var"sub1.x") in keys(vi)
99+
julia> @varname(sub1.x) in keys(vi)
100100
true
101101
102-
julia> @varname(var"sub2.x") in keys(vi)
102+
julia> @varname(sub2.x) in keys(vi)
103103
true
104104
```
105105
@@ -116,9 +116,9 @@ false
116116
We can check that the log joint probability of the model accumulated in `vi` is correct:
117117
118118
```jldoctest submodelprefix
119-
julia> sub1_x = vi[@varname(var"sub1.x")];
119+
julia> sub1_x = vi[@varname(sub1.x)];
120120
121-
julia> sub2_x = vi[@varname(var"sub2.x")];
121+
julia> sub2_x = vi[@varname(sub2.x)];
122122
123123
julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x);
124124
@@ -157,7 +157,7 @@ julia> # Automatically determined from `a`.
157157
@model submodel_prefix_true() = @submodel prefix=true a = inner()
158158
submodel_prefix_true (generic function with 2 methods)
159159
160-
julia> @varname(var"a.x") in keys(VarInfo(submodel_prefix_true()))
160+
julia> @varname(a.x) in keys(VarInfo(submodel_prefix_true()))
161161
┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.
162162
│ caller = ip:0x0
163163
└ @ Core :-1
@@ -167,7 +167,7 @@ julia> # Using a static string.
167167
@model submodel_prefix_string() = @submodel prefix="my prefix" a = inner()
168168
submodel_prefix_string (generic function with 2 methods)
169169
170-
julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string()))
170+
julia> @varname(var"my prefix".x) in keys(VarInfo(submodel_prefix_string()))
171171
┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.
172172
│ caller = ip:0x0
173173
└ @ Core :-1
@@ -177,7 +177,7 @@ julia> # Using string interpolation.
177177
@model submodel_prefix_interpolation() = @submodel prefix="\$(nameof(inner()))" a = inner()
178178
submodel_prefix_interpolation (generic function with 2 methods)
179179
180-
julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation()))
180+
julia> @varname(inner.x) in keys(VarInfo(submodel_prefix_interpolation()))
181181
┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.
182182
│ caller = ip:0x0
183183
└ @ Core :-1
@@ -187,7 +187,7 @@ julia> # Or using some arbitrary expression.
187187
@model submodel_prefix_expr() = @submodel prefix=1 + 2 a = inner()
188188
submodel_prefix_expr (generic function with 2 methods)
189189
190-
julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr()))
190+
julia> @varname(var"3".x) in keys(VarInfo(submodel_prefix_expr()))
191191
┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.
192192
│ caller = ip:0x0
193193
└ @ Core :-1

src/utils.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,14 +1285,18 @@ broadcast_safe(x) = x
12851285
broadcast_safe(x::Distribution) = (x,)
12861286
broadcast_safe(x::AbstractContext) = (x,)
12871287

1288+
# Convert (x=1,) to Dict(@varname(x) => 1)
1289+
_nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt))
12881290
# Version of `merge` used by `conditioned` and `fixed` to handle
12891291
# the scenario where we might try to merge a dict with an empty
12901292
# tuple.
12911293
# TODO: Maybe replace the default of returning `NamedTuple` with `nothing`?
12921294
_merge(left::NamedTuple, right::NamedTuple) = merge(left, right)
12931295
_merge(left::AbstractDict, right::AbstractDict) = merge(left, right)
1294-
_merge(left::AbstractDict, right::NamedTuple{()}) = left
1295-
_merge(left::NamedTuple{()}, right::AbstractDict) = right
1296+
_merge(left::AbstractDict, ::NamedTuple{()}) = left
1297+
_merge(left::AbstractDict, right::NamedTuple) = merge(left, _nt_to_varname_dict(right))
1298+
_merge(::NamedTuple{()}, right::AbstractDict) = right
1299+
_merge(left::NamedTuple, right::AbstractDict) = merge(_nt_to_varname_dict(left), right)
12961300

12971301
"""
12981302
unique_syms(vns::T) where {T<:NTuple{N,VarName}}

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3232
[compat]
3333
ADTypes = "1"
3434
AbstractMCMC = "5"
35-
AbstractPPL = "0.10.1"
35+
AbstractPPL = "0.11"
3636
Accessors = "0.1"
3737
Aqua = "0.8"
3838
Bijectors = "0.15.1"

0 commit comments

Comments
 (0)