22from typing import Sequence , Tuple , Union
33
44import numpy as np
5+ import pymc
56import pytensor .tensor as pt
7+ from arviz import dict_to_dataset
68from pymc import SymbolicRandomVariable
9+ from pymc .backends .arviz import coords_and_dims_for_inferencedata
710from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
811from pymc .distributions .transforms import Chain
912from pymc .logprob .abstract import _logprob
1013from pymc .logprob .basic import conditional_logp
1114from pymc .logprob .transforms import IntervalTransform
1215from pymc .model import Model
13- from pymc .pytensorf import constant_fold , inputvars
16+ from pymc .pytensorf import compile_pymc , constant_fold , inputvars
17+ from pymc .util import _get_seeds_per_chain , dataset_to_point_list , treedict
1418from pytensor import Mode
1519from pytensor .compile import SharedVariable
1620from pytensor .compile .builders import OpFromGraph
17- from pytensor .graph import Constant , FunctionGraph , ancestors , clone_replace
21+ from pytensor .graph import (
22+ Constant ,
23+ FunctionGraph ,
24+ ancestors ,
25+ clone_replace ,
26+ vectorize_graph ,
27+ )
1828from pytensor .scan import map as scan_map
1929from pytensor .tensor import TensorVariable
2030from pytensor .tensor .elemwise import Elemwise
31+ from pytensor .tensor .shape import Shape
32+ from pytensor .tensor .special import log_softmax
2133
2234__all__ = ["MarginalModel" ]
2335
24- from pytensor .tensor .shape import Shape
25-
2636
2737class MarginalModel (Model ):
2838 """Subclass of PyMC Model that implements functionality for automatic
@@ -74,6 +84,7 @@ class MarginalModel(Model):
7484 def __init__ (self , * args , ** kwargs ):
7585 super ().__init__ (* args , ** kwargs )
7686 self .marginalized_rvs = []
87+ self ._marginalized_named_vars_to_dims = treedict ()
7788
7889 def _delete_rv_mappings (self , rv : TensorVariable ) -> None :
7990 """Remove all model mappings referring to rv
@@ -205,8 +216,9 @@ def clone(self):
205216 vars = self .basic_RVs + self .potentials + self .deterministics + self .marginalized_rvs
206217 cloned_vars = clone_replace (vars )
207218 vars_to_clone = {var : cloned_var for var , cloned_var in zip (vars , cloned_vars )}
219+ m .vars_to_clone = vars_to_clone
208220
209- m .named_vars = {name : vars_to_clone [var ] for name , var in self .named_vars .items ()}
221+ m .named_vars = treedict ( {name : vars_to_clone [var ] for name , var in self .named_vars .items ()})
210222 m .named_vars_to_dims = self .named_vars_to_dims
211223 m .values_to_rvs = {i : vars_to_clone [rv ] for i , rv in self .values_to_rvs .items ()}
212224 m .rvs_to_values = {vars_to_clone [rv ]: i for rv , i in self .rvs_to_values .items ()}
@@ -220,11 +232,18 @@ def clone(self):
220232 m .deterministics = [vars_to_clone [det ] for det in self .deterministics ]
221233
222234 m .marginalized_rvs = [vars_to_clone [rv ] for rv in self .marginalized_rvs ]
235+ m ._marginalized_named_vars_to_dims = self ._marginalized_named_vars_to_dims
223236 return m
224237
225- def marginalize (self , rvs_to_marginalize : Union [TensorVariable , Sequence [TensorVariable ]]):
238+ def marginalize (
239+ self ,
240+ rvs_to_marginalize : Union [TensorVariable , str , Sequence [TensorVariable ], Sequence [str ]],
241+ ):
226242 if not isinstance (rvs_to_marginalize , Sequence ):
227243 rvs_to_marginalize = (rvs_to_marginalize ,)
244+ rvs_to_marginalize = [
245+ self [var ] if isinstance (var , str ) else var for var in rvs_to_marginalize
246+ ]
228247
229248 supported_dists = (Bernoulli , Categorical , DiscreteUniform )
230249 for rv_to_marginalize in rvs_to_marginalize :
@@ -238,12 +257,233 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
238257 f"Supported distribution include { supported_dists } "
239258 )
240259
260+ if rv_to_marginalize .name in self .named_vars_to_dims :
261+ dims = self .named_vars_to_dims [rv_to_marginalize .name ]
262+ self ._marginalized_named_vars_to_dims [rv_to_marginalize .name ] = dims
263+
241264 self ._delete_rv_mappings (rv_to_marginalize )
242265 self .marginalized_rvs .append (rv_to_marginalize )
243266
244267 # Raise errors and warnings immediately
245268 self .clone ()._marginalize (user_warnings = True )
246269
270+ def _to_transformed (self ):
271+ "Create a function from the untransformed space to the transformed space"
272+ transformed_rvs = []
273+ transformed_names = []
274+
275+ for rv in self .free_RVs :
276+ transform = self .rvs_to_transforms .get (rv )
277+ if transform is None :
278+ transformed_rvs .append (rv )
279+ transformed_names .append (rv .name )
280+ else :
281+ transformed_rv = transform .forward (rv , * rv .owner .inputs )
282+ transformed_rvs .append (transformed_rv )
283+ transformed_names .append (self .rvs_to_values [rv ].name )
284+
285+ fn = self .compile_fn (inputs = self .free_RVs , outs = transformed_rvs )
286+ return fn , transformed_names
287+
288+ def unmarginalize (self , rvs_to_unmarginalize ):
289+ for rv in rvs_to_unmarginalize :
290+ self .marginalized_rvs .remove (rv )
291+ if rv .name in self ._marginalized_named_vars_to_dims :
292+ dims = self ._marginalized_named_vars_to_dims .pop (rv .name )
293+ else :
294+ dims = None
295+ self .register_rv (rv , name = rv .name , dims = dims )
296+
297+ def recover_marginals (
298+ self ,
299+ idata ,
300+ var_names = None ,
301+ return_samples = True ,
302+ extend_inferencedata = True ,
303+ random_seed = None ,
304+ ):
305+ """Computes posterior log-probabilities and samples of marginalized variables
306+ conditioned on parameters of the model given InferenceData with posterior group
307+
308+ When there are multiple marginalized variables, each marginalized variable is
309+ conditioned on both the parameters and the other variables still marginalized
310+
311+ All log-probabilities are within the transformed space
312+
313+ Parameters
314+ ----------
315+ idata : InferenceData
316+ InferenceData with posterior group
317+ var_names : sequence of str, optional
318+ List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables
319+ return_samples : bool, default True
320+ If True, also return samples of the marginalized variables
321+ extend_inferencedata : bool, default True
322+ Whether to extend the original InferenceData or return a new one
323+ random_seed: int, array-like of int or SeedSequence, optional
324+ Seed used to generating samples
325+
326+ Returns
327+ -------
328+ idata : InferenceData
329+ InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group
330+
331+ .. code-block:: python
332+
333+ import pymc as pm
334+ from pymc_experimental import MarginalModel
335+
336+ with MarginalModel() as m:
337+ p = pm.Beta("p", 1, 1)
338+ x = pm.Bernoulli("x", p=p, shape=(3,))
339+ y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
340+
341+ m.marginalize([x])
342+
343+ idata = pm.sample()
344+ m.recover_marginals(idata, var_names=["x"])
345+
346+
347+ """
348+ if var_names is None :
349+ var_names = [var .name for var in self .marginalized_rvs ]
350+
351+ var_names = [var if isinstance (var , str ) else var .name for var in var_names ]
352+ vars_to_recover = [v for v in self .marginalized_rvs if v .name in var_names ]
353+ missing_names = [v .name for v in vars_to_recover if v not in self .marginalized_rvs ]
354+ if missing_names :
355+ raise ValueError (f"Unrecognized var_names: { missing_names } " )
356+
357+ if return_samples and random_seed is not None :
358+ seeds = _get_seeds_per_chain (random_seed , len (vars_to_recover ))
359+ else :
360+ seeds = [None ] * len (vars_to_recover )
361+
362+ posterior = idata .posterior
363+
364+ # Remove Deterministics
365+ posterior_values = posterior [
366+ [rv .name for rv in self .free_RVs if rv not in self .marginalized_rvs ]
367+ ]
368+
369+ sample_dims = ("chain" , "draw" )
370+ posterior_pts , stacked_dims = dataset_to_point_list (posterior_values , sample_dims )
371+
372+ # Handle Transforms
373+ transform_fn , transform_names = self ._to_transformed ()
374+
375+ def transform_input (inputs ):
376+ return dict (zip (transform_names , transform_fn (inputs )))
377+
378+ posterior_pts = [transform_input (vs ) for vs in posterior_pts ]
379+
380+ rv_dict = {}
381+ rv_dims = {}
382+ for seed , rv in zip (seeds , vars_to_recover ):
383+ supported_dists = (Bernoulli , Categorical , DiscreteUniform )
384+ if not isinstance (rv .owner .op , supported_dists ):
385+ raise NotImplementedError (
386+ f"RV with distribution { rv .owner .op } cannot be recovered. "
387+ f"Supported distribution include { supported_dists } "
388+ )
389+
390+ m = self .clone ()
391+ rv = m .vars_to_clone [rv ]
392+ m .unmarginalize ([rv ])
393+ dependent_vars = find_conditional_dependent_rvs (rv , m .basic_RVs )
394+ joint_logps = m .logp (vars = dependent_vars + [rv ], sum = False )
395+
396+ marginalized_value = m .rvs_to_values [rv ]
397+ other_values = [v for v in m .value_vars if v is not marginalized_value ]
398+
399+ # Handle batch dims for marginalized value and its dependent RVs
400+ joint_logp = joint_logps [- 1 ]
401+ for dv in joint_logps [:- 1 ]:
402+ dbcast = dv .type .broadcastable
403+ mbcast = marginalized_value .type .broadcastable
404+ mbcast = (True ,) * (len (dbcast ) - len (mbcast )) + mbcast
405+ values_axis_bcast = [
406+ i for i , (m , v ) in enumerate (zip (mbcast , dbcast )) if m and not v
407+ ]
408+ joint_logp += dv .sum (values_axis_bcast )
409+
410+ rv_shape = constant_fold (tuple (rv .shape ))
411+ rv_domain = get_domain_of_finite_discrete_rv (rv )
412+ rv_domain_tensor = pt .moveaxis (
413+ pt .full (
414+ (* rv_shape , len (rv_domain )),
415+ rv_domain ,
416+ dtype = rv .dtype ,
417+ ),
418+ - 1 ,
419+ 0 ,
420+ )
421+
422+ joint_logps = vectorize_graph (
423+ joint_logp ,
424+ replace = {marginalized_value : rv_domain_tensor },
425+ )
426+ joint_logps = pt .moveaxis (joint_logps , 0 , - 1 )
427+
428+ rv_loglike_fn = None
429+ joint_logps_norm = log_softmax (joint_logps , axis = - 1 )
430+ if return_samples :
431+ sample_rv_outs = pymc .Categorical .dist (logit_p = joint_logps )
432+ if isinstance (rv .owner .op , DiscreteUniform ):
433+ sample_rv_outs += rv_domain [0 ]
434+
435+ rv_loglike_fn = compile_pymc (
436+ inputs = other_values ,
437+ outputs = [joint_logps_norm , sample_rv_outs ],
438+ on_unused_input = "ignore" ,
439+ random_seed = seed ,
440+ )
441+ else :
442+ rv_loglike_fn = compile_pymc (
443+ inputs = other_values ,
444+ outputs = joint_logps_norm ,
445+ on_unused_input = "ignore" ,
446+ random_seed = seed ,
447+ )
448+
449+ logvs = [rv_loglike_fn (** vs ) for vs in posterior_pts ]
450+
451+ logps = None
452+ samples = None
453+ if return_samples :
454+ logps , samples = zip (* logvs )
455+ logps = np .array (logps )
456+ samples = np .array (samples )
457+ rv_dict [rv .name ] = samples .reshape (
458+ tuple (len (coord ) for coord in stacked_dims .values ()) + samples .shape [1 :],
459+ )
460+ else :
461+ logps = np .array (logvs )
462+
463+ rv_dict ["lp_" + rv .name ] = logps .reshape (
464+ tuple (len (coord ) for coord in stacked_dims .values ()) + logps .shape [1 :],
465+ )
466+ if rv .name in m .named_vars_to_dims :
467+ rv_dims [rv .name ] = list (m .named_vars_to_dims [rv .name ])
468+ rv_dims ["lp_" + rv .name ] = rv_dims [rv .name ] + ["lp_" + rv .name + "_dim" ]
469+
470+ coords , dims = coords_and_dims_for_inferencedata (self )
471+ dims .update (rv_dims )
472+ rv_dataset = dict_to_dataset (
473+ rv_dict ,
474+ library = pymc ,
475+ dims = dims ,
476+ coords = coords ,
477+ default_dims = list (sample_dims ),
478+ skip_event_dims = True ,
479+ )
480+
481+ if extend_inferencedata :
482+ idata .posterior = idata .posterior .assign (rv_dataset )
483+ return idata
484+ else :
485+ return rv_dataset
486+
247487
248488class MarginalRV (SymbolicRandomVariable ):
249489 """Base class for Marginalized RVs"""
@@ -444,14 +684,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
444684 # PyMC does not allow RVs in the logp graph, even if we are just using the shape
445685 marginalized_rv_shape = constant_fold (tuple (marginalized_rv .shape ))
446686 marginalized_rv_domain = get_domain_of_finite_discrete_rv (marginalized_rv )
447- marginalized_rv_domain_tensor = pt .swapaxes (
687+ marginalized_rv_domain_tensor = pt .moveaxis (
448688 pt .full (
449689 (* marginalized_rv_shape , len (marginalized_rv_domain )),
450690 marginalized_rv_domain ,
451691 dtype = marginalized_rv .dtype ,
452692 ),
453- axis1 = 0 ,
454- axis2 = - 1 ,
693+ - 1 ,
694+ 0 ,
455695 )
456696
457697 # Arbitrary cutoff to switch to Scan implementation to keep graph size under control
0 commit comments