7575from pymc .step_methods .arraystep import BlockedStep , PopulationArrayStepShared
7676from pymc .step_methods .hmc import quadpotential
7777from pymc .util import (
78- chains_and_samples ,
7978 dataset_to_point_list ,
8079 get_default_varnames ,
8180 get_untransformed_name ,
@@ -1765,6 +1764,7 @@ def sample_posterior_predictive(
17651764 trace ,
17661765 model : Optional [Model ] = None ,
17671766 var_names : Optional [List [str ]] = None ,
1767+ sample_dims : Optional [List [str ]] = None ,
17681768 random_seed : RandomState = None ,
17691769 progressbar : bool = True ,
17701770 return_inferencedata : bool = True ,
@@ -1785,6 +1785,10 @@ def sample_posterior_predictive(
17851785 generally be the model used to generate the ``trace``, but it doesn't need to be.
17861786 var_names : Iterable[str]
17871787 Names of variables for which to compute the posterior predictive samples.
1788+ sample_dims : list of str, optional
1789+ Dimensions over which to loop and generate posterior predictive samples.
1790+ When `sample_dims` is ``None`` (default) both "chain" and "draw" are considered sample
1791+ dimensions. Only taken into account when `trace` is InferenceData or Dataset.
17881792 random_seed : int, RandomState or Generator, optional
17891793 Seed for the random number generator.
17901794 progressbar : bool
@@ -1821,6 +1825,14 @@ def sample_posterior_predictive(
18211825 thinned_idata = idata.sel(draw=slice(None, None, 5))
18221826 with model:
18231827 idata.extend(pymc.sample_posterior_predictive(thinned_idata))
1828+
1829+ Generate 5 posterior predictive samples per posterior sample.
1830+
1831+ .. code:: python
1832+
1833+ expanded_data = idata.posterior.expand_dims(pred_id=5)
1834+ with model:
1835+ idata.extend(pymc.sample_posterior_predictive(expanded_data))
18241836 """
18251837
18261838 _trace : Union [MultiTrace , PointList ]
@@ -1829,36 +1841,34 @@ def sample_posterior_predictive(
18291841 idata_kwargs = {}
18301842 else :
18311843 idata_kwargs = idata_kwargs .copy ()
1844+ if sample_dims is None :
1845+ sample_dims = ["chain" , "draw" ]
18321846 constant_data : Dict [str , np .ndarray ] = {}
18331847 trace_coords : Dict [str , np .ndarray ] = {}
18341848 if "coords" not in idata_kwargs :
18351849 idata_kwargs ["coords" ] = {}
1850+ idata : Optional [InferenceData ] = None
1851+ stacked_dims = None
18361852 if isinstance (trace , InferenceData ):
1837- idata_kwargs ["coords" ].setdefault ("draw" , trace ["posterior" ]["draw" ])
1838- idata_kwargs ["coords" ].setdefault ("chain" , trace ["posterior" ]["chain" ])
18391853 _constant_data = getattr (trace , "constant_data" , None )
18401854 if _constant_data is not None :
18411855 trace_coords .update ({str (k ): v .data for k , v in _constant_data .coords .items ()})
18421856 constant_data .update ({str (k ): v .data for k , v in _constant_data .items ()})
1843- trace_coords .update ({str (k ): v .data for k , v in trace ["posterior" ].coords .items ()})
1844- _trace = dataset_to_point_list (trace ["posterior" ])
1845- nchain , len_trace = chains_and_samples (trace )
1846- elif isinstance (trace , xarray .Dataset ):
1847- idata_kwargs ["coords" ].setdefault ("draw" , trace ["draw" ])
1848- idata_kwargs ["coords" ].setdefault ("chain" , trace ["chain" ])
1857+ idata = trace
1858+ trace = trace ["posterior" ]
1859+ if isinstance (trace , xarray .Dataset ):
18491860 trace_coords .update ({str (k ): v .data for k , v in trace .coords .items ()})
1850- _trace = dataset_to_point_list (trace )
1851- nchain , len_trace = chains_and_samples ( trace )
1861+ _trace , stacked_dims = dataset_to_point_list (trace , sample_dims )
1862+ nchain = 1
18521863 elif isinstance (trace , MultiTrace ):
18531864 _trace = trace
18541865 nchain = _trace .nchains
1855- len_trace = len (_trace )
18561866 elif isinstance (trace , list ) and all (isinstance (x , dict ) for x in trace ):
18571867 _trace = trace
18581868 nchain = 1
1859- len_trace = len (_trace )
18601869 else :
18611870 raise TypeError (f"Unsupported type for `trace` argument: { type (trace )} ." )
1871+ len_trace = len (_trace )
18621872
18631873 if isinstance (_trace , MultiTrace ):
18641874 samples = sum (len (v ) for v in _trace ._straces .values ())
@@ -1961,23 +1971,30 @@ def sample_posterior_predictive(
19611971 ppc_trace = ppc_trace_t .trace_dict
19621972
19631973 for k , ary in ppc_trace .items ():
1964- ppc_trace [k ] = ary .reshape ((nchain , len_trace , * ary .shape [1 :]))
1974+ if stacked_dims is not None :
1975+ ppc_trace [k ] = ary .reshape (
1976+ (* [len (coord ) for coord in stacked_dims .values ()], * ary .shape [1 :])
1977+ )
1978+ else :
1979+ ppc_trace [k ] = ary .reshape ((nchain , len_trace , * ary .shape [1 :]))
19651980
19661981 if not return_inferencedata :
19671982 return ppc_trace
19681983 ikwargs : Dict [str , Any ] = dict (model = model , ** idata_kwargs )
1984+ ikwargs .setdefault ("sample_dims" , sample_dims )
1985+ if stacked_dims is not None :
1986+ coords = ikwargs .get ("coords" , {})
1987+ ikwargs ["coords" ] = {** stacked_dims , ** coords }
19691988 if predictions :
19701989 if extend_inferencedata :
1971- ikwargs .setdefault ("idata_orig" , trace )
1990+ ikwargs .setdefault ("idata_orig" , idata )
19721991 ikwargs .setdefault ("inplace" , True )
19731992 return pm .predictions_to_inference_data (ppc_trace , ** ikwargs )
1974- converter = pm .backends .arviz .InferenceDataConverter (posterior_predictive = ppc_trace , ** ikwargs )
1975- converter .nchains = nchain
1976- converter .ndraws = len_trace
1977- idata_pp = converter .to_inference_data ()
1978- if extend_inferencedata :
1979- trace .extend (idata_pp )
1980- return trace
1993+ idata_pp = pm .to_inference_data (posterior_predictive = ppc_trace , ** ikwargs )
1994+
1995+ if extend_inferencedata and idata is not None :
1996+ idata .extend (idata_pp )
1997+ return idata
19811998 return idata_pp
19821999
19832000
0 commit comments