-
Notifications
You must be signed in to change notification settings - Fork 37
condition and decondition
#278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Tor, it looks good overall. I like this approach. See below for some minor comments.
| vectorize, | ||
| # Model | ||
| Model, | ||
| ContextualModel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe keep this internal until we have a good case of exporting it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think any case that can be made for exporting Model now also can be made for ContextualModel, don't you think?
| PriorContext, | ||
| MiniBatchContext, | ||
| PrefixContext, | ||
| ConditionContext, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, maybe keep this internal until we have a good case of exporting it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason why I exported it is because essentially all other contexts are being exported. It also makes show nicer, tbh. But I'm also fine with not exporting, up to you 👍
src/compiler.jl
Outdated
| $(DynamicPPL.Model)( | ||
| $(QuoteNode(modeldef[:name])), | ||
| $evaluator, | ||
| $allargs_namedtuple, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are passing allargs_namedtuple to both the model constructor and the condition call - is it because we are still trying to be backwards compatible with the current design?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BUT I would say that we should remove this initially. If the function defined by @model suddenly returns ContextualModel, I'm going to wager that a lot of dispatches in the ecosystem will be messed up.
So IMO:
- Introduce this PR but without conditioning by default, i.e.
@modelstill defines a method returningModel. - We get to try out
conditionanddeconditionto see how things work. - If everything works out nicely, we fully replace
Model(then we could even addcontextas an additional field toModelrather than introducingContextualModel)
src/context_implementations.jl
Outdated
| return value | ||
| end | ||
|
|
||
| function tilde_assume!(context::ConditionContext, right, vn, inds, vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| function tilde_assume!(context::ConditionContext, right, vn, inds, vi) | |
| function tilde_assume_or_observe!(context::ConditionContext, right, vn, inds, vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, what's the intention with this renaming?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume (:stuck_out_tongue_closed_eyes:) it's because now the function does not only implement assumption anymore. But it is a weirdly mixed case anyway, since missings are either defined by model arguments or by the context, right?
IMHO the cleanest solution would be to handle the distinction entirely in isassumption. Isn't that possible? Is there a case when e.g. the isassumption check takes the true branch, but then this tilde_assume! call resorts to tilde_observe! anyway?
(See also my non-review comment below.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice one 😎
But no that won't work unfortunately. We can't call tilde_observe! without left, and left isn't defined until we've extracted it from VarInfo or __context__.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I see. But that starts to get rather unclean. Why not rewrite tildes somewhat like that, calculating the "observation value" beforehand or setting it to missing:
var"##vn#261" = (VarName){:x}()
var"##observationvalue#263" = begin
getobsvalueormissing(__model__, __context__, var"##vn#261")
end
if var"##observationvalue#263" === missing
x = (DynamicPPL.tilde_assume!)(
__context__,
(DynamicPPL.unwrap_right_vn)((DynamicPPL.check_tilde_rhs)(filldist(Normal(m, sqrt(s)), length(y))), var"##vn#261")...,
var"##inds#262",
__varinfo__
)
else
(DynamicPPL.tilde_observe!)(
__context__,
(DynamicPPL.check_tilde_rhs)(filldist(Normal(m, sqrt(s)), length(y))),
var"##observationvalue#263",
var"##vn#261",
var"##inds#262",
__varinfo__
)
endThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I actually started writing a comment immediately after posting my preivous one where I started saying "Of course we can extract the value in the model-macro before the tilde-statement and then do something nifty there" but gave up because I'm struggling to come up with a sufficiently robust approach. And IMO it needs to be a very good approach vs. intercepting the tilde-calls because intercepting the tilde-calls is non-intrusive while doing something like the above touches the DPPL-compiler directly.
And I'm not 100% sure I agree it's "unclean". IMO we should be allowed to switch between sampling and evaluation anywhere in the tilde-pipeline, and evaluation is essentially just an observe-statement with the exception that we might extract the value from VarInfo.
src/context_implementations.jl
Outdated
| return value | ||
| end | ||
|
|
||
| function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) | |
| function dot_tilde_assume_or_observe!(context::ConditionContext, right, left, vn, inds, vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above: a bit confused what you mean here.
src/contexts.jl
Outdated
| end | ||
| end | ||
|
|
||
| struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
| struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext | |
| struct ConditionContext{VarNames, Values,Ctx<:AbstractContext} <: AbstractContext |
| # When no second argument is given, we remove _all_ conditioned variables. | ||
| # TODO: Should we remove this and just return `context.context`? | ||
| # That will work better if `Model` becomes like `ContextualModel`. | ||
| decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename this to _decondition to avoid confusion with decondition(m, ...)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO the different types for the arguments already disambiguiates the two 😕
| @@ -1,3 +1,5 @@ | |||
| abstract type AbstractModel <: AbstractProbabilisticProgram end | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we set up a type alias instead of creating a new level of the hierarchy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately not because then methods such as
function (model::AbstractModel)(args...)
return model(Random.GLOBAL_RNG, args...)
endbecome type-piracy.
|
bors try |
tryBuild failed: |
…l into tor/conditioning
|
bors try |
tryBuild failed: |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
|
So I thought further on the "maybe we should extract the value in the model rather than in the tilde-statment?" idea. This is now what I've done in this PR.
All in all, this means that now the following two instances of models are equivalent: julia> @model function demo_noargs()
m ~ Normal()
x ~ Normal(m, 1)
return (; m, x)
end
demo (generic function with 2 methods)
julia> condition(demo_noargs(), x=1.0)()
(m = 1.3184454383513486, x = 1.0)
julia> @model function demo(x)
m ~ Normal()
x ~ Normal(m, 1)
return (; m, x)
end
demo (generic function with 2 methods)
julia> demo(1.0)()
(m = 0.8496181905704883, x = 1.0)The only drawback is that we cannot use this to change model arguments, i.e. neither julia> decondition(demo(1.0), :x)()
(m = 1.6693305580373252, x = 1.0)
julia> condition(demo(1.0), x=2.0)()
(m = 1.3370773180628, x = 1.0)In contrast, both of these work for the "no-argument" model: julia> decondition(condition(demo_noargs(), x=1.0), :x)()
(m = 0.16963191023957844, x = -0.5504657603545225)
julia> condition(condition(demo_noargs(), x=1.0), x=2.0)()
(m = -0.5525977807677651, x = 2.0)Essentially, we'd get everything that we desire (i.e. ability to We could potentially start experimenting with approaches that could allow this, e.g. we can imagine overloading On the flip-side this works very well even for submodels, etc.: julia> @model function inner(x)
m ~ Normal()
x ~ Normal(m, 1)
return (; m, x)
end
inner (generic function with 1 method)
julia> @model function outer(x)
y = Vector(undef, length(x))
for i in eachindex(x)
y[i] = @submodel $(Symbol("y[$i]")) inner(x[i])
end
return y
end
outer (generic function with 1 method)
julia> m = outer([1.0, 2.0])
Model{var"#43#44", (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}}(:outer, var"#43#44"(), (x = [1.0, 2.0],), NamedTuple())
julia> condition(m, var"y[1].m" = 100.0)()
2-element Vector{Any}:
(m = 100.0, x = 1.0)
(m = -0.002383907227840093, x = 2.0)Drawbacks
Benefits
|
| # TODO: Support by adding context to model, and use `model.args` | ||
| # as the default conditioning. Then we no longer need to check `inargnames` | ||
| # since it will all be handled by `contextual_isassumption`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO this should be deferred for now. It will require a bit of thinking I believe.
| # If `vn` is not in `argnames`, we need to make sure that the variable is defined. | ||
| if !$(DynamicPPL.inargnames)($vn, __model__) | ||
| $left = $(DynamicPPL.getvalue)(__context__, $vn) | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will compile away when possible, just as isassumption, and will preserve existing functionality/meaning of model arguments hence will be non-breaking.
| function getvalue(context::ConditionContext, vn) | ||
| return if haskey(context, vn) | ||
| _getvalue(context.values, vn) | ||
| else | ||
| getvalue(context.context, vn) | ||
| end | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This currently behaves a bit strangely in certain cases:
julia> @model function outer_with_inner_condition(x)
y = Vector(undef, length(x))
for i in eachindex(x)
# Here we'll end up with `PrefixContext{..., <:ConditionContext}`
y[i] = @submodel $(Symbol("y[$i]")) condition(inner(x[i]), m=10.0)
end
return y
end
outer_with_inner_condition (generic function with 1 method)
julia> m = outer_with_inner_condition([1.0, 2.0])
Model{var"#49#50", (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}}(:outer_with_inner_condition, var"#49#50"(), (x = [1.0, 2.0],), NamedTuple())
julia> m() # (✓) both `m` are set as specified within `outer_with_inner_condition`
2-element Vector{Any}:
(m = 10.0, x = 1.0)
(m = 10.0, x = 2.0)
julia> condition(m, var"y[1].m"=5.0)() # (×) try to override one of the inner `m`
2-element Vector{Any}:
(m = 10.0, x = 1.0)
(m = 10.0, x = 2.0)But we can change getvalue to:
function getvalue(context::ConditionContext, vn)
# Return early if we've already found our value, thus giving precedence
# to the inner-most `ConditionContext`.
maybeval = getvalue(context.context, vn)
maybeval === nothing || return maybeval
return haskey(context, vn) ? _getvalue(context.values, vn) : nothing
endfor which we then get the "desired" behavior in the above example:
julia> condition(m, var"y[1].m"=5.0)() # (✓) try to override one of the inner `m`
2-element Vector{Any}:
(m = 5.0, x = 1.0)
(m = 10.0, x = 2.0)but this does mean that we give precedence to the inner-most ConditionContext rather than the outermost, i.e. the one that was applied most recently, which is counter-intuitive and can be difficult to work with.
We could potentially introduce functionality that will recurse into the context-tree to remove all other mentionings of the variables, i.e. at most one CouplingContext in the context-tree contains a mentioning of any specific variable. cthis also makes sense when considering decondition(ctx, :x) since we'd expect this to traverse the entire context-tree and remove every mentioning of x rather than only doing so in the outer-most ConditionContext. Buuuut this brings us back to #254 again, since it would require functionality for such a tree-traversal and other things.
src/contextual_model.jl
Outdated
| function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) | ||
| # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as | ||
| # `ConditionContext` child. | ||
| return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) | |
| return _evaluate( | |
| cmodel.model, varinfo, ConditionContext(cmodel.context.values, context) | |
| ) |
src/contexts.jl
Outdated
| @generated function drop_missings(nt::NamedTuple{names,values}) where {names,values} | ||
| names_expr = Expr(:tuple) | ||
| values_expr = Expr(:tuple) | ||
|
|
||
| for (n, v) in zip(names, values.parameters) | ||
| if !(v <: Missing) | ||
| push!(names_expr.args, QuoteNode(n)) | ||
| push!(values_expr.args, :(nt.$n)) | ||
| end | ||
| end | ||
|
|
||
| return :(NamedTuple{$names_expr}($values_expr)) | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is useful in this PR anymore. It's useful if we want to replace missing with !haskey(conditioncontext) but as I've said before, I think there are a bit too many corner cases that needs to be resolved before we make such an intrusive change.
| The default implementation for `AbstractContext` always returns `true`. | ||
| """ | ||
| contextual_isassumption(context::AbstractContext, vn) = true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is too general, e.g. we want this for DefaultContext and other "leaf" contexts but not for "parent" contexts, e.g. MiniBatchContext.
| # TODO: Can we maybe do this in a better way? | ||
| # When no second argument is given, we remove _all_ conditioned variables. | ||
| # TODO: Should we remove this and just return `context.context`? | ||
| # That will work better if `Model` becomes like `ContextualModel`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be addressed nicely with #286
I've been saying that, in different terms, a couple of times already :P The "automatically deduce everything from arguments" style is kinda nice from a user perspective, but makes things hard to handle internally.
And now we're at the point of stringly submodel variables, of which I have also warned... I understand that the NamedTuple + I love where this is going to, though. |
Yeah, yeah am aware:) This was referring to this particular PR and it's implementation, just making explicit the limitations of it. But I think it's fine here because everything else is left as-is, meaning the current behavior of doing
Exactly: we can easily make this easier to specify, i.e. using a macro. The point here is to introduce the functionality, then we can make it nicer later.
❤️ Btw, take a look at #287 . This PR has a lot of shortcomings in how it interacts with other PRs, i.e. there will be errors. That PR addresses those nicely. |
|
Closed in favour of #294 |
This is motivated by the potential introduction of more contexts, e.g. #278, and has been brought up as an alternative (and better) approach to achieve parts of what we want to achieve in #254 . (I hope you're proud of me @devmotion ) ## Current state of things Currently, if one wants to implement a new `AbstractContext` one _at least_ has to implement the following methods: ```julia tilde_assume(...) tilde_observe(...) dot_tilde_assume(...) dot_tilde_observe(...) ``` But there are also other methods that _should_ be implemented but generally aren't properly handled, e.g. `matchingvalue`. And there might be more methods in the future, e.g. `contextual_isassumption` in #254. ## This sucks This means that: 1. Implementing a new behavior for `AbstractContext`, e.g. `contextual_isassumption`, requires you to: 1. Find all implementations of `AbstractContext`, which is non-trivial! Most are here in DPPL, but some are in Turing.jl, and eventually we also want packages outside of the Turing.jl-umbrella to extend DPPL using contexts. 2. Implement the method for that particular context. 2. Implementing a new `AbstractContext`, e.g. `Turing.OptimizationContext`, requires you to find all the methods to implement and then do so. Again, non-trivial. This combinatorial blow up essentially means that we're super-reluctant to introduce new behaviors or new contexts, for good reasons. And the stupid thing is that in most cases a context is only trying to modify maybe one or two "behaviors", e.g. `MiniBatchContext` only wants to change the `*tilde_observe` methods, and otherwise just defer to whatever implementation is available for its "childcontext" `minibatchcontext.context`. ## Goal A new `AbstractContext` should only (up to an additive factor) have to implement the behavior it _wants to change_, not _all_ behaviors. E.g. `MiniBatchContext` should really only have to overload `*tilde_observe`. ## Solution (this PR) The above was the motivation for #254, but there we wanted a rather strict separation between certain types of contexts which we're reluctant to add due to its restrictive nature (in particular given how "recently" contexts where introduced). This PR takes are more extensible and less restrictive approach of introducing some traits for `AbstractContext`. As a starter I've just introduced a couple of traits to allow code-sharing between "parent contexts", e.g. `MiniBatchContext` and `ConditionContext` (from #278), and "leaf-contexts", e.g. `DefaultContext` and `PriorContext` (which IMO should have been a wrapper-context itself). Ideally we'd also define a promotion-system, e.g. what do we do if we're asked to combine a `MiniBatchContext` and `DefaultContext`? Well in that case we could either 1. replace `minibatchcontext.context` with `DefaultContext()`, or 2. recursively "rewrap" `DefaultContext()` in `minibatch.context`. (1) has the issue that `MiniBatchContext` might be wrapping another context, e.g. `PrefixContext`, and so just replacing `.context` is dangerous. (2) seems like a better idea since `DefaultContext` should always be at the end of the stack, i.e. a "leaf", since it always exists the tilde-callstack (e.g. in `tilde_assume` we call `assume`, etc.). Such a promotion system will require some thought though, but this PR will allow us to experiment with this (on top of providing a good approach to code-sharing). ## Examples ### `GeneratedQuantitiesContext` In DPPL we have the `generated_quantities` but it sort of sucks because often these quantities are not relevant for sampling, thus we're adding unnecessary computation to the sampling process. One might want want to introduce a `@generated_quantities` macro that will only be executed when called from `generated_quantities`. This can easily be achieved with contexts: ```julia using DynamicPPL, Distributions, Random using DynamicPPL: AbstractContext, IsLeaf, IsParent, childcontext struct GeneratedQuantitiesContext{Ctx} <: AbstractContext context::Ctx end GeneratedQuantitiesContext() = GeneratedQuantitiesContext(DefaultContext()) # Define the `NodeTrait` for `GeneratedQuantitiesContext`. DynamicPPL.NodeTrait(context::GeneratedQuantitiesContext) = IsParent() DynamicPPL.childcontext(context::GeneratedQuantitiesContext) = context.context """ isgeneratedquantities(context) Return `true` if `context` wants evaluation of model to execute the `@generatedquantities` block. """ function isgeneratedquantities(context::AbstractContext) return isgeneratedquantities(DynamicPPL.NodeTrait(isgeneratedquantities, context), context) end # Define the behavior for the different `NodeType`s. isgeneratedquantities(::IsLeaf, context::AbstractContext) = false function isgeneratedquantities(::IsParent, context::AbstractContext) return isgeneratedquantities(childcontext(context)) end # Specific implementations of `isgeneratedquantities`. isgeneratedquantities(context::GeneratedQuantitiesContext) = true """ @generatedquantities f(x) Specify that `f(x)` should only if the model is run using `GeneratedQuantitiesContext`. """ macro generated_quantities(expr) return esc(generated_quantities_expr(expr)) end function generated_quantities_expr(expr) return quote if isgeneratedquantities(__context__) $expr end end end ``` And usage would be as follows: ```julia julia> @model function demo(x) @generated_quantities a = [] m ~ Normal() # "Expensive" piece of code that we don't want to compute # unless we're computing the generated quantities. @generated_quantities for i = 1:100 push!(a, m + randn()) end # Observe. x ~ Normal(m, 1.0) # Return additional fields if we're computing the generated quantities. @generated_quantities return (; x, m, logp = getlogp(__varinfo__), a) return nothing end demo (generic function with 1 method) julia> m = demo(1.0); var_info = VarInfo(m); julia> m(var_info) # (✓) returns nothing julia> m(var_info, GeneratedQuantitiesContext()) # (✓) returns everything (x = 1.0, m = 0.7451107426181028, logp = -2.147956342556143, a = Any[0.929285513523637, 0.5979631005709274, 3.1944251790696794, -0.611727924858586, 1.0561547812845788, 0.9694358994096283, 0.9130715096769692, 0.018196803783751103, 0.919216919507385, 0.5382931170515716 … 0.24049636440948963, 0.7868300598622132, 0.18210113151764207, 1.9568444848346909, 0.4512687443970105, 0.5155449377942058, 0.32900420588898294, 1.4274186203915957, 0.5599167955770905, -0.5351355146677438]) julia> m(Random.GLOBAL_RNG, GeneratedQuantitiesContext()) # (✓) just works even when wrapped in `SamplingContext` (x = 1.0, m = -0.7568553457880578, logp = -3.6675624266453637, a = Any[-0.4715652454870858, 2.0553700887015776, -0.6055293756513751, -1.6153963580018202, -0.6316036328835652, -0.9694585645577235, -1.234406947321252, 0.47075652867281415, -2.0403875768252187, -1.332736944079995 … -1.123179524380186, -0.4445013585772502, -1.2366853403014721, -0.2726171409415225, 0.050910231154342234, -1.937702826603367, -2.109933658872988, -0.8300603173278494, 0.076480161355589, -1.2666452596129179]) ``` This would also _just work_ for submodels, etc. This "conditional execution" could of course be generalized too. Such a conditional execution is very useful when you also want to work with Zygote, e.g. you don't want mutations in the model but for the post-processing steps, e.g. `predict` and `generated_quantities`, you do need it. But whether or not we want this particular `@generatedquantities` macro in DPPL is not the point; the point is that we _can_ implement such a thing. Even more importantly, I can easily implement this from the "outside", no needing to touch DPPL to do it. ### `ConditionContext` See #278. Co-authored-by: Hong Ge <hg344@cam.ac.uk>
This is motivated by the potential introduction of more contexts, e.g. #278, and has been brought up as an alternative (and better) approach to achieve parts of what we want to achieve in #254 . (I hope you're proud of me @devmotion ) ## Current state of things Currently, if one wants to implement a new `AbstractContext` one _at least_ has to implement the following methods: ```julia tilde_assume(...) tilde_observe(...) dot_tilde_assume(...) dot_tilde_observe(...) ``` But there are also other methods that _should_ be implemented but generally aren't properly handled, e.g. `matchingvalue`. And there might be more methods in the future, e.g. `contextual_isassumption` in #254. ## This sucks This means that: 1. Implementing a new behavior for `AbstractContext`, e.g. `contextual_isassumption`, requires you to: 1. Find all implementations of `AbstractContext`, which is non-trivial! Most are here in DPPL, but some are in Turing.jl, and eventually we also want packages outside of the Turing.jl-umbrella to extend DPPL using contexts. 2. Implement the method for that particular context. 2. Implementing a new `AbstractContext`, e.g. `Turing.OptimizationContext`, requires you to find all the methods to implement and then do so. Again, non-trivial. This combinatorial blow up essentially means that we're super-reluctant to introduce new behaviors or new contexts, for good reasons. And the stupid thing is that in most cases a context is only trying to modify maybe one or two "behaviors", e.g. `MiniBatchContext` only wants to change the `*tilde_observe` methods, and otherwise just defer to whatever implementation is available for its "childcontext" `minibatchcontext.context`. ## Goal A new `AbstractContext` should only (up to an additive factor) have to implement the behavior it _wants to change_, not _all_ behaviors. E.g. `MiniBatchContext` should really only have to overload `*tilde_observe`. ## Solution (this PR) The above was the motivation for #254, but there we wanted a rather strict separation between certain types of contexts which we're reluctant to add due to its restrictive nature (in particular given how "recently" contexts where introduced). This PR takes are more extensible and less restrictive approach of introducing some traits for `AbstractContext`. As a starter I've just introduced a couple of traits to allow code-sharing between "parent contexts", e.g. `MiniBatchContext` and `ConditionContext` (from #278), and "leaf-contexts", e.g. `DefaultContext` and `PriorContext` (which IMO should have been a wrapper-context itself). Ideally we'd also define a promotion-system, e.g. what do we do if we're asked to combine a `MiniBatchContext` and `DefaultContext`? Well in that case we could either 1. replace `minibatchcontext.context` with `DefaultContext()`, or 2. recursively "rewrap" `DefaultContext()` in `minibatch.context`. (1) has the issue that `MiniBatchContext` might be wrapping another context, e.g. `PrefixContext`, and so just replacing `.context` is dangerous. (2) seems like a better idea since `DefaultContext` should always be at the end of the stack, i.e. a "leaf", since it always exists the tilde-callstack (e.g. in `tilde_assume` we call `assume`, etc.). Such a promotion system will require some thought though, but this PR will allow us to experiment with this (on top of providing a good approach to code-sharing). ## Examples ### `GeneratedQuantitiesContext` In DPPL we have the `generated_quantities` but it sort of sucks because often these quantities are not relevant for sampling, thus we're adding unnecessary computation to the sampling process. One might want want to introduce a `@generated_quantities` macro that will only be executed when called from `generated_quantities`. This can easily be achieved with contexts: ```julia using DynamicPPL, Distributions, Random using DynamicPPL: AbstractContext, IsLeaf, IsParent, childcontext struct GeneratedQuantitiesContext{Ctx} <: AbstractContext context::Ctx end GeneratedQuantitiesContext() = GeneratedQuantitiesContext(DefaultContext()) # Define the `NodeTrait` for `GeneratedQuantitiesContext`. DynamicPPL.NodeTrait(context::GeneratedQuantitiesContext) = IsParent() DynamicPPL.childcontext(context::GeneratedQuantitiesContext) = context.context """ isgeneratedquantities(context) Return `true` if `context` wants evaluation of model to execute the `@generatedquantities` block. """ function isgeneratedquantities(context::AbstractContext) return isgeneratedquantities(DynamicPPL.NodeTrait(isgeneratedquantities, context), context) end # Define the behavior for the different `NodeType`s. isgeneratedquantities(::IsLeaf, context::AbstractContext) = false function isgeneratedquantities(::IsParent, context::AbstractContext) return isgeneratedquantities(childcontext(context)) end # Specific implementations of `isgeneratedquantities`. isgeneratedquantities(context::GeneratedQuantitiesContext) = true """ @generatedquantities f(x) Specify that `f(x)` should only if the model is run using `GeneratedQuantitiesContext`. """ macro generated_quantities(expr) return esc(generated_quantities_expr(expr)) end function generated_quantities_expr(expr) return quote if isgeneratedquantities(__context__) $expr end end end ``` And usage would be as follows: ```julia julia> @model function demo(x) @generated_quantities a = [] m ~ Normal() # "Expensive" piece of code that we don't want to compute # unless we're computing the generated quantities. @generated_quantities for i = 1:100 push!(a, m + randn()) end # Observe. x ~ Normal(m, 1.0) # Return additional fields if we're computing the generated quantities. @generated_quantities return (; x, m, logp = getlogp(__varinfo__), a) return nothing end demo (generic function with 1 method) julia> m = demo(1.0); var_info = VarInfo(m); julia> m(var_info) # (✓) returns nothing julia> m(var_info, GeneratedQuantitiesContext()) # (✓) returns everything (x = 1.0, m = 0.7451107426181028, logp = -2.147956342556143, a = Any[0.929285513523637, 0.5979631005709274, 3.1944251790696794, -0.611727924858586, 1.0561547812845788, 0.9694358994096283, 0.9130715096769692, 0.018196803783751103, 0.919216919507385, 0.5382931170515716 … 0.24049636440948963, 0.7868300598622132, 0.18210113151764207, 1.9568444848346909, 0.4512687443970105, 0.5155449377942058, 0.32900420588898294, 1.4274186203915957, 0.5599167955770905, -0.5351355146677438]) julia> m(Random.GLOBAL_RNG, GeneratedQuantitiesContext()) # (✓) just works even when wrapped in `SamplingContext` (x = 1.0, m = -0.7568553457880578, logp = -3.6675624266453637, a = Any[-0.4715652454870858, 2.0553700887015776, -0.6055293756513751, -1.6153963580018202, -0.6316036328835652, -0.9694585645577235, -1.234406947321252, 0.47075652867281415, -2.0403875768252187, -1.332736944079995 … -1.123179524380186, -0.4445013585772502, -1.2366853403014721, -0.2726171409415225, 0.050910231154342234, -1.937702826603367, -2.109933658872988, -0.8300603173278494, 0.076480161355589, -1.2666452596129179]) ``` This would also _just work_ for submodels, etc. This "conditional execution" could of course be generalized too. Such a conditional execution is very useful when you also want to work with Zygote, e.g. you don't want mutations in the model but for the post-processing steps, e.g. `predict` and `generated_quantities`, you do need it. But whether or not we want this particular `@generatedquantities` macro in DPPL is not the point; the point is that we _can_ implement such a thing. Even more importantly, I can easily implement this from the "outside", no needing to touch DPPL to do it. ### `ConditionContext` See #278. Co-authored-by: Hong Ge <hg344@cam.ac.uk>
This PR implements
conditionanddeconditiontaking the context-approach.This approach works as follows:
ConditionContextspecifies which variables are condition and their values.ContextualModelpairs aAbstractContextwithModel.condition(model, ...),prior(model, ...),do(model, ...), etc.conditionanddeconditionare implemented by wrapping/unwrappingmodelwithConditionContext.isassumptioncheck returnstrueif the variable is in__context__and__context isa ConditionContext(IMO this can be made prettier).tilde_assume!forConditionContextinstead callstilde_observe!ifvnis inConditionContext(deferring the rest of the tilde-pipeline to its child context), etc.@modelreturns a constructor that wraps the resultingModelin aConditionContextwhere the conditioned values aremodel.argswith themissingremoved.This PR has the benefit that we can even introduce it without any changes to
Model, thus making it non-breaking if we drop the last of the above steps. This means that we can try the approach out a bit before fully committing.See #279 for an alternative.
Example