Skip to content

Conversation

@torfjelde
Copy link
Member

@torfjelde torfjelde commented Jul 15, 2021

This PR implements condition and decondition taking the context-approach.

This approach works as follows:

  • ConditionContext specifies which variables are condition and their values.
  • ContextualModel pairs a AbstractContext with Model.
    • Allows us to define operations using contexts to instead act on models, e.g. condition(model, ...), prior(model, ...), do(model, ...), etc.
  • condition and decondition are implemented by wrapping/unwrapping model with ConditionContext.
  • isassumption check returns true if the variable is in __context__ and __context isa ConditionContext (IMO this can be made prettier).
  • tilde_assume! for ConditionContext instead calls tilde_observe! if vn is in ConditionContext (deferring the rest of the tilde-pipeline to its child context), etc.
  • @model returns a constructor that wraps the resulting Model in a ConditionContext where the conditioned values are model.args with the missing removed.

This PR has the benefit that we can even introduce it without any changes to Model, thus making it non-breaking if we drop the last of the above steps. This means that we can try the approach out a bit before fully committing.

See #279 for an alternative.

Example

julia> @model function demo(x; y = 1.0)
           m ~ Normal()
           x ~ Normal(m, 1)
           y ~ Normal(m + 100, 1)

           return (; m, x, y, logp = getlogp(__varinfo__))
       end
demo (generic function with 1 method)

julia> m = demo(1.0, y=missing); vi = VarInfo(m);

julia> m() # `m` and `y` are sampled, as per usual
(m = -0.25962116394142204, x = 1.0, y = 100.78371360219013, logp = -4.1281136294311285)

julia> m # a `Model` but with a `ConditionContext` in there
ContextualModel{ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Float64}}, DefaultContext}, Model{var"#2#3", (:x, :y), (:y,), (:y,), Tuple{Float64, Missing}, Tuple{Float64}}}(ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Float64}}, DefaultContext}((x = 1.0,), DefaultContext()), Model{var"#2#3", (:x, :y), (:y,), (:y,), Tuple{Float64, Missing}, Tuple{Float64}}(:demo, var"#2#3"(), (x = 1.0, y = missing), (y = 1.0,)))

julia> condition(m, m=10.0)() # condition on `m`, `y` still sampled
(m = 10.0, x = 1.0, y = 108.85531912531268, logp = -93.9119627520515)

julia> decondition(condition(m, m=10.0), :x)() # sample `x` and `y`
(m = 10.0, x = 8.679001479337314, y = 110.91533778314589, logp = -54.048255774037735)

julia> decondition(condition(m, m=10.0), Val(:x))() # sample `x` and `y`, using `Val` to `decondition`
(m = 10.0, x = 10.337833445626693, y = 110.17352701342462, logp = -52.828937130300055)

julia> condition(m, m=10.0, y=100.0) # conditioned on everything
ContextualModel{ConditionContext{(:x, :m, :y), NamedTuple{(:x, :m, :y), Tuple{Float64, Float64, Float64}}, DefaultContext}, Model{var"#2#3", (:x, :y), (:y,), (:y,), Tuple{Float64, Missing}, Tuple{Float64}}}(ConditionContext{(:x, :m, :y), NamedTuple{(:x, :m, :y), Tuple{Float64, Float64, Float64}}, DefaultContext}((x = 1.0, m = 10.0, y = 100.0), DefaultContext()), Model{var"#2#3", (:x, :y), (:y,), (:y,), Tuple{Float64, Missing}, Tuple{Float64}}(:demo, var"#2#3"(), (x = 1.0, y = missing), (y = 1.0,)))

julia> condition(m, m=10.0, y=100.0)() # conditioned on everything
(m = 10.0, x = 1.0, y = 100.0, logp = -143.25681559961401)

julia> condition(m, m=10.0, y=100.0)() # (✓) deterministic
(m = 10.0, x = 1.0, y = 100.0, logp = -143.25681559961401)

torfjelde and others added 2 commits July 16, 2021 01:52
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Tor, it looks good overall. I like this approach. See below for some minor comments.

vectorize,
# Model
Model,
ContextualModel,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe keep this internal until we have a good case of exporting it.

Copy link
Member Author

@torfjelde torfjelde Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think any case that can be made for exporting Model now also can be made for ContextualModel, don't you think?

