@@ -42,6 +42,148 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
4242 return keys (c. info. varname_to_symbol)
4343end
4444
45+ """
46+ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
47+
48+ Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
49+ in `chain`, and return the resulting `Chains`.
50+
51+ The `model` passed to `predict` is often different from the one used to generate `chain`.
52+ Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
53+ data points), while the model you pass to `predict` may mark these same variables as missing
54+ or unobserved. Calling `predict` then leverages the previously inferred parameter values to
55+ simulate what new, unobserved data might look like, given your posterior beliefs.
56+
57+ For each parameter configuration in `chain`:
58+ 1. All random variables present in `chain` are fixed to their sampled values.
59+ 2. Any variables not included in `chain` are sampled from their prior distributions.
60+
61+ If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
62+ the samples in `chain`. This is useful when you want to sample only new variables from the posterior
63+ predictive distribution.
64+
65+ # Examples
66+ ```jldoctest
67+ using AbstractMCMC, Distributions, DynamicPPL, Random
68+
69+ @model function linear_reg(x, y, σ = 0.1)
70+ β ~ Normal(0, 1)
71+ for i in eachindex(y)
72+ y[i] ~ Normal(β * x[i], σ)
73+ end
74+ end
75+
76+ # Generate synthetic chain using known ground truth parameter
77+ ground_truth_β = 2.0
78+
79+ # Create chain of samples from a normal distribution centered on ground truth
80+ β_chain = MCMCChains.Chains(
81+ rand(Normal(ground_truth_β, 0.002), 1000), [:β,]
82+ )
83+
84+ # Generate predictions for two test points
85+ xs_test = [10.1, 10.2]
86+
87+ m_train = linear_reg(xs_test, fill(missing, length(xs_test)))
88+
89+ predictions = DynamicPPL.AbstractPPL.predict(
90+ Random.default_rng(), m_train, β_chain
91+ )
92+
93+ ys_pred = vec(mean(Array(predictions); dims=1))
94+
95+ # Check if predictions match expected values within tolerance
96+ (
97+ isapprox(ys_pred[1], ground_truth_β * xs_test[1], atol = 0.01),
98+ isapprox(ys_pred[2], ground_truth_β * xs_test[2], atol = 0.01)
99+ )
100+
101+ # output
102+
103+ (true, true)
104+ ```
105+ """
106+ function DynamicPPL. predict (
107+ rng:: DynamicPPL.Random.AbstractRNG ,
108+ model:: DynamicPPL.Model ,
109+ chain:: MCMCChains.Chains ;
110+ include_all= false ,
111+ )
112+ parameter_only_chain = MCMCChains. get_sections (chain, :parameters )
113+ varinfo = DynamicPPL. VarInfo (model)
114+
115+ iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
116+ predictive_samples = map (iters) do (sample_idx, chain_idx)
117+ DynamicPPL. setval_and_resample! (varinfo, parameter_only_chain, sample_idx, chain_idx)
118+ model (rng, varinfo, DynamicPPL. SampleFromPrior ())
119+
120+ vals = DynamicPPL. values_as_in_model (model, varinfo)
121+ varname_vals = mapreduce (
122+ collect,
123+ vcat,
124+ map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals)),
125+ )
126+
127+ return (varname_and_values= varname_vals, logp= DynamicPPL. getlogp (varinfo))
128+ end
129+
130+ chain_result = reduce (
131+ MCMCChains. chainscat,
132+ [
133+ _predictive_samples_to_chains (predictive_samples[:, chain_idx]) for
134+ chain_idx in 1 : size (predictive_samples, 2 )
135+ ],
136+ )
137+ parameter_names = if include_all
138+ MCMCChains. names (chain_result, :parameters )
139+ else
140+ filter (
141+ k -> ! (k in MCMCChains. names (parameter_only_chain, :parameters )),
142+ names (chain_result, :parameters ),
143+ )
144+ end
145+ return chain_result[parameter_names]
146+ end
147+
148+ function _predictive_samples_to_arrays (predictive_samples)
149+ variable_names_set = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
150+
151+ sample_dicts = map (predictive_samples) do sample
152+ varname_value_pairs = sample. varname_and_values
153+ varnames = map (first, varname_value_pairs)
154+ values = map (last, varname_value_pairs)
155+ for varname in varnames
156+ push! (variable_names_set, varname)
157+ end
158+
159+ return DynamicPPL. OrderedCollections. OrderedDict (zip (varnames, values))
160+ end
161+
162+ variable_names = collect (variable_names_set)
163+ variable_values = [
164+ get (sample_dicts[i], key, missing ) for i in eachindex (sample_dicts),
165+ key in variable_names
166+ ]
167+
168+ return variable_names, variable_values
169+ end
170+
171+ function _predictive_samples_to_chains (predictive_samples)
172+ variable_names, variable_values = _predictive_samples_to_arrays (predictive_samples)
173+ variable_names_symbols = map (Symbol, variable_names)
174+
175+ internal_parameters = [:lp ]
176+ log_probabilities = reshape ([sample. logp for sample in predictive_samples], :, 1 )
177+
178+ parameter_names = [variable_names_symbols; internal_parameters]
179+ parameter_values = hcat (variable_values, log_probabilities)
180+ parameter_values = MCMCChains. concretize (parameter_values)
181+
182+ return MCMCChains. Chains (
183+ parameter_values, parameter_names, (internals= internal_parameters,)
184+ )
185+ end
186+
45187"""
46188 returned(model::Model, chain::MCMCChains.Chains)
47189
0 commit comments