Skip to content

Commit 2cc99b3

Browse files
authored
Try #278:
2 parents b82459a + 22fcae8 commit 2cc99b3

File tree

9 files changed

+181
-16
lines changed

9 files changed

+181
-16
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.12.3"
55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
77
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
8+
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
89
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
910
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1011
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/DynamicPPL.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Distributions
66
using Bijectors
77

88
using AbstractMCMC: AbstractMCMC
9+
using BangBang: BangBang
910
using ChainRulesCore: ChainRulesCore
1011
using MacroTools: MacroTools
1112
using ZygoteRules: ZygoteRules
@@ -67,6 +68,7 @@ export AbstractVarInfo,
6768
vectorize,
6869
# Model
6970
Model,
71+
ContextualModel,
7072
getmissings,
7173
getargnames,
7274
generated_quantities,
@@ -81,6 +83,7 @@ export AbstractVarInfo,
8183
PriorContext,
8284
MiniBatchContext,
8385
PrefixContext,
86+
ConditionContext,
8487
assume,
8588
dot_assume,
8689
observe,
@@ -99,6 +102,8 @@ export AbstractVarInfo,
99102
logprior,
100103
logjoint,
101104
pointwise_loglikelihoods,
105+
condition,
106+
decondition,
102107
# Convenience macros
103108
@addlogprob!,
104109
@submodel
@@ -129,5 +134,6 @@ include("prob_macro.jl")
129134
include("compat/ad.jl")
130135
include("loglikelihoods.jl")
131136
include("submodel_macro.jl")
137+
include("contextual_model.jl")
132138

133139
end # module

src/compiler.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ function isassumption(expr::Union{Symbol,Expr})
2323
# This branch should compile nicely in all cases except for partial missing data
2424
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
2525
if !$(DynamicPPL.inargnames)($vn, __model__) ||
26-
$(DynamicPPL.inmissings)($vn, __model__)
26+
$(DynamicPPL.inmissings)($vn, __model__) ||
27+
$(DynamicPPL.contextual_isassumption)(__context__, $vn)
2728
true
2829
else
2930
# Evaluate the LHS
@@ -33,6 +34,9 @@ function isassumption(expr::Union{Symbol,Expr})
3334
end
3435
end
3536

37+
contextual_isassumption(context::AbstractContext, vn) = false
38+
contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn))
39+
3640
# failsafe: a literal is never an assumption
3741
isassumption(expr) = :(false)
3842