PriorContext,
MiniBatchContext,
PrefixContext,
ConditionContext,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, maybe keep this internal until we have a good case of exporting it.

Copy link
Member Author

@torfjelde torfjelde Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why I exported it is because essentially all other contexts are being exported. It also makes show nicer, tbh. But I'm also fine with not exporting, up to you 👍

src/compiler.jl Outdated
$(DynamicPPL.Model)(
$(QuoteNode(modeldef[:name])),
$evaluator,
$allargs_namedtuple,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are passing allargs_namedtuple to both the model constructor and the condition call - is it because we are still trying to be backwards compatible with the current design?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BUT I would say that we should remove this initially. If the function defined by @model suddenly returns ContextualModel, I'm going to wager that a lot of dispatches in the ecosystem will be messed up.

So IMO:

  1. Introduce this PR but without conditioning by default, i.e. @model still defines a method returning Model.
  2. We get to try out condition and decondition to see how things work.
  3. If everything works out nicely, we fully replace Model (then we could even add context as an additional field to Model rather than introducing ContextualModel)

return value
end

function tilde_assume!(context::ConditionContext, right, vn, inds, vi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function tilde_assume!(context::ConditionContext, right, vn, inds, vi)
function tilde_assume_or_observe!(context::ConditionContext, right, vn, inds, vi)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, what's the intention with this renaming?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume (:stuck_out_tongue_closed_eyes:) it's because now the function does not only implement assumption anymore. But it is a weirdly mixed case anyway, since missings are either defined by model arguments or by the context, right?

IMHO the cleanest solution would be to handle the distinction entirely in isassumption. Isn't that possible? Is there a case when e.g. the isassumption check takes the true branch, but then this tilde_assume! call resorts to tilde_observe! anyway?

(See also my non-review comment below.)

Copy link
Member Author

@torfjelde torfjelde Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one 😎

But no that won't work unfortunately. We can't call tilde_observe! without left, and left isn't defined until we've extracted it from VarInfo or __context__.

Copy link
Member

@phipsgabler phipsgabler Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I see. But that starts to get rather unclean. Why not rewrite tildes somewhat like that, calculating the "observation value" beforehand or setting it to missing:

var"##vn#261" = (VarName){:x}()
var"##observationvalue#263" = begin
    getobsvalueormissing(__model__, __context__, var"##vn#261")
end

if var"##observationvalue#263" === missing
    x = (DynamicPPL.tilde_assume!)(
        __context__,
        (DynamicPPL.unwrap_right_vn)((DynamicPPL.check_tilde_rhs)(filldist(Normal(m, sqrt(s)), length(y))), var"##vn#261")...,
        var"##inds#262",
        __varinfo__
    )
else
    (DynamicPPL.tilde_observe!)(
        __context__,
        (DynamicPPL.check_tilde_rhs)(filldist(Normal(m, sqrt(s)), length(y))),
        var"##observationvalue#263",
        var"##vn#261",
        var"##inds#262",
        __varinfo__
    )
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I actually started writing a comment immediately after posting my preivous one where I started saying "Of course we can extract the value in the model-macro before the tilde-statement and then do something nifty there" but gave up because I'm struggling to come up with a sufficiently robust approach. And IMO it needs to be a very good approach vs. intercepting the tilde-calls because intercepting the tilde-calls is non-intrusive while doing something like the above touches the DPPL-compiler directly.

And I'm not 100% sure I agree it's "unclean". IMO we should be allowed to switch between sampling and evaluation anywhere in the tilde-pipeline, and evaluation is essentially just an observe-statement with the exception that we might extract the value from VarInfo.

return value
end

function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi)
function dot_tilde_assume_or_observe!(context::ConditionContext, right, left, vn, inds, vi)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above: a bit confused what you mean here.

src/contexts.jl Outdated
end
end

struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe

Suggested change
struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext
struct ConditionContext{VarNames, Values,Ctx<:AbstractContext} <: AbstractContext

# When no second argument is given, we remove _all_ conditioned variables.
# TODO: Should we remove this and just return `context.context`?
# That will work better if `Model` becomes like `ContextualModel`.
decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename this to _decondition to avoid confusion with decondition(m, ...)?

@phipsgabler

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the different types for the arguments already disambiguiates the two 😕

@@ -1,3 +1,5 @@
abstract type AbstractModel <: AbstractProbabilisticProgram end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set up a type alias instead of creating a new level of the hierarchy?

Copy link
Member Author

@torfjelde torfjelde Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately not because then methods such as

function (model::AbstractModel)(args...)
    return model(Random.GLOBAL_RNG, args...)
end

become type-piracy.

@yebai
Copy link
Member

yebai commented Jul 18, 2021

bors try

bors bot added a commit that referenced this pull request Jul 18, 2021
@bors
Copy link
Contributor

bors bot commented Jul 18, 2021

try

Build failed:

@torfjelde
Copy link
Member Author

bors try

bors bot added a commit that referenced this pull request Jul 19, 2021
@bors
Copy link
Contributor

bors bot commented Jul 19, 2021

try

Build failed:

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@torfjelde
Copy link
Member Author

So I thought further on the "maybe we should extract the value in the model rather than in the tilde-statment?" idea. This is now what I've done in this PR.

  1. isassumption now includes a check called contextual_isassumption(context, vn) which checks whether context considers vn as an assumption or observation. Note that this can recurse, e.g.
    • in the case of ConditionContext if we don't consider vn as an observation, we then check it's child context,
    • in the case of PrefixContext we simply add the prefix to vn, just as we do in the tilde-pipeline, and defers the decision to its child context.
  2. In the compiler we've added code which will extract the value from the context only if the value is not present in inargnames.

All in all, this means that now the following two instances of models are equivalent:

julia> @model function demo_noargs()
           m ~ Normal()
           x ~ Normal(m, 1)
           return (; m, x)
       end
demo (generic function with 2 methods)

julia> condition(demo_noargs(), x=1.0)()
(m = 1.3184454383513486, x = 1.0)

julia> @model function demo(x)
           m ~ Normal()
           x ~ Normal(m, 1)
           return (; m, x)
       end
demo (generic function with 2 methods)

julia> demo(1.0)()
(m = 0.8496181905704883, x = 1.0)

The only drawback is that we cannot use this to change model arguments, i.e. neither condition nor decondition will do anything to model arguments:

julia> decondition(demo(1.0), :x)()
(m = 1.6693305580373252, x = 1.0)

julia> condition(demo(1.0), x=2.0)()
(m = 1.3370773180628, x = 1.0)

In contrast, both of these work for the "no-argument" model:

julia> decondition(condition(demo_noargs(), x=1.0), :x)()
(m = 0.16963191023957844, x = -0.5504657603545225)

julia> condition(condition(demo_noargs(), x=1.0), x=2.0)()
(m = -0.5525977807677651, x = 2.0)

Essentially, we'd get everything that we desire (i.e. ability to condition or decondition on any random variables) if the users only use the approach where model arguments are only meant to be constants, never variables.

We could potentially start experimenting with approaches that could allow this, e.g. we can imagine overloading matchingvalue to replace the arguments in the call to the model but we're still left with the issue of identifying the corresponding vn with the argument which is non-trivial in inner above. I do believe it's possible though, just non-trivial, hence I would argue that it should be left out as of now.

On the flip-side this works very well even for submodels, etc.:

julia> @model function inner(x)
           m ~ Normal()
           x ~ Normal(m, 1)

           return (; m, x)
       end
inner (generic function with 1 method)

julia> @model function outer(x)
           y = Vector(undef, length(x))
           for i in eachindex(x)
               y[i] = @submodel $(Symbol("y[$i]")) inner(x[i])
           end
           return y
       end
outer (generic function with 1 method)

julia> m = outer([1.0, 2.0])
Model{var"#43#44", (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}}(:outer, var"#43#44"(), (x = [1.0, 2.0],), NamedTuple())

julia> condition(m, var"y[1].m" = 100.0)()
2-element Vector{Any}:
 (m = 100.0, x = 1.0)
 (m = -0.002383907227840093, x = 2.0)

Drawbacks

  1. Model args still take precedence, but as I've said before, model args are ambiguous, e.g. how would a user condition on x in the second call to inner within outer in the above example?
  2. Empty models can be a bit annoying. E.g. it's nicer to do length(x) in outer in the above example rather than specify a length in the argument which one cannot then change easily from the "outside" of the model, e.g. if we do outer(2) then we cannot use condition with something that has a different length.

Benefits

  1. Very non-intrusive since it's just additional functionality rather than changing existing functionality.
  • E.g. the current missing functionality and all that is still present so won't break any code. Ideally we'd move away from this, but as said before, this will require more work I think.
  1. If ConditionContext is applied after PrefixContext, i.e. PrefixContext is a ancestor of ConditionContext in the context-tree, then we're good. This is the case for submodels, hence they are supported.

Comment on lines +34 to +36
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this should be deferred for now. It will require a bit of thinking I believe.

Comment on lines +376 to +379
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getvalue)(__context__, $vn)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will compile away when possible, just as isassumption, and will preserve existing functionality/meaning of model arguments hence will be non-breaking.

