56
56
_logprob_helper ,
57
57
)
58
58
from pymc .logprob .rewriting import measurable_ir_rewrites_db
59
+ from pymc .math import logdiffexp
59
60
from pymc .pytensorf import constant_fold
60
61
61
62
@@ -66,6 +67,13 @@ class MeasurableMax(Max):
66
67
MeasurableVariable .register (MeasurableMax )
67
68
68
69
70
+ class MeasurableMaxDiscrete (Max ):
71
+ """A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables"""
72
+
73
+
74
+ MeasurableVariable .register (MeasurableMaxDiscrete )
75
+
76
+
69
77
@node_rewriter ([Max ])
70
78
def find_measurable_max (fgraph : FunctionGraph , node : Node ) -> Optional [List [TensorVariable ]]:
71
79
rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
@@ -87,10 +95,6 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
87
95
if not (isinstance (base_var .owner .op , RandomVariable ) and base_var .owner .op .ndim_supp == 0 ):
88
96
return None
89
97
90
- # TODO: We are currently only supporting continuous rvs
91
- if isinstance (base_var .owner .op , RandomVariable ) and base_var .owner .op .dtype .startswith ("int" ):
92
- return None
93
-
94
98
# univariate i.i.d. test which also rules out other distributions
95
99
for params in base_var .owner .inputs [3 :]:
96
100
if params .type .ndim != 0 :
@@ -102,7 +106,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
102
106
if axis != base_var_dims :
103
107
return None
104
108
105
- measurable_max = MeasurableMax (list (axis ))
109
+ # distinguish measurable discrete and continuous (because logprob is different)
110
+ if base_var .owner .op .dtype .startswith ("int" ):
111
+ measurable_max = MeasurableMaxDiscrete (list (axis ))
112
+ else :
113
+ measurable_max = MeasurableMax (list (axis ))
114
+
106
115
max_rv_node = measurable_max .make_node (base_var )
107
116
max_rv = max_rv_node .outputs
108
117
@@ -131,6 +140,26 @@ def max_logprob(op, values, base_rv, **kwargs):
131
140
return logprob
132
141
133
142
143
+ @_logprob .register (MeasurableMaxDiscrete )
144
+ def max_logprob_discrete (op , values , base_rv , ** kwargs ):
145
+ r"""Compute the log-likelihood graph for the `Max` operation.
146
+
147
+ The formula that we use here is :
148
+ .. math::
149
+ \ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
150
+ where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
151
+ """
152
+ (value ,) = values
153
+ logcdf = _logcdf_helper (base_rv , value )
154
+ logcdf_prev = _logcdf_helper (base_rv , value - 1 )
155
+
156
+ [n ] = constant_fold ([base_rv .size ])
157
+
158
+ logprob = logdiffexp (n * logcdf , n * logcdf_prev )
159
+
160
+ return logprob
161
+
162
+
134
163
class MeasurableMaxNeg (Max ):
135
164
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
136
165
This shows up in the graph of min, which is (neg(max(neg(x)))."""
0 commit comments