11#  assume
2- """ 
3-     tilde_assume(context::SamplingContext, right, vn, vi) 
4- 
5- Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), 
6- accumulate the log probability, and return the sampled value with a context associated 
7- with a sampler. 
8- 
9- Falls back to 
10- ```julia 
11- tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) 
12- ``` 
13- """ 
14- function  tilde_assume (context:: SamplingContext , right, vn, vi)
15-     return  tilde_assume (context. rng, context. context, context. sampler, right, vn, vi)
16- end 
17- 
182function  tilde_assume (context:: AbstractContext , args... )
193    return  tilde_assume (childcontext (context), args... )
204end 
215function  tilde_assume (:: DefaultContext , right, vn, vi)
22-     return  assume (right, vn, vi)
23- end 
24- 
25- function  tilde_assume (rng:: Random.AbstractRNG , context:: AbstractContext , args... )
26-     return  tilde_assume (rng, childcontext (context), args... )
27- end 
28- function  tilde_assume (rng:: Random.AbstractRNG , :: DefaultContext , sampler, right, vn, vi)
29-     return  assume (rng, sampler, right, vn, vi)
30- end 
31- function  tilde_assume (rng:: Random.AbstractRNG , :: InitContext , sampler, right, vn, vi)
32-     @warn (
33-         " Encountered SamplingContext->InitContext. This method will be removed in the next PR." 
34-     )
35-     #  just pretend the `InitContext` isn't there for now.
36-     return  assume (rng, sampler, right, vn, vi)
37- end 
38- function  tilde_assume (:: DefaultContext , sampler, right, vn, vi)
39-     #  same as above but no rng
40-     return  assume (Random. default_rng (), sampler, right, vn, vi)
6+     y =  getindex_internal (vi, vn)
7+     f =  from_maybe_linked_internal_transform (vi, vn, right)
8+     x, logjac =  with_logabsdet_jacobian (f, y)
9+     vi =  accumulate_assume!! (vi, x, logjac, vn, right)
10+     return  x, vi
4111end 
42- 
4312function  tilde_assume (context:: PrefixContext , right, vn, vi)
4413    #  Note that we can't use something like this here:
4514    #      new_vn = prefix(context, vn)
@@ -53,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
5322    new_vn, new_context =  prefix_and_strip_contexts (context, vn)
5423    return  tilde_assume (new_context, right, new_vn, vi)
5524end 
56- function  tilde_assume (
57-     rng:: Random.AbstractRNG , context:: PrefixContext , sampler, right, vn, vi
58- )
59-     new_vn, new_context =  prefix_and_strip_contexts (context, vn)
60-     return  tilde_assume (rng, new_context, sampler, right, new_vn, vi)
61- end 
6225
6326""" 
6427    tilde_assume!!(context, right, vn, vi) 
@@ -78,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi)
7841end 
7942
8043#  observe
81- """ 
82-     tilde_observe!!(context::SamplingContext, right, left, vi) 
83- 
84- Handle observed constants with a `context` associated with a sampler. 
85- 
86- Falls back to `tilde_observe!!(context.context, right, left, vi)`. 
87- """ 
88- function  tilde_observe!! (context:: SamplingContext , right, left, vn, vi)
89-     return  tilde_observe!! (context. context, right, left, vn, vi)
90- end 
91- 
9244function  tilde_observe!! (context:: AbstractContext , right, left, vn, vi)
9345    return  tilde_observe!! (childcontext (context), right, left, vn, vi)
9446end 
@@ -121,59 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
12173    vi =  accumulate_observe!! (vi, right, left, vn)
12274    return  left, vi
12375end 
124- 
125- function  assume (:: Random.AbstractRNG , spl:: Sampler , dist)
126-     return  error (" DynamicPPL.assume: unmanaged inference algorithm: $(typeof (spl)) " 
127- end 
128- 
129- #  fallback without sampler
130- function  assume (dist:: Distribution , vn:: VarName , vi)
131-     y =  getindex_internal (vi, vn)
132-     f =  from_maybe_linked_internal_transform (vi, vn, dist)
133-     x, logjac =  with_logabsdet_jacobian (f, y)
134-     vi =  accumulate_assume!! (vi, x, logjac, vn, dist)
135-     return  x, vi
136- end 
137- 
138- #  TODO : Remove this thing.
139- #  SampleFromPrior and SampleFromUniform
140- function  assume (
141-     rng:: Random.AbstractRNG ,
142-     sampler:: Union{SampleFromPrior,SampleFromUniform} ,
143-     dist:: Distribution ,
144-     vn:: VarName ,
145-     vi:: VarInfoOrThreadSafeVarInfo ,
146- )
147-     if  haskey (vi, vn)
148-         #  Always overwrite the parameters with new ones for `SampleFromUniform`.
149-         if  sampler isa  SampleFromUniform ||  is_flagged (vi, vn, " del" 
150-             #  TODO (mhauru) Is it important to unset the flag here? The `true` allows us
151-             #  to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
152-             #  if that's okay.
153-             unset_flag! (vi, vn, " del" true )
154-             r =  init (rng, dist, sampler)
155-             f =  to_maybe_linked_internal_transform (vi, vn, dist)
156-             #  TODO (mhauru) This should probably be call a function called setindex_internal!
157-             vi =  BangBang. setindex!! (vi, f (r), vn)
158-             setorder! (vi, vn, get_num_produce (vi))
159-         else 
160-             #  Otherwise we just extract it.
161-             r =  vi[vn, dist]
162-         end 
163-     else 
164-         r =  init (rng, dist, sampler)
165-         if  istrans (vi)
166-             f =  to_linked_internal_transform (vi, vn, dist)
167-             vi =  push!! (vi, vn, f (r), dist)
168-             #  By default `push!!` sets the transformed flag to `false`.
169-             vi =  settrans!! (vi, true , vn)
170-         else 
171-             vi =  push!! (vi, vn, r, dist)
172-         end 
173-     end 
174- 
175-     #  HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
176-     logjac =  logabsdetjac (istrans (vi, vn) ?  link_transform (dist) :  identity, r)
177-     vi =  accumulate_assume!! (vi, r, - logjac, vn, dist)
178-     return  r, vi
179- end 
0 commit comments