Comment on lines +156 to +162
function getvalue(context::ConditionContext, vn)
return if haskey(context, vn)
_getvalue(context.values, vn)
else
getvalue(context.context, vn)
end
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This currently behaves a bit strangely in certain cases:

julia> @model function outer_with_inner_condition(x)
           y = Vector(undef, length(x))
           for i in eachindex(x)
               # Here we'll end up with `PrefixContext{..., <:ConditionContext}`
               y[i] = @submodel $(Symbol("y[$i]")) condition(inner(x[i]), m=10.0)
           end
           return y
       end
outer_with_inner_condition (generic function with 1 method)

julia> m = outer_with_inner_condition([1.0, 2.0])
Model{var"#49#50", (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}}(:outer_with_inner_condition, var"#49#50"(), (x = [1.0, 2.0],), NamedTuple())

julia> m() # (✓) both `m` are set as specified within `outer_with_inner_condition`
2-element Vector{Any}:
 (m = 10.0, x = 1.0)
 (m = 10.0, x = 2.0)

julia> condition(m, var"y[1].m"=5.0)() # (×) try to override one of the inner `m`
2-element Vector{Any}:
 (m = 10.0, x = 1.0)
 (m = 10.0, x = 2.0)

But we can change getvalue to:

function getvalue(context::ConditionContext, vn)
    # Return early if we've already found our value, thus giving precedence
    # to the inner-most `ConditionContext`.
    maybeval = getvalue(context.context, vn)
    maybeval === nothing || return maybeval

    return haskey(context, vn) ? _getvalue(context.values, vn) : nothing