@@ -437,11 +441,14 @@ function build_output(modelinfo, linenumbernode)
437441
modeldef[:body] = MacroTools.@q begin
438442
$(linenumbernode)
439443
$evaluator = $(MacroTools.combinedef(evaluatordef))
440-
return $(DynamicPPL.Model)(
441-
$(QuoteNode(modeldef[:name])),
442-
$evaluator,
444+
return $(DynamicPPL.condition)(
445+
$(DynamicPPL.Model)(
446+
$(QuoteNode(modeldef[:name])),
447+
$evaluator,
448+
$allargs_namedtuple,
449+
$defaults_namedtuple,
450+
),
443451
$allargs_namedtuple,
444-
$defaults_namedtuple,
445452
)
446453
end
447454

src/context_implementations.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,28 @@ function tilde_assume!(context, right, vn, inds, vi)
133133
return value
134134
end
135135

136+
function tilde_assume!(context::ConditionContext, right, vn, inds, vi)
137+
if haskey(context, vn)
138+
# Extract value.
139+
value = if inds isa Tuple{}
140+
getfield(context.values, getsym(vn))
141+
else
142+
_getindex(getfield(context.values, getsym(vn)), inds)
143+
end
144+
145+
# Should we even do this?
146+
if haskey(vi, vn)
147+
vi[vn] = value
148+
end
149+
150+
tilde_observe!(context.context, right, value, vn, inds, vi)
151+
else
152+
value = tilde_assume!(context.context, right, vn, inds, vi)
153+
end
154+
155+
return value
156+
end
157+
136158
# observe
137159
"""
138160
tilde_observe(context::SamplingContext, right, left, vname, vinds, vi)
@@ -217,6 +239,10 @@ function tilde_observe!(context, right, left, vi)
217239
return left
218240
end
219241

242+
function tilde_observe!(context::ConditionContext, right, left, vi)
243+
return tilde_observe!(context.context, right, left, vi)
244+
end
245+
220246
function assume(rng, spl::Sampler, dist)
221247
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
222248
end
@@ -419,6 +445,28 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi)
419445
return value
420446
end
421447

448+
function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi)
449+
if haskey(context, vn)
450+
# Extract value.
451+
value = if inds isa Tuple{}
452+
getfield(context.values, sym)
453+
else
454+
_getindex(getfield(context.values, sym), inds)
455+
end
456+
457+
# Should we even do this?
458+
if haskey(vi, vn)
459+
vi[vn] = value
460+
end
461+
462+
dot_tilde_observe!(context.context, right, value, vn, inds, vi)
463+
else
464+
value = dot_tilde_assume!(context.context, right, left, vn, inds, vi)
465+
end
466+
467+
return value
468+
end
469+
422470
# `dot_assume`
423471
function dot_assume(
424472
dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi
@@ -637,6 +685,9 @@ function dot_tilde_observe!(context, right, left, vi)
637685
acclogp!(vi, logp)
638686
return left
639687
end
688+
function dot_tilde_observe!(context::ConditionContext, right, left, vi)
689+
return dot_tilde_observe!(context.context, right, left, vi)
690+
end
640691

641692
# Falls back to non-sampler definition.
642693
function dot_observe(::AbstractSampler, dist, value, vi)

src/contexts.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,73 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
106106
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing)
107107
end
108108
end
109+
110+
struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext
111+
values::Values
112+
context::Ctx
113+
114+
function ConditionContext{Values}(
115+
values::Values, context::AbstractContext
116+
) where {names,Values<:NamedTuple{names}}
117+
return new{names,typeof(values),typeof(context)}(values, context)
118+
end
119+
end
120+
121+
@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values}
122+
names_expr = Expr(:tuple)
123+
values_expr = Expr(:tuple)
124+
125+
for (n, v) in zip(names, values.parameters)
126+
if !(v <: Missing)
127+
push!(names_expr.args, QuoteNode(n))
128+
push!(values_expr.args, :(nt.$n))
129+
end
130+
end
131+
132+
return :(NamedTuple{$names_expr}($values_expr))
133+
end
134+
135+
function ConditionContext(context::ConditionContext, child_context::AbstractContext)
136+
return ConditionContext(context.values, child_context)
137+
end
138+
function ConditionContext(values::NamedTuple)
139+
return ConditionContext(values, DefaultContext())
140+
end
141+
142+
function ConditionContext(values::NamedTuple, context::AbstractContext)
143+
values_wo_missing = drop_missings(values)
144+
return ConditionContext{typeof(values_wo_missing)}(values_wo_missing, context)
145+
end
146+
147+
# Try to avoid nested `ConditionContext`.
148+
function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) where {Vars}
149+
# Note that this potentially overrides values from `context`, thus giving
150+
# precedence to the outmost `ConditionContext`.
151+
return ConditionContext(merge(context.values, values), context.context)
152+
end
153+
154+
function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym}
155+
# TODO: Add possibility of indexed variables, e.g. `x[1]`, etc.
156+
return sym in vars
157+
end
158+
159+
function Base.haskey(
160+
context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}
161+
) where {vars,sym}
162+
# TODO: Add possibility of indexed variables, e.g. `x[1]`, etc.
163+
return sym in vars
164+
end
165+
166+
# TODO: Can we maybe do this in a better way?
167+
# When no second argument is given, we remove _all_ conditioned variables.
168+
# TODO: Should we remove this and just return `context.context`?
169+
# That will work better if `Model` becomes like `ContextualModel`.
170+
decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context)
171+
function decondition(context::ConditionContext, sym)
172+
return ConditionContext(BangBang.delete!!(context.values, sym), context.context)
173+
end
174+
function decondition(context::ConditionContext, sym, syms...)
175+
return decondition(
176+
ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms...
177+
)
178+
end

src/contextual_model.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel
2+
context::Ctx
3+
model::M
4+
end
5+
6+
function contextualize(model::AbstractModel, context::AbstractContext)
7+
return ContextualModel(context, model)
8+
end
9+
10+
# TODO: What do we do for other contexts? Could handle this in general if we had a
11+
# notion of wrapper-, primitive-context, etc.
12+
function (cmodel::ContextualModel{<:ConditionContext})(
13+
varinfo::AbstractVarInfo, context::AbstractContext
14+
)
15+
# Wrap `context` in the model-associated `ConditionContext`, but now using `context` as
16+
# `ConditionContext` child.
17+
return cmodel.model(varinfo, ConditionContext(cmodel.context.values, context))
18+
end
19+
20+
condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values))
21+
condition(model::AbstractModel; values...) = condition(model, (; values...))
22+
function condition(cmodel::ContextualModel{<:ConditionContext}, values)
23+
return contextualize(cmodel.model, ConditionContext(values, cmodel.context))
24+
end
25+
26+
decondition(model::AbstractModel, args...) = model
27+
function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...)
28+
return contextualize(cmodel.model, decondition(cmodel.context, syms...))
29+
end

src/model.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
abstract type AbstractModel <: AbstractProbabilisticProgram end
2+
13
"""
24
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults}
35
name::Symbol
@@ -32,8 +34,7 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition
3234
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
3335
```
3436
"""
35-
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <:
36-
AbstractProbabilisticProgram
37+
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel
3738
name::Symbol
3839
f::F
3940
args::NamedTuple{argnames,Targs}
@@ -82,7 +83,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a
8283
The method resets the log joint probability of `varinfo` and increases the evaluation
8384
number of `sampler`.
8485
"""
85-
function (model::Model)(
86+
function (model::AbstractModel)(
8687
rng::Random.AbstractRNG,
8788
varinfo::AbstractVarInfo=VarInfo(),
8889
sampler::AbstractSampler=SampleFromPrior(),
@@ -91,7 +92,7 @@ function (model::Model)(
9192
return model(varinfo, SamplingContext(rng, sampler, context))
9293
end
9394

94-
(model::Model)(context::AbstractContext) = model(VarInfo(), context)
95+
(model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context)
9596
function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext)
9697
if Threads.nthreads() == 1
9798
return evaluate_threadunsafe(model, varinfo, context)
@@ -100,17 +101,17 @@ function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext)
100101
end
101102
end
102103

103-
function (model::Model)(args...)
104+
function (model::AbstractModel)(args...)
104105
return model(Random.GLOBAL_RNG, args...)
105106
end
106107

107108
# without VarInfo
108-
function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...)
109+
function (model::AbstractModel)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...)
109110
return model(rng, VarInfo(), sampler, args...)
110111
end
111112

112113
# without VarInfo and without AbstractSampler
113-
function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext)
114+
function (model::AbstractModel)(rng::Random.AbstractRNG, context::AbstractContext)
114115
return model(rng, VarInfo(), SampleFromPrior(), context)
115116
end
116117

src/varinfo.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,18 +124,18 @@ end
124124

125125
function VarInfo(
126126
rng::Random.AbstractRNG,
127-
model::Model,
127+
model::AbstractModel,
128128
sampler::AbstractSampler=SampleFromPrior(),
129129
context::AbstractContext=DefaultContext(),
130130
)
131131
varinfo = VarInfo()
132132
model(rng, varinfo, sampler, context)
133133
return TypedVarInfo(varinfo)
134134
end
135-
VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...)
135+
VarInfo(model::AbstractModel, args...) = VarInfo(Random.GLOBAL_RNG, model, args...)
136136

137137
# without AbstractSampler
138-
function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext)
138+
function VarInfo(rng::Random.AbstractRNG, model::AbstractModel, context::AbstractContext)
139139
return VarInfo(rng, model, SampleFromPrior(), context)
140140
end
141141

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2222
AbstractMCMC = "2.1, 3.0"
2323
AbstractPPL = "0.1.3"
2424
Bijectors = "0.9.5"
25-
Distributions = "0.24, 0.25"
25+
Distributions = "< 0.25.11"
2626
DistributionsAD = "0.6.3"
2727
Documenter = "0.26.1, 0.27"
2828
ForwardDiff = "0.10.12"

0 commit comments

Comments
 (0)