|  | 
| 1 | 1 | # assume | 
| 2 |  | -function tilde_assume(context::AbstractContext, args...) | 
| 3 |  | -    return tilde_assume(childcontext(context), args...) | 
|  | 2 | +function tilde_assume!!(context::AbstractContext, right::Distribution, vn, vi) | 
|  | 3 | +    return tilde_assume!!(childcontext(context), right, vn, vi) | 
| 4 | 4 | end | 
| 5 |  | -function tilde_assume(::DefaultContext, right, vn, vi) | 
|  | 5 | +function tilde_assume!!(::DefaultContext, right::Distribution, vn, vi) | 
| 6 | 6 |     y = getindex_internal(vi, vn) | 
| 7 | 7 |     f = from_maybe_linked_internal_transform(vi, vn, right) | 
| 8 | 8 |     x, inv_logjac = with_logabsdet_jacobian(f, y) | 
| 9 | 9 |     vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) | 
| 10 | 10 |     return x, vi | 
| 11 | 11 | end | 
| 12 |  | -function tilde_assume(context::PrefixContext, right, vn, vi) | 
|  | 12 | +function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi) | 
| 13 | 13 |     # Note that we can't use something like this here: | 
| 14 | 14 |     #     new_vn = prefix(context, vn) | 
| 15 |  | -    #     return tilde_assume(childcontext(context), right, new_vn, vi) | 
|  | 15 | +    #     return tilde_assume!!(childcontext(context), right, new_vn, vi) | 
| 16 | 16 |     # This is because `prefix` applies _all_ prefixes in a given context to a | 
| 17 | 17 |     # variable name. Thus, if we had two levels of nested prefixes e.g. | 
| 18 | 18 |     # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the | 
| 19 | 19 |     # first call would apply the prefix `a.b._`, and the recursive call | 
| 20 | 20 |     # would apply the prefix `b._`, resulting in `b.a.b._`. | 
| 21 | 21 |     # This is why we need a special function, `prefix_and_strip_contexts`. | 
| 22 | 22 |     new_vn, new_context = prefix_and_strip_contexts(context, vn) | 
| 23 |  | -    return tilde_assume(new_context, right, new_vn, vi) | 
|  | 23 | +    return tilde_assume!!(new_context, right, new_vn, vi) | 
| 24 | 24 | end | 
| 25 | 25 | 
 | 
| 26 | 26 | """ | 
| 27 | 27 |     tilde_assume!!(context, right, vn, vi) | 
| 28 | 28 | 
 | 
| 29 | 29 | Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), | 
| 30 | 30 | accumulate the log probability, and return the sampled value and updated `vi`. | 
| 31 |  | -
 | 
| 32 |  | -By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log | 
| 33 |  | -probability of `vi` with the returned value. | 
| 34 | 31 | """ | 
| 35 |  | -function tilde_assume!!(context, right, vn, vi) | 
| 36 |  | -    return if right isa DynamicPPL.Submodel | 
| 37 |  | -        _evaluate!!(right, vi, context, vn) | 
| 38 |  | -    else | 
| 39 |  | -        tilde_assume(context, right, vn, vi) | 
| 40 |  | -    end | 
|  | 32 | +function tilde_assume!!(context, right::DynamicPPL.Submodel, vn, vi) | 
|  | 33 | +    return _evaluate!!(right, vi, context, vn) | 
| 41 | 34 | end | 
| 42 | 35 | 
 | 
| 43 | 36 | # observe | 
|  | 
0 commit comments