end

for which we then get the "desired" behavior in the above example:

julia> condition(m, var"y[1].m"=5.0)() # (✓) try to override one of the inner `m`
2-element Vector{Any}:
 (m = 5.0, x = 1.0)
 (m = 10.0, x = 2.0)

but this does mean that we give precedence to the inner-most ConditionContext rather than the outermost, i.e. the one that was applied most recently, which is counter-intuitive and can be difficult to work with.

We could potentially introduce functionality that will recurse into the context-tree to remove all other mentionings of the variables, i.e. at most one CouplingContext in the context-tree contains a mentioning of any specific variable. cthis also makes sense when considering decondition(ctx, :x) since we'd expect this to traverse the entire context-tree and remove every mentioning of x rather than only doing so in the outer-most ConditionContext. Buuuut this brings us back to #254 again, since it would require functionality for such a tree-traversal and other things.

function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context)
# Wrap `context` in the model-associated `ConditionContext`, but now using `context` as
# `ConditionContext` child.
return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context))
return _evaluate(
cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)
)

src/contexts.jl Outdated
Comment on lines 121 to 133
@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values}
names_expr = Expr(:tuple)
values_expr = Expr(:tuple)

for (n, v) in zip(names, values.parameters)
if !(v <: Missing)
push!(names_expr.args, QuoteNode(n))
push!(values_expr.args, :(nt.$n))
end
end

return :(NamedTuple{$names_expr}($values_expr))
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is useful in this PR anymore. It's useful if we want to replace missing with !haskey(conditioncontext) but as I've said before, I think there are a bit too many corner cases that needs to be resolved before we make such an intrusive change.

