@@ -133,6 +133,28 @@ function tilde_assume!(context, right, vn, inds, vi)
133133 return value
134134end
135135
136+ function tilde_assume! (context:: ConditionContext , right, vn, inds, vi)
137+ if haskey (context, vn)
138+ # Extract value.
139+ value = if inds isa Tuple{}
140+ getfield (context. values, getsym (vn))
141+ else
142+ _getindex (getfield (context. values, getsym (vn)), inds)
143+ end
144+
145+ # Should we even do this?
146+ if haskey (vi, vn)
147+ vi[vn] = value
148+ end
149+
150+ tilde_observe! (context. context, right, value, vn, inds, vi)
151+ else
152+ value = tilde_assume! (context. context, right, vn, inds, vi)
153+ end
154+
155+ return value
156+ end
157+
136158# observe
137159"""
138160 tilde_observe(context::SamplingContext, right, left, vname, vinds, vi)
@@ -217,6 +239,10 @@ function tilde_observe!(context, right, left, vi)
217239 return left
218240end
219241
242+ function tilde_observe! (context:: ConditionContext , right, left, vi)
243+ return tilde_observe! (context. context, right, left, vi)
244+ end
245+
220246function assume (rng, spl:: Sampler , dist)
221247 return error (" DynamicPPL.assume: unmanaged inference algorithm: $(typeof (spl)) " )
222248end
@@ -419,6 +445,28 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi)
419445 return value
420446end
421447
448+ function dot_tilde_assume! (context:: ConditionContext , right, left, vn, inds, vi)
449+ if haskey (context, vn)
450+ # Extract value.
451+ value = if inds isa Tuple{}
452+ getfield (context. values, sym)
453+ else
454+ _getindex (getfield (context. values, sym), inds)
455+ end
456+
457+ # Should we even do this?
458+ if haskey (vi, vn)
459+ vi[vn] = value
460+ end
461+
462+ dot_tilde_observe! (context. context, right, value, vn, inds, vi)
463+ else
464+ value = dot_tilde_assume! (context. context, right, left, vn, inds, vi)
465+ end
466+
467+ return value
468+ end
469+
422470# `dot_assume`
423471function dot_assume (
424472 dist:: MultivariateDistribution , var:: AbstractMatrix , vns:: AbstractVector{<:VarName} , vi
@@ -637,6 +685,9 @@ function dot_tilde_observe!(context, right, left, vi)
637685 acclogp! (vi, logp)
638686 return left
639687end
688+ function dot_tilde_observe! (context:: ConditionContext , right, left, vi)
689+ return dot_tilde_observe! (context. context, right, left, vi)
690+ end
640691
641692# Falls back to non-sampler definition.
642693function dot_observe (:: AbstractSampler , dist, value, vi)
0 commit comments