11#  assume
2- """ 
3-     tilde_assume(context::SamplingContext, right, vn, vi) 
4- 
5- Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), 
6- accumulate the log probability, and return the sampled value with a context associated 
7- with a sampler. 
8- 
9- Falls back to 
10- ```julia 
11- tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) 
12- ``` 
13- """ 
14- function  tilde_assume (context:: SamplingContext , right, vn, vi)
15-     return  tilde_assume (context. rng, context. context, context. sampler, right, vn, vi)
16- end 
17- 
182function  tilde_assume (context:: AbstractContext , args... )
193    return  tilde_assume (childcontext (context), args... )
204end 
215function  tilde_assume (:: DefaultContext , right, vn, vi)
22-     return  assume (right, vn, vi)
23- end 
24- 
25- function  tilde_assume (rng:: Random.AbstractRNG , context:: AbstractContext , args... )
26-     return  tilde_assume (rng, childcontext (context), args... )
27- end 
28- function  tilde_assume (rng:: Random.AbstractRNG , :: DefaultContext , sampler, right, vn, vi)
29-     return  assume (rng, sampler, right, vn, vi)
30- end 
31- function  tilde_assume (:: Random.AbstractRNG , :: InitContext , sampler, right, vn, vi)
32-     return  error (
33-         " Encountered SamplingContext->InitContext. This method will be removed in the next PR." 
34-     )
35- end 
36- function  tilde_assume (:: DefaultContext , sampler, right, vn, vi)
37-     #  same as above but no rng
38-     return  assume (Random. default_rng (), sampler, right, vn, vi)
6+     y =  getindex_internal (vi, vn)
7+     f =  from_maybe_linked_internal_transform (vi, vn, right)
8+     x, logjac =  with_logabsdet_jacobian (f, y)
9+     vi =  accumulate_assume!! (vi, x, logjac, vn, right)
10+     return  x, vi
3911end 
40- 
4112function  tilde_assume (context:: PrefixContext , right, vn, vi)
4213    #  Note that we can't use something like this here:
4314    #      new_vn = prefix(context, vn)
@@ -51,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
5122    new_vn, new_context =  prefix_and_strip_contexts (context, vn)
5223    return  tilde_assume (new_context, right, new_vn, vi)
5324end 
54- function  tilde_assume (
55-     rng:: Random.AbstractRNG , context:: PrefixContext , sampler, right, vn, vi
56- )
57-     new_vn, new_context =  prefix_and_strip_contexts (context, vn)
58-     return  tilde_assume (rng, new_context, sampler, right, new_vn, vi)
59- end 
6025
6126""" 
6227    tilde_assume!!(context, right, vn, vi) 
@@ -76,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi)
7641end 
7742
7843#  observe
79- """ 
80-     tilde_observe!!(context::SamplingContext, right, left, vi) 
81- 
82- Handle observed constants with a `context` associated with a sampler. 
83- 
84- Falls back to `tilde_observe!!(context.context, right, left, vi)`. 
85- """ 
86- function  tilde_observe!! (context:: SamplingContext , right, left, vn, vi)
87-     return  tilde_observe!! (context. context, right, left, vn, vi)
88- end 
89- 
9044function  tilde_observe!! (context:: AbstractContext , right, left, vn, vi)
9145    return  tilde_observe!! (childcontext (context), right, left, vn, vi)
9246end 
@@ -119,59 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
11973    vi =  accumulate_observe!! (vi, right, left, vn)
12074    return  left, vi
12175end 
122- 
123- function  assume (:: Random.AbstractRNG , spl:: Sampler , dist)
124-     return  error (" DynamicPPL.assume: unmanaged inference algorithm: $(typeof (spl)) " 
125- end 
126- 
127- #  fallback without sampler
128- function  assume (dist:: Distribution , vn:: VarName , vi)
129-     y =  getindex_internal (vi, vn)
130-     f =  from_maybe_linked_internal_transform (vi, vn, dist)
131-     x, logjac =  with_logabsdet_jacobian (f, y)
132-     vi =  accumulate_assume!! (vi, x, logjac, vn, dist)
133-     return  x, vi
134- end 
135- 
136- #  TODO : Remove this thing.
137- #  SampleFromPrior and SampleFromUniform
138- function  assume (
139-     rng:: Random.AbstractRNG ,
140-     sampler:: Union{SampleFromPrior,SampleFromUniform} ,
141-     dist:: Distribution ,
142-     vn:: VarName ,
143-     vi:: VarInfoOrThreadSafeVarInfo ,
144- )
145-     if  haskey (vi, vn)
146-         #  Always overwrite the parameters with new ones for `SampleFromUniform`.
147-         if  sampler isa  SampleFromUniform ||  is_flagged (vi, vn, " del" 
148-             #  TODO (mhauru) Is it important to unset the flag here? The `true` allows us
149-             #  to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
150-             #  if that's okay.
151-             unset_flag! (vi, vn, " del" true )
152-             r =  init (rng, dist, sampler)
153-             f =  to_maybe_linked_internal_transform (vi, vn, dist)
154-             #  TODO (mhauru) This should probably be call a function called setindex_internal!
155-             vi =  BangBang. setindex!! (vi, f (r), vn)
156-             setorder! (vi, vn, get_num_produce (vi))
157-         else 
158-             #  Otherwise we just extract it.
159-             r =  vi[vn, dist]
160-         end 
161-     else 
162-         r =  init (rng, dist, sampler)
163-         if  istrans (vi)
164-             f =  to_linked_internal_transform (vi, vn, dist)
165-             vi =  push!! (vi, vn, f (r), dist)
166-             #  By default `push!!` sets the transformed flag to `false`.
167-             vi =  settrans!! (vi, true , vn)
168-         else 
169-             vi =  push!! (vi, vn, r, dist)
170-         end 
171-     end 
172- 
173-     #  HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
174-     logjac =  logabsdetjac (istrans (vi, vn) ?  link_transform (dist) :  identity, r)
175-     vi =  accumulate_assume!! (vi, r, - logjac, vn, dist)
176-     return  r, vi
177- end 
0 commit comments