The default implementation for `AbstractContext` always returns `true`.
"""
contextual_isassumption(context::AbstractContext, vn) = true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too general, e.g. we want this for DefaultContext and other "leaf" contexts but not for "parent" contexts, e.g. MiniBatchContext.

Comment on lines +178 to +181
# TODO: Can we maybe do this in a better way?
# When no second argument is given, we remove _all_ conditioned variables.
# TODO: Should we remove this and just return `context.context`?
# That will work better if `Model` becomes like `ContextualModel`.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be addressed nicely with #286

@torfjelde torfjelde marked this pull request as draft July 24, 2021 07:02
@phipsgabler
Copy link
Member

Essentially, we'd get everything that we desire (i.e. ability to condition or decondition on any random variables) if the users only use the approach where model arguments are only meant to be constants, never variables.

I've been saying that, in different terms, a couple of times already :P The "automatically deduce everything from arguments" style is kinda nice from a user perspective, but makes things hard to handle internally.

julia> condition(m, var"y[1].m" = 100.0)()

And now we're at the point of stringly submodel variables, of which I have also warned... I understand that the NamedTuple + var is just intended to be a temporary style of writing this? This is exactly the use case I have intended (but not defined) the "probability expressions" used in the APPL interface document.

I love where this is going to, though.

@torfjelde
Copy link
Member Author

I've been saying that, in different terms, a couple of times already :P The "automatically deduce everything from arguments" style is kinda nice from a user perspective, but makes things hard to handle internally.

Yeah, yeah am aware:) This was referring to this particular PR and it's implementation, just making explicit the limitations of it. But I think it's fine here because everything else is left as-is, meaning the current behavior of doing missing etc. will work as usual, we just need to make it very clear to the user that condition cannot override arguments to the model hence if they want to do that, ten they'll have to use condition all the way.

And now we're at the point of stringly submodel variables, of which I have also warned... I understand that the NamedTuple + var is just intended to be a temporary style of writing this? This is exactly the use case I have intended (but not defined) the "probability expressions" used in the APPL interface document.

Exactly: we can easily make this easier to specify, i.e. using a macro. The point here is to introduce the functionality, then we can make it nicer later.

I love where this is going to, though.

❤️

Btw, take a look at #287 . This PR has a lot of shortcomings in how it interacts with other PRs, i.e. there will be errors. That PR addresses those nicely.

@torfjelde
Copy link
Member Author

Closed in favour of #294

@torfjelde torfjelde closed this Aug 1, 2021
bors bot pushed a commit that referenced this pull request Aug 4, 2021
This is motivated by the potential introduction of more contexts, e.g. #278, and has been brought up as an alternative (and better) approach to achieve parts of what we want to achieve in #254 .

(I hope you're proud of me @devmotion )

## Current state of things
Currently, if one wants to implement a new `AbstractContext` one _at least_ has to implement the following methods:

```julia
tilde_assume(...)
tilde_observe(...)
dot_tilde_assume(...)
dot_tilde_observe(...)
```

But there are also other methods that _should_ be implemented but generally aren't properly handled, e.g. `matchingvalue`. And there might be more methods in the future, e.g. `contextual_isassumption` in #254.

## This sucks
This means that:
1. Implementing a new behavior for `AbstractContext`, e.g. `contextual_isassumption`, requires you to:
   1. Find all implementations of `AbstractContext`, which is non-trivial! Most are here in DPPL, but some are in Turing.jl, and eventually we also want packages outside of the Turing.jl-umbrella to extend DPPL using contexts.
   2. Implement the method for that particular context.
2. Implementing a new `AbstractContext`, e.g. `Turing.OptimizationContext`, requires you to find all the methods to implement and then do so. Again, non-trivial.

This combinatorial blow up essentially means that we're super-reluctant to introduce new behaviors or new contexts, for good reasons.

And the stupid thing is that in most cases a context is only trying to modify maybe one or two "behaviors", e.g. `MiniBatchContext` only wants to change the `*tilde_observe` methods, and otherwise just defer to whatever implementation is available for its "childcontext" `minibatchcontext.context`. 

## Goal
A new `AbstractContext` should only (up to an additive factor) have to implement the behavior it _wants to change_, not _all_ behaviors. E.g. `MiniBatchContext` should really only have to overload `*tilde_observe`.

## Solution (this PR)
The above was the motivation for #254, but there we wanted a rather strict separation between certain types of contexts which we're reluctant to add due to its restrictive nature (in particular given how "recently" contexts where introduced).

This PR takes are more extensible and less restrictive approach of introducing some traits for `AbstractContext`.

As a starter I've just introduced a couple of traits to allow code-sharing between "parent contexts", e.g. `MiniBatchContext` and `ConditionContext` (from #278), and "leaf-contexts", e.g. `DefaultContext` and `PriorContext` (which IMO should have been a wrapper-context itself). Ideally we'd also define a promotion-system, e.g. what do we do if we're asked to combine a `MiniBatchContext` and `DefaultContext`? Well in that case we could either
1. replace `minibatchcontext.context` with `DefaultContext()`, or
2. recursively "rewrap" `DefaultContext()` in `minibatch.context`.

(1) has the issue that `MiniBatchContext` might be wrapping another context, e.g. `PrefixContext`, and so just replacing `.context` is dangerous. (2) seems like a better idea since `DefaultContext` should always be at the end of the stack, i.e. a "leaf", since it always exists the tilde-callstack (e.g. in `tilde_assume` we call `assume`, etc.).

Such a promotion system will require some thought though, but this PR will allow us to experiment with this (on top of providing a good approach to code-sharing).

## Examples

### `GeneratedQuantitiesContext`

In DPPL we have the `generated_quantities` but it sort of sucks because often these quantities are not relevant for sampling, thus we're adding unnecessary computation to the sampling process. One might want want to introduce a `@generated_quantities` macro that will only be executed when called from `generated_quantities`. This can easily be achieved with contexts:

```julia
using DynamicPPL, Distributions, Random
using DynamicPPL: AbstractContext, IsLeaf, IsParent, childcontext

struct GeneratedQuantitiesContext{Ctx} <: AbstractContext
    context::Ctx
end
GeneratedQuantitiesContext() = GeneratedQuantitiesContext(DefaultContext())

# Define the `NodeTrait` for `GeneratedQuantitiesContext`.
DynamicPPL.NodeTrait(context::GeneratedQuantitiesContext) = IsParent()
DynamicPPL.childcontext(context::GeneratedQuantitiesContext) = context.context

"""
    isgeneratedquantities(context)

