Skip to content

Conversation

@torfjelde
Copy link
Member

This PR implements an alternative to #278.

The way this approach works as follows:

  • ConditionContext specifies which variables are condition and their values.
  • Model now also has a ConditionContext instead of missings: by default it holds model.args and so a variable is considered missing if it's in model.args but not in the ConditionContext. Why do we need both model.args and ConditionContext (model.context)?
    • If we set something to missing, i.e. decondition(model, :x) which results in model.context but now without x, we don't know the shape of x anymore. Hence we need to keep args around to understand what the user wants.
    • model.context wraps whatever context we provide evaluation of the model, i.e. in _evaluate. This in turn means that we're still allowed to do something like model(vi, ConditionContext(x = 1.0)), resulting in ConditionContext(model.context, ConditionContext(x = 1.0)) being used inside of the model.
    • We could alternatively have a model.conditioned_args instead of a ConditionContext, but then we'd have to add more logic to the compiler to maybe extract variables if they're not conditioned on and not in model.args rather than deferring this to tilde_assume! and others.

The drawback of this vs. #278 is that it's definitively going to a breaking change + we probably want the ContextualModel from #278 in the future anyways for other operations, e.g. do, etc. Hence I'm personally in favour of #278 and then maybe we can think about taking the approach in this PR at a later stage.

Example

julia> using DynamicPPL, Distributions

julia> @model function demo(x; y = 1.0)
           m ~ Normal()
           x ~ Normal(m, 1)
           y ~ Normal(m + 100, 1)

           return (; m, x, y, logp = getlogp(__varinfo__))
       end
demo (generic function with 1 method)

julia> m = demo(1.0, y=missing); vi = VarInfo(m);

julia> m() # `m` and `y` are sampled, as per usual
(m = 0.7477179865347953, x = 1.0, y = 100.4209457847496, logp = -3.1215698363966418)

julia> m # a `Model` but with a `ConditionContext` in there
Model{var"#2#3", (:x, :y), (:y,), Tuple{Float64, Missing}, Tuple{Float64}, (:x,), ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Float64}}, DefaultContext}}(:demo, var"#2#3"(), (x = 1.0, y = missing), (y = 1.0,), ConditionContext{(:x,), NamedTuple{(:x,), Tuple{Float64}}, DefaultContext}((x = 1.0,), DefaultContext()))

julia> condition(m, m=10.0)() # condition on `m`, `y` still sampled
(m = 10.0, x = 1.0, y = 110.1834019891443, logp = -93.27363374442507)

julia> decondition(condition(m, m=10.0), :x)() # sample `x` and `y`
(m = 10.0, x = 9.040371498852057, y = 109.56367944997731, logp = -53.312446840907796)

julia> decondition(condition(m, m=10.0), Val(:x))() # sample `x` and `y`, using `Val` to `decondition`
(m = 10.0, x = 9.711221897662238, y = 109.27183931845398, logp = -53.0636209848837)

julia> condition(m, m=10.0, y=100.0) # conditioned on everything
Model{var"#2#3", (:x, :y), (:y,), Tuple{Float64, Missing}, Tuple{Float64}, (:x, :m, :y), ConditionContext{(:x, :m, :y), NamedTuple{(:x, :m, :y), Tuple{Float64, Float64, Float64}}, DefaultContext}}(:demo, var"#2#3"(), (x = 1.0, y = missing), (y = 1.0,), ConditionContext{(:x, :m, :y), NamedTuple{(:x, :m, :y), Tuple{Float64, Float64, Float64}}, DefaultContext}((x = 1.0, m = 10.0, y = 100.0), DefaultContext()))

julia> condition(m, m=10.0, y=100.0)() # conditioned on everything
(m = 10.0, x = 1.0, y = 100.0, logp = -143.25681559961401)

julia> condition(m, m=10.0, y=100.0)() # (✓) deterministic
(m = 10.0, x = 1.0, y = 100.0, logp = -143.25681559961401)

@torfjelde torfjelde marked this pull request as draft July 24, 2021 07:02
@torfjelde
Copy link
Member Author

Closed in favour of #294

@torfjelde torfjelde closed this Aug 1, 2021
@yebai yebai deleted the tor/conditioning-alternative branch January 31, 2022 20:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants