@@ -108,86 +108,14 @@ function DynamicPPL.generated_quantities(
108108 varinfo = DynamicPPL. VarInfo (model)
109109 iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
110110 return map (iters) do (sample_idx, chain_idx)
111- if DynamicPPL. supports_varname_indexing (chain)
112- varname_pairs = _varname_pairs_with_varname_indexing (
113- chain, varinfo, sample_idx, chain_idx
114- )
115- else
116- varname_pairs = _varname_pairs_without_varname_indexing (
117- chain, varinfo, sample_idx, chain_idx
118- )
119- end
120- fixed_model = DynamicPPL. fix (model, Dict (varname_pairs))
121- return fixed_model ()
111+ # TODO : Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
112+ # Update the varinfo with the current sample and make variables not present in `chain`
113+ # to be sampled.
114+ DynamicPPL. setval_and_resample! (varinfo, chain, sample_idx, chain_idx)
115+ # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
116+ # `deepcopy` the `varinfo` before passing it to the `model`.
117+ model (deepcopy (varinfo))
122118 end
123119end
124120
125- """
126- _varname_pairs_with_varname_indexing(
127- chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
128- )
129-
130- Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
131- from the chain.
132-
133- This implementation assumes `chain` can be indexed using variable names, and is the
134- preffered implementation.
135- """
136- function _varname_pairs_with_varname_indexing (
137- chain:: MCMCChains.Chains , varinfo, sample_idx, chain_idx
138- )
139- vns = DynamicPPL. varnames (chain)
140- vn_parents = Iterators. map (vns) do vn
141- # The call nested_setindex_maybe! is used to handle cases where vn is not
142- # the variable name used in the model, but rather subsumed by one. Except
143- # for the subsumption part, this could be
144- # vn => getindex_varname(chain, sample_idx, vn, chain_idx)
145- # TODO (mhauru) This call to nested_setindex_maybe! is unintuitive.
146- DynamicPPL. nested_setindex_maybe! (
147- varinfo, DynamicPPL. getindex_varname (chain, sample_idx, vn, chain_idx), vn
148- )
149- end
150- varname_pairs = Iterators. map (Iterators. filter (! isnothing, vn_parents)) do vn_parent
151- vn_parent => varinfo[vn_parent]
152- end
153- return varname_pairs
154- end
155-
156- """
157- Check which keys in `key_strings` are subsumed by `vn_string` and return the their values.
158-
159- The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and
160- won't catch all cases. We should get rid of this if we can.
161- """
162- # TODO (mhauru) See docstring above.
163- function _vcat_subsumed_values (vn_string, values, key_strings)
164- indices = findall (Base. Fix1 (DynamicPPL. subsumes_string, vn_string), key_strings)
165- return ! isempty (indices) ? reduce (vcat, values[indices]) : nothing
166- end
167-
168- """
169- _varname_pairs_without_varname_indexing(
170- chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
171- )
172-
173- Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
174- from the chain.
175-
176- This implementation does not assume that `chain` can be indexed using variable names. It is
177- thus not guaranteed to work in cases where the variable names have complex subsumption
178- patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`.
179- """
180- function _varname_pairs_without_varname_indexing (
181- chain:: MCMCChains.Chains , varinfo, sample_idx, chain_idx
182- )
183- values = chain. value[sample_idx, :, chain_idx]
184- keys = Base. keys (chain)
185- keys_strings = map (string, keys)
186- varname_pairs = [
187- vn => _vcat_subsumed_values (string (vn), values, keys_strings) for
188- vn in Base. keys (varinfo)
189- ]
190- return varname_pairs
191- end
192-
193121end
0 commit comments