Return `true` if `context` wants evaluation of model to execute
the `@generatedquantities` block.
"""
function isgeneratedquantities(context::AbstractContext)
    return isgeneratedquantities(DynamicPPL.NodeTrait(isgeneratedquantities, context), context)
end

# Define the behavior for the different `NodeType`s.
isgeneratedquantities(::IsLeaf, context::AbstractContext) = false
function isgeneratedquantities(::IsParent, context::AbstractContext)
    return isgeneratedquantities(childcontext(context))
end

# Specific implementations of `isgeneratedquantities`.
isgeneratedquantities(context::GeneratedQuantitiesContext) = true

"""
    @generatedquantities f(x)

Specify that `f(x)` should only if the model is run using `GeneratedQuantitiesContext`.
"""
macro generated_quantities(expr)
    return esc(generated_quantities_expr(expr))
end

function generated_quantities_expr(expr)
    return quote
        if isgeneratedquantities(__context__)
            $expr
        end
    end
end
```

And usage would be as follows:

```julia
julia> @model function demo(x)
           @generated_quantities a = []
           m ~ Normal()

           # "Expensive" piece of code that we don't want to compute
           # unless we're computing the generated quantities.
           @generated_quantities for i = 1:100
               push!(a, m + randn())
           end

           # Observe.
           x ~ Normal(m, 1.0)

           # Return additional fields if we're computing the generated quantities.
           @generated_quantities return (; x, m, logp = getlogp(__varinfo__), a)

           return nothing
       end
demo (generic function with 1 method)

julia> m = demo(1.0); var_info = VarInfo(m);

julia> m(var_info) # (✓) returns nothing

julia> m(var_info, GeneratedQuantitiesContext()) # (✓) returns everything
(x = 1.0, m = 0.7451107426181028, logp = -2.147956342556143, a = Any[0.929285513523637, 0.5979631005709274, 3.1944251790696794, -0.611727924858586, 1.0561547812845788, 0.9694358994096283, 0.9130715096769692, 0.018196803783751103, 0.919216919507385, 0.5382931170515716  …  0.24049636440948963, 0.7868300598622132, 0.18210113151764207, 1.9568444848346909, 0.4512687443970105, 0.5155449377942058, 0.32900420588898294, 1.4274186203915957, 0.5599167955770905, -0.5351355146677438])

julia> m(Random.GLOBAL_RNG, GeneratedQuantitiesContext()) # (✓) just works even when wrapped in `SamplingContext`
(x = 1.0, m = -0.7568553457880578, logp = -3.6675624266453637, a = Any[-0.4715652454870858, 2.0553700887015776, -0.6055293756513751, -1.6153963580018202, -0.6316036328835652, -0.9694585645577235, -1.234406947321252, 0.47075652867281415, -2.0403875768252187, -1.332736944079995  …  -1.123179524380186, -0.4445013585772502, -1.2366853403014721, -0.2726171409415225, 0.050910231154342234, -1.937702826603367, -2.109933658872988, -0.8300603173278494, 0.076480161355589, -1.2666452596129179])
```

This would also _just work_ for submodels, etc. This "conditional execution" could of course be generalized too. Such a conditional execution is very useful when you also want to work with Zygote, e.g. you don't want mutations in the model but for the post-processing steps, e.g. `predict` and `generated_quantities`, you do need it.

But whether or not we want this particular `@generatedquantities` macro in DPPL is not the point; the point is that we _can_ implement such a thing. Even more importantly, I can easily implement this from the "outside", no needing to touch DPPL to do it. 

### `ConditionContext`
See #278.

Co-authored-by: Hong Ge <hg344@cam.ac.uk>
bors bot pushed a commit that referenced this pull request Aug 5, 2021
This is motivated by the potential introduction of more contexts, e.g. #278, and has been brought up as an alternative (and better) approach to achieve parts of what we want to achieve in #254 .

(I hope you're proud of me @devmotion )

## Current state of things
Currently, if one wants to implement a new `AbstractContext` one _at least_ has to implement the following methods:

```julia
tilde_assume(...)
tilde_observe(...)
dot_tilde_assume(...)
dot_tilde_observe(...)
```

But there are also other methods that _should_ be implemented but generally aren't properly handled, e.g. `matchingvalue`. And there might be more methods in the future, e.g. `contextual_isassumption` in #254.

## This sucks
This means that:
1. Implementing a new behavior for `AbstractContext`, e.g. `contextual_isassumption`, requires you to:
   1. Find all implementations of `AbstractContext`, which is non-trivial! Most are here in DPPL, but some are in Turing.jl, and eventually we also want packages outside of the Turing.jl-umbrella to extend DPPL using contexts.
   2. Implement the method for that particular context.
2. Implementing a new `AbstractContext`, e.g. `Turing.OptimizationContext`, requires you to find all the methods to implement and then do so. Again, non-trivial.

This combinatorial blow up essentially means that we're super-reluctant to introduce new behaviors or new contexts, for good reasons.

And the stupid thing is that in most cases a context is only trying to modify maybe one or two "behaviors", e.g. `MiniBatchContext` only wants to change the `*tilde_observe` methods, and otherwise just defer to whatever implementation is available for its "childcontext" `minibatchcontext.context`. 

## Goal
A new `AbstractContext` should only (up to an additive factor) have to implement the behavior it _wants to change_, not _all_ behaviors. E.g. `MiniBatchContext` should really only have to overload `*tilde_observe`.

## Solution (this PR)
The above was the motivation for #254, but there we wanted a rather strict separation between certain types of contexts which we're reluctant to add due to its restrictive nature (in particular given how "recently" contexts where introduced).

This PR takes are more extensible and less restrictive approach of introducing some traits for `AbstractContext`.

As a starter I've just introduced a couple of traits to allow code-sharing between "parent contexts", e.g. `MiniBatchContext` and `ConditionContext` (from #278), and "leaf-contexts", e.g. `DefaultContext` and `PriorContext` (which IMO should have been a wrapper-context itself). Ideally we'd also define a promotion-system, e.g. what do we do if we're asked to combine a `MiniBatchContext` and `DefaultContext`? Well in that case we could either
1. replace `minibatchcontext.context` with `DefaultContext()`, or
2. recursively "rewrap" `DefaultContext()` in `minibatch.context`.

(1) has the issue that `MiniBatchContext` might be wrapping another context, e.g. `PrefixContext`, and so just replacing `.context` is dangerous. (2) seems like a better idea since `DefaultContext` should always be at the end of the stack, i.e. a "leaf", since it always exists the tilde-callstack (e.g. in `tilde_assume` we call `assume`, etc.).

Such a promotion system will require some thought though, but this PR will allow us to experiment with this (on top of providing a good approach to code-sharing).

## Examples

### `GeneratedQuantitiesContext`

In DPPL we have the `generated_quantities` but it sort of sucks because often these quantities are not relevant for sampling, thus we're adding unnecessary computation to the sampling process. One might want want to introduce a `@generated_quantities` macro that will only be executed when called from `generated_quantities`. This can easily be achieved with contexts:

```julia
using DynamicPPL, Distributions, Random
using DynamicPPL: AbstractContext, IsLeaf, IsParent, childcontext

struct GeneratedQuantitiesContext{Ctx} <: AbstractContext
    context::Ctx
end
GeneratedQuantitiesContext() = GeneratedQuantitiesContext(DefaultContext())

# Define the `NodeTrait` for `GeneratedQuantitiesContext`.
DynamicPPL.NodeTrait(context::GeneratedQuantitiesContext) = IsParent()
DynamicPPL.childcontext(context::GeneratedQuantitiesContext) = context.context

"""
    isgeneratedquantities(context)

Return `true` if `context` wants evaluation of model to execute
the `@generatedquantities` block.
"""
function isgeneratedquantities(context::AbstractContext)
    return isgeneratedquantities(DynamicPPL.NodeTrait(isgeneratedquantities, context), context)
end

# Define the behavior for the different `NodeType`s.
isgeneratedquantities(::IsLeaf, context::AbstractContext) = false
function isgeneratedquantities(::IsParent, context::AbstractContext)
    return isgeneratedquantities(childcontext(context))
end

# Specific implementations of `isgeneratedquantities`.
isgeneratedquantities(context::GeneratedQuantitiesContext) = true

"""
    @generatedquantities f(x)

Specify that `f(x)` should only if the model is run using `GeneratedQuantitiesContext`.
"""
macro generated_quantities(expr)
    return esc(generated_quantities_expr(expr))
end

function generated_quantities_expr(expr)
    return quote
        if isgeneratedquantities(__context__)
            $expr
        end
    end
end
```

And usage would be as follows:

```julia
julia> @model function demo(x)
           @generated_quantities a = []
           m ~ Normal()

           # "Expensive" piece of code that we don't want to compute
           # unless we're computing the generated quantities.
           @generated_quantities for i = 1:100
               push!(a, m + randn())
           end

           # Observe.
           x ~ Normal(m, 1.0)

           # Return additional fields if we're computing the generated quantities.
           @generated_quantities return (; x, m, logp = getlogp(__varinfo__), a)

           return nothing
       end
demo (generic function with 1 method)

julia> m = demo(1.0); var_info = VarInfo(m);

julia> m(var_info) # (✓) returns nothing

julia> m(var_info, GeneratedQuantitiesContext()) # (✓) returns everything
(x = 1.0, m = 0.7451107426181028, logp = -2.147956342556143, a = Any[0.929285513523637, 0.5979631005709274, 3.1944251790696794, -0.611727924858586, 1.0561547812845788, 0.9694358994096283, 0.9130715096769692, 0.018196803783751103, 0.919216919507385, 0.5382931170515716  …  0.24049636440948963, 0.7868300598622132, 0.18210113151764207, 1.9568444848346909, 0.4512687443970105, 0.5155449377942058, 0.32900420588898294, 1.4274186203915957, 0.5599167955770905, -0.5351355146677438])

julia> m(Random.GLOBAL_RNG, GeneratedQuantitiesContext()) # (✓) just works even when wrapped in `SamplingContext`
(x = 1.0, m = -0.7568553457880578, logp = -3.6675624266453637, a = Any[-0.4715652454870858, 2.0553700887015776, -0.6055293756513751, -1.6153963580018202, -0.6316036328835652, -0.9694585645577235, -1.234406947321252, 0.47075652867281415, -2.0403875768252187, -1.332736944079995  …  -1.123179524380186, -0.4445013585772502, -1.2366853403014721, -0.2726171409415225, 0.050910231154342234, -1.937702826603367, -2.109933658872988, -0.8300603173278494, 0.076480161355589, -1.2666452596129179])
```

This would also _just work_ for submodels, etc. This "conditional execution" could of course be generalized too. Such a conditional execution is very useful when you also want to work with Zygote, e.g. you don't want mutations in the model but for the post-processing steps, e.g. `predict` and `generated_quantities`, you do need it.

But whether or not we want this particular `@generatedquantities` macro in DPPL is not the point; the point is that we _can_ implement such a thing. Even more importantly, I can easily implement this from the "outside", no needing to touch DPPL to do it. 

### `ConditionContext`
See #278.

Co-authored-by: Hong Ge <hg344@cam.ac.uk>
@yebai yebai deleted the tor/conditioning branch January 28, 2022 20:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants