Skip to content

Commit efa0d34

Browse files
Logprob derivation of Max for Discrete IID distributions (#6790)
1 parent c3f93ba commit efa0d34

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

pymc/logprob/order.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
_logprob_helper,
5757
)
5858
from pymc.logprob.rewriting import measurable_ir_rewrites_db
59+
from pymc.math import logdiffexp
5960
from pymc.pytensorf import constant_fold
6061

6162

@@ -66,6 +67,13 @@ class MeasurableMax(Max):
6667
MeasurableVariable.register(MeasurableMax)
6768

6869

70+
class MeasurableMaxDiscrete(Max):
71+
"""A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables"""
72+
73+
74+
MeasurableVariable.register(MeasurableMaxDiscrete)
75+
76+
6977
@node_rewriter([Max])
7078
def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
7179
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
@@ -87,10 +95,6 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
8795
if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0):
8896
return None
8997

90-
# TODO: We are currently only supporting continuous rvs
91-
if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"):
92-
return None
93-
9498
# univariate i.i.d. test which also rules out other distributions
9599
for params in base_var.owner.inputs[3:]:
96100
if params.type.ndim != 0:
@@ -102,7 +106,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
102106
if axis != base_var_dims:
103107
return None
104108

105-
measurable_max = MeasurableMax(list(axis))
109+
# distinguish measurable discrete and continuous (because logprob is different)
110+
if base_var.owner.op.dtype.startswith("int"):
111+
measurable_max = MeasurableMaxDiscrete(list(axis))
112+
else:
113+
measurable_max = MeasurableMax(list(axis))
114+
106115
max_rv_node = measurable_max.make_node(base_var)
107116
max_rv = max_rv_node.outputs
108117

@@ -131,6 +140,26 @@ def max_logprob(op, values, base_rv, **kwargs):
131140
return logprob
132141

133142

143+
@_logprob.register(MeasurableMaxDiscrete)
144+
def max_logprob_discrete(op, values, base_rv, **kwargs):
145+
r"""Compute the log-likelihood graph for the `Max` operation.
146+
147+
The formula that we use here is :
148+
.. math::
149+
\ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
150+
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
151+
"""
152+
(value,) = values
153+
logcdf = _logcdf_helper(base_rv, value)
154+
logcdf_prev = _logcdf_helper(base_rv, value - 1)
155+
156+
[n] = constant_fold([base_rv.size])
157+
158+
logprob = logdiffexp(n * logcdf, n * logcdf_prev)
159+
160+
return logprob
161+
162+
134163
class MeasurableMaxNeg(Max):
135164
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
136165
This shows up in the graph of min, which is (neg(max(neg(x)))."""

tests/logprob/test_order.py

+24
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import numpy as np
4040
import pytensor.tensor as pt
4141
import pytest
42+
import scipy.stats as sp
4243

4344
import pymc as pm
4445

@@ -230,3 +231,26 @@ def test_min_non_mul_elemwise_fails():
230231
x_min_value = pt.vector("x_min_value")
231232
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
232233
x_min_logprob = logp(x_min, x_min_value)
234+
235+
236+
@pytest.mark.parametrize(
237+
"mu, size, value, axis",
238+
[(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
239+
)
240+
def test_max_discrete(mu, size, value, axis):
241+
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
242+
x_max = pt.max(x, axis=axis)
243+
x_max_value = pt.scalar("x_max_value")
244+
x_max_logprob = logp(x_max, x_max_value)
245+
246+
test_value = value
247+
248+
n = size
249+
exp_rv = sp.poisson(mu).cdf(test_value) ** n
250+
exp_rv_prev = sp.poisson(mu).cdf(test_value - 1) ** n
251+
252+
np.testing.assert_allclose(
253+
np.log(exp_rv - exp_rv_prev),
254+
(x_max_logprob.eval({x_max_value: test_value})),
255+
rtol=1e-06,
256+
)

0 commit comments

Comments
 (0)