@@ -36,113 +36,6 @@ function chain_sample_to_varname_dict(
3636 return d
3737end
3838
39- """
40- DynamicPPL.to_chains(
41- ::Type{MCMCChains.Chains},
42- params_and_stats::AbstractArray{<:ParamsWithStats}
43- )
44-
45- Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object.
46- """
47- function DynamicPPL. to_chains (
48- :: Type{MCMCChains.Chains} ,
49- params_and_stats:: AbstractMatrix{<:DynamicPPL.ParamsWithStats} ,
50- )
51- # Handle parameters
52- all_vn_leaves = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
53- split_dicts = map (params_and_stats) do ps
54- # Separate into individual VarNames.
55- vn_leaves_and_vals = if isempty (ps. params)
56- Tuple{DynamicPPL. VarName,Any}[]
57- else
58- iters = map (
59- AbstractPPL. varname_and_value_leaves,
60- keys (ps. params),
61- values (ps. params),
62- )
63- mapreduce (collect, vcat, iters)
64- end
65- vn_leaves = map (first, vn_leaves_and_vals)
66- vals = map (last, vn_leaves_and_vals)
67- for vn_leaf in vn_leaves
68- push! (all_vn_leaves, vn_leaf)
69- end
70- DynamicPPL. OrderedCollections. OrderedDict (zip (vn_leaves, vals))
71- end
72- vn_leaves = collect (all_vn_leaves)
73- param_vals = [
74- get (split_dicts[i, j], key, missing ) for i in eachindex (axes (split_dicts, 1 )),
75- key in vn_leaves, j in eachindex (axes (split_dicts, 2 ))
76- ]
77- param_symbols = map (Symbol, vn_leaves)
78- # Handle statistics
79- stat_keys = DynamicPPL. OrderedCollections. OrderedSet {Symbol} ()
80- for ps in params_and_stats
81- for k in keys (ps. stats)
82- push! (stat_keys, k)
83- end
84- end
85- stat_keys = collect (stat_keys)
86- stat_vals = [
87- get (params_and_stats[i, j]. stats, key, missing ) for
88- i in eachindex (axes (params_and_stats, 1 )), key in stat_keys,
89- j in eachindex (axes (params_and_stats, 2 ))
90- ]
91- # Construct name map and info
92- name_map = (internals= stat_keys,)
93- info = (
94- varname_to_symbol= DynamicPPL. OrderedCollections. OrderedDict (
95- zip (all_vn_leaves, param_symbols)
96- ),
97- )
98- # Concatenate parameter and statistic values
99- vals = cat (param_vals, stat_vals; dims= 2 )
100- symbols = vcat (param_symbols, stat_keys)
101- return MCMCChains. Chains (MCMCChains. concretize (vals), symbols, name_map; info= info)
102- end
103- function DynamicPPL. to_chains (
104- :: Type{MCMCChains.Chains} , ps:: AbstractVector{<:DynamicPPL.ParamsWithStats}
105- )
106- return DynamicPPL. to_chains (MCMCChains. Chains, hcat (ps))
107- end
108-
109- function DynamicPPL. from_chains (
110- :: Type{T} , chain:: MCMCChains.Chains
111- ) where {T<: AbstractDict{<:DynamicPPL.VarName} }
112- idxs = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
113- matrix = map (idxs) do (sample_idx, chain_idx)
114- d = T ()
115- for vn in DynamicPPL. varnames (chain)
116- d[vn] = DynamicPPL. getindex_varname (chain, sample_idx, vn, chain_idx)
117- end
118- d
119- end
120- return matrix
121- end
122- function DynamicPPL. from_chains (:: Type{NamedTuple} , chain:: MCMCChains.Chains )
123- idxs = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
124- matrix = map (idxs) do (sample_idx, chain_idx)
125- get (chain[sample_idx, :, chain_idx], keys (chain); flatten= true )
126- end
127- return matrix
128- end
129- function DynamicPPL. from_chains (
130- :: Type{DynamicPPL.ParamsWithStats} , chain:: MCMCChains.Chains
131- )
132- idxs = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
133- internals_chain = MCMCChains. get_sections (chain, :internals )
134- params = DynamicPPL. from_chains (
135- DynamicPPL. OrderedCollections. OrderedDict{DynamicPPL. VarName,eltype (chain. value)},
136- chain,
137- )
138- stats = DynamicPPL. from_chains (NamedTuple, internals_chain)
139- return map (idxs) do (sample_idx, chain_idx)
140- DynamicPPL. ParamsWithStats (
141- params[sample_idx, chain_idx], stats[sample_idx, chain_idx]
142- )
143- end
144- end
145-
14639"""
14740 predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
14841
@@ -217,6 +110,7 @@ function DynamicPPL.predict(
217110 DynamicPPL. VarInfo (),
218111 (
219112 DynamicPPL. LogPriorAccumulator (),
113+ DynamicPPL. LogJacobianAccumulator (),
220114 DynamicPPL. LogLikelihoodAccumulator (),
221115 DynamicPPL. ValuesAsInModelAccumulator (false ),
222116 ),
@@ -235,9 +129,23 @@ function DynamicPPL.predict(
235129 varinfo,
236130 DynamicPPL. InitFromParams (values_dict, DynamicPPL. InitFromPrior ()),
237131 )
238- DynamicPPL. ParamsWithStats (varinfo, nothing )
132+ vals = DynamicPPL. getacc (varinfo, Val (:ValuesAsInModel )). values
133+ varname_vals = mapreduce (
134+ collect,
135+ vcat,
136+ map (AbstractPPL. varname_and_value_leaves, keys (vals), values (vals)),
137+ )
138+
139+ return (varname_and_values= varname_vals, logp= DynamicPPL. getlogjoint (varinfo))
239140 end
240- chain_result = DynamicPPL. to_chains (MCMCChains. Chains, predictive_samples)
141+
142+ chain_result = reduce (
143+ MCMCChains. chainscat,
144+ [
145+ _predictive_samples_to_chains (predictive_samples[:, chain_idx]) for
146+ chain_idx in 1 : size (predictive_samples, 2 )
147+ ],
148+ )
241149 parameter_names = if include_all
242150 MCMCChains. names (chain_result, :parameters )
243151 else
@@ -256,6 +164,45 @@ function DynamicPPL.predict(
256164 )
257165end
258166
167+ function _predictive_samples_to_arrays (predictive_samples)
168+ variable_names_set = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
169+
170+ sample_dicts = map (predictive_samples) do sample
171+ varname_value_pairs = sample. varname_and_values
172+ varnames = map (first, varname_value_pairs)
173+ values = map (last, varname_value_pairs)
174+ for varname in varnames
175+ push! (variable_names_set, varname)
176+ end
177+
178+ return DynamicPPL. OrderedCollections. OrderedDict (zip (varnames, values))
179+ end
180+
181+ variable_names = collect (variable_names_set)
182+ variable_values = [
183+ get (sample_dicts[i], key, missing ) for i in eachindex (sample_dicts),
184+ key in variable_names
185+ ]
186+
187+ return variable_names, variable_values
188+ end
189+
190+ function _predictive_samples_to_chains (predictive_samples)
191+ variable_names, variable_values = _predictive_samples_to_arrays (predictive_samples)
192+ variable_names_symbols = map (Symbol, variable_names)
193+
194+ internal_parameters = [:lp ]
195+ log_probabilities = reshape ([sample. logp for sample in predictive_samples], :, 1 )
196+
197+ parameter_names = [variable_names_symbols; internal_parameters]
198+ parameter_values = hcat (variable_values, log_probabilities)
199+ parameter_values = MCMCChains. concretize (parameter_values)
200+
201+ return MCMCChains. Chains (
202+ parameter_values, parameter_names, (internals= internal_parameters,)
203+ )
204+ end
205+
259206"""
260207 returned(model::Model, chain::MCMCChains.Chains)
261208
0 commit comments