1+ import warnings
2+
13from collections .abc import Sequence
24
35import numpy as np
46import pytensor .tensor as pt
57
68from pymc .distributions import Bernoulli , Categorical , DiscreteUniform
9+ from pymc .distributions .distribution import _support_point , support_point
710from pymc .logprob .abstract import MeasurableOp , _logprob
811from pymc .logprob .basic import conditional_logp , logp
912from pymc .pytensorf import constant_fold
1013from pytensor import Variable
1114from pytensor .compile .builders import OpFromGraph
1215from pytensor .compile .mode import Mode
13- from pytensor .graph import Op , vectorize_graph
16+ from pytensor .graph import FunctionGraph , Op , vectorize_graph
17+ from pytensor .graph .basic import equal_computations
1418from pytensor .graph .replace import clone_replace , graph_replace
1519from pytensor .scan import map as scan_map
1620from pytensor .scan import scan
1721from pytensor .tensor import TensorVariable
22+ from pytensor .tensor .random .type import RandomType
1823
1924from pymc_extras .distributions import DiscreteMarkovChain
2025
2126
2227class MarginalRV (OpFromGraph , MeasurableOp ):
2328 """Base class for Marginalized RVs"""
2429
25- def __init__ (self , * args , dims_connections : tuple [tuple [int | None ]], ** kwargs ) -> None :
30+ def __init__ (
31+ self ,
32+ * args ,
33+ dims_connections : tuple [tuple [int | None ], ...],
34+ dims : tuple [Variable , ...],
35+ ** kwargs ,
36+ ) -> None :
2637 self .dims_connections = dims_connections
38+ self .dims = dims
2739 super ().__init__ (* args , ** kwargs )
2840
2941 @property
@@ -43,6 +55,74 @@ def support_axes(self) -> tuple[tuple[int]]:
4355 )
4456 return tuple (support_axes_vars )
4557
58+ def __eq__ (self , other ):
59+ # Just to allow easy testing of equivalent models,
60+ # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
61+ if type (self ) is not type (other ):
62+ return False
63+
64+ return equal_computations (
65+ self .inner_outputs ,
66+ other .inner_outputs ,
67+ self .inner_inputs ,
68+ other .inner_inputs ,
69+ )
70+
71+ def __hash__ (self ):
72+ # Just to allow easy testing of equivalent models,
73+ # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
74+ return hash ((type (self ), len (self .inner_inputs ), len (self .inner_outputs )))
75+
76+
77+ @_support_point .register
78+ def support_point_marginal_rv (op : MarginalRV , rv , * inputs ):
79+ """Support point for a marginalized RV.
80+
81+ The support point of a marginalized RV is the support point of the inner RV,
82+ conditioned on the marginalized RV taking its support point.
83+ """
84+ outputs = rv .owner .outputs
85+
86+ inner_rv = op .inner_outputs [outputs .index (rv )]
87+ marginalized_inner_rv , * other_dependent_inner_rvs = (
88+ out
89+ for out in op .inner_outputs
90+ if out is not inner_rv and not isinstance (out .type , RandomType )
91+ )
92+
93+ # Replace references to inner rvs by the dummy variables (including the marginalized RV)
94+ # This is necessary because the inner RVs may depend on each other
95+ marginalized_inner_rv_dummy = marginalized_inner_rv .clone ()
96+ other_dependent_inner_rv_to_dummies = {
97+ inner_rv : inner_rv .clone () for inner_rv in other_dependent_inner_rvs
98+ }
99+ inner_rv = clone_replace (
100+ inner_rv ,
101+ replace = {marginalized_inner_rv : marginalized_inner_rv_dummy }
102+ | other_dependent_inner_rv_to_dummies ,
103+ )
104+
105+ # Get support point of inner RV and marginalized RV
106+ inner_rv_support_point = support_point (inner_rv )
107+ marginalized_inner_rv_support_point = support_point (marginalized_inner_rv )
108+
109+ replacements = [
110+ # Replace the marginalized RV dummy by its support point
111+ (marginalized_inner_rv_dummy , marginalized_inner_rv_support_point ),
112+ # Replace other dependent RVs dummies by the respective outer outputs.
113+ # PyMC will replace them by their support points later
114+ * (
115+ (v , outputs [op .inner_outputs .index (k )])
116+ for k , v in other_dependent_inner_rv_to_dummies .items ()
117+ ),
118+ # Replace outer input RVs
119+ * zip (op .inner_inputs , inputs ),
120+ ]
121+ fgraph = FunctionGraph (outputs = [inner_rv_support_point ], clone = False )
122+ fgraph .replace_all (replacements , import_missing = True )
123+ [rv_support_point ] = fgraph .outputs
124+ return rv_support_point
125+
46126
47127class MarginalFiniteDiscreteRV (MarginalRV ):
48128 """Base class for Marginalized Finite Discrete RVs"""
@@ -132,12 +212,27 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var
132212 Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133213 the inner graph.
134214 """
135- return clone_replace (
215+ return graph_replace (
136216 op .inner_outputs ,
137217 replace = tuple (zip (op .inner_inputs , inputs )),
218+ strict = False ,
138219 )
139220
140221
222+ class NonSeparableLogpWarning (UserWarning ):
223+ pass
224+
225+
226+ def warn_non_separable_logp (values ):
227+ if len (values ) > 1 :
228+ warnings .warn (
229+ "There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
230+ f"Their joint logp terms will be assigned to the first value: { values [0 ]} ." ,
231+ NonSeparableLogpWarning ,
232+ stacklevel = 2 ,
233+ )
234+
235+
141236DUMMY_ZERO = pt .constant (0 , name = "dummy_zero" )
142237
143238
@@ -199,6 +294,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
199294 # Align logp with non-collapsed batch dimensions of first RV
200295 joint_logp = align_logp_dims (dims = op .dims_connections [0 ], logp = joint_logp )
201296
297+ warn_non_separable_logp (values )
202298 # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
203299 dummy_logps = (DUMMY_ZERO ,) * (len (values ) - 1 )
204300 return joint_logp , * dummy_logps
@@ -272,5 +368,6 @@ def step_alpha(logp_emission, log_alpha, log_P):
272368
273369 # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
274370 # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
371+ warn_non_separable_logp (values )
275372 dummy_logps = (DUMMY_ZERO ,) * (len (values ) - 1 )
276373 return joint_logp , * dummy_logps
0 commit comments