1- struct  PriorExtractorContext{D<: OrderedDict{VarName,Any} ,Ctx<: AbstractContext } < :
2-        AbstractContext
1+ struct  PriorDistributionAccumulator{D<: OrderedDict{VarName,Any} } <:  AbstractAccumulator 
32    priors:: D 
4-     context:: Ctx 
53end 
64
7- PriorExtractorContext (context ) =  PriorExtractorContext (OrderedDict {VarName,Any} (), context )
5+ PriorDistributionAccumulator ( ) =  PriorDistributionAccumulator (OrderedDict {VarName,Any} ())
86
9- NodeTrait (:: PriorExtractorContext ) =  IsParent ()
10- childcontext (context:: PriorExtractorContext ) =  context. context
11- function  setchildcontext (parent:: PriorExtractorContext , child)
12-     return  PriorExtractorContext (parent. priors, child)
7+ accumulator_name (:: PriorDistributionAccumulator ) =  :PriorDistributionAccumulator 
8+ 
9+ split (acc:: PriorDistributionAccumulator ) =  PriorDistributionAccumulator (empty (acc. priors))
10+ function  combine (acc1:: PriorDistributionAccumulator , acc2:: PriorDistributionAccumulator )
11+     return  PriorDistributionAccumulator (merge (acc1. priors, acc2. priors))
1312end 
1413
15- function  setprior! (context:: PriorExtractorContext , vn:: VarName , dist:: Distribution )
16-     return  context. priors[vn] =  dist
14+ function  setprior! (acc:: PriorDistributionAccumulator , vn:: VarName , dist:: Distribution )
15+     acc. priors[vn] =  dist
16+     return  acc
1717end 
1818
1919function  setprior! (
20-     context :: PriorExtractorContext , vns:: AbstractArray{<:VarName} , dist:: Distribution 
20+     acc :: PriorDistributionAccumulator , vns:: AbstractArray{<:VarName} , dist:: Distribution 
2121)
2222    for  vn in  vns
23-         context . priors[vn] =  dist
23+         acc . priors[vn] =  dist
2424    end 
25+     return  acc
2526end 
2627
2728function  setprior! (
28-     context :: PriorExtractorContext ,
29+     acc :: PriorDistributionAccumulator ,
2930    vns:: AbstractArray{<:VarName} ,
3031    dists:: AbstractArray{<:Distribution} ,
3132)
3233    for  (vn, dist) in  zip (vns, dists)
33-         context . priors[vn] =  dist
34+         acc . priors[vn] =  dist
3435    end 
36+     return  acc
3537end 
3638
37- function  DynamicPPL. tilde_assume (context:: PriorExtractorContext , right, vn, vi)
38-     setprior! (context, vn, right)
39-     return  DynamicPPL. tilde_assume (childcontext (context), right, vn, vi)
39+ function  accumulate_assume!! (acc:: PriorDistributionAccumulator , val, logjac, vn, right)
40+     return  setprior! (acc, vn, right)
4041end 
4142
43+ accumulate_observe!! (acc:: PriorDistributionAccumulator , right, left, vn) =  acc
44+ 
4245""" 
4346    extract_priors([rng::Random.AbstractRNG, ]model::Model) 
4447
@@ -108,9 +111,13 @@ julia> length(extract_priors(rng, model)[@varname(x)])
108111extract_priors (args:: Union{Model,AbstractVarInfo} ...) = 
109112    extract_priors (Random. default_rng (), args... )
110113function  extract_priors (rng:: Random.AbstractRNG , model:: Model )
111-     context =  PriorExtractorContext (SamplingContext (rng))
112-     evaluate!! (model, VarInfo (), context)
113-     return  context. priors
114+     varinfo =  VarInfo ()
115+     #  TODO (mhauru) This doesn't actually need the NumProduceAccumulator, it's only a
116+     #  workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
117+     #  can't push new variables without knowing the num_produce. Remove this when possible.
118+     varinfo =  setaccs!! (varinfo, (PriorDistributionAccumulator (), NumProduceAccumulator ()))
119+     varinfo =  last (evaluate!! (model, varinfo, SamplingContext (rng)))
120+     return  getacc (varinfo, Val (:PriorDistributionAccumulator )). priors
114121end 
115122
116123""" 
@@ -122,7 +129,12 @@ This is done by evaluating the model at the values present in `varinfo`
122129and recording the distributions that are present at each tilde statement. 
123130""" 
124131function  extract_priors (model:: Model , varinfo:: AbstractVarInfo )
125-     context =  PriorExtractorContext (DefaultContext ())
126-     evaluate!! (model, deepcopy (varinfo), context)
127-     return  context. priors
132+     #  TODO (mhauru) This doesn't actually need the NumProduceAccumulator, it's only a
133+     #  workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
134+     #  can't push new variables without knowing the num_produce. Remove this when possible.
135+     varinfo =  setaccs!! (
136+         deepcopy (varinfo), (PriorDistributionAccumulator (), NumProduceAccumulator ())
137+     )
138+     varinfo =  last (evaluate!! (model, varinfo, DefaultContext ()))
139+     return  getacc (varinfo, Val (:PriorDistributionAccumulator )). priors
128140end 
0 commit comments