Skip to content

Infer logcdf of discrete transformations #7444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 43 additions & 150 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,25 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


from typing import cast

import pytensor.tensor as pt

from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import Max
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableOp,
MeasurableOpMixin,
_logcdf_helper,
_logprob,
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import find_negated_var
from pymc.math import logdiffexp
from pymc.pytensorf import constant_fold

Expand All @@ -73,25 +70,41 @@
if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableMax):
return None # pragma: no cover
if isinstance(node.op, MeasurableMax | MeasurableMaxDiscrete):
return None

Check warning on line 74 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L74

Added line #L74 was not covered by tests

base_var = cast(TensorVariable, node.inputs[0])
[base_var] = node.inputs

if base_var.owner is None:
return None

if not rv_map_feature.request_measurable(node.inputs):
return None

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0):
# We allow Max of RandomVariables or Elemwise of univariate RandomVariables
if isinstance(base_var.owner.op, MeasurableElemwise):
latent_base_vars = [
var
for var in base_var.owner.inputs
if (var.owner and isinstance(var.owner.op, MeasurableOp))
]
if len(latent_base_vars) != 1:
return None

Check warning on line 92 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L92

Added line #L92 was not covered by tests
[latent_base_var] = latent_base_vars
else:
latent_base_var = base_var

latent_op = latent_base_var.owner.op
if not (hasattr(latent_op, "dist_params") and getattr(latent_op, "ndim_supp") == 0):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_var.owner.op.dist_params(base_var.owner):
if not all(params.type.broadcastable):
return None
if not all(
all(params.type.broadcastable) for params in latent_op.dist_params(latent_base_var.owner)
):
return None

base_var = cast(TensorVariable, base_var)

if node.op.axis is None:
axis = tuple(range(base_var.ndim))
Expand All @@ -102,16 +115,11 @@
return None

# distinguish measurable discrete and continuous (because logprob is different)
measurable_max: Max
if base_var.type.dtype.startswith("int"):
measurable_max = MeasurableMaxDiscrete(axis)
else:
measurable_max = MeasurableMax(axis)

max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs

return max_rv
measurable_max_class = (
MeasurableMaxDiscrete if latent_base_var.type.dtype.startswith("int") else MeasurableMax
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idem comment as below

max_rv = cast(TensorVariable, measurable_max_class(axis)(base_var))
return [max_rv]


measurable_ir_rewrites_db.register(
Expand All @@ -127,13 +135,13 @@
r"""Compute the log-likelihood graph for the `Max` operation."""
(value,) = values

logprob = _logprob_helper(base_rv, value)
logcdf = _logcdf_helper(base_rv, value)
base_rv_shape = constant_fold(tuple(base_rv.shape), raise_not_constant=False)
bcast_value = pt.broadcast_to(value, base_rv_shape)
logprob = _logprob_helper(base_rv, bcast_value)[0]
logcdf = _logcdf_helper(base_rv, bcast_value)[0]

[n] = constant_fold([base_rv.size])
logprob = (n - 1) * logcdf + logprob + pt.math.log(n)

return logprob
n = pt.prod(base_rv_shape)
return (n - 1) * logcdf + logprob + pt.math.log(n)


@_logprob.register(MeasurableMaxDiscrete)
Expand All @@ -146,126 +154,11 @@
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
"""
(value,) = values
logcdf = _logcdf_helper(base_rv, value)
logcdf_prev = _logcdf_helper(base_rv, value - 1)

[n] = constant_fold([base_rv.size])

logprob = logdiffexp(n * logcdf, n * logcdf_prev)

return logprob


class MeasurableMaxNeg(MeasurableOpMixin, Max):
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
This shows up in the graph of min, which is (neg(max(neg(x)))."""


class MeasurableDiscreteMaxNeg(MeasurableOpMixin, Max):
"""A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""


@node_rewriter(tracks=[Max])
def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableMaxNeg):
return None # pragma: no cover

base_var = cast(TensorVariable, node.inputs[0])

# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if not (base_var.owner is not None and isinstance(base_var.owner.op, Elemwise)):
return None

base_rv = find_negated_var(base_var)

# negation is rv * (-1). Hence the scalar_op must be Mul
if base_rv is None:
return None

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_rv.owner.op.dist_params(base_rv.owner):
if not all(params.type.broadcastable):
return None

if node.op.axis is None:
axis = tuple(range(base_var.ndim))
else:
# Check whether axis is supported or not
axis = tuple(sorted(node.op.axis))
if axis != tuple(range(base_var.ndim)):
return None

if not rv_map_feature.request_measurable([base_rv]):
return None

# distinguish measurable discrete and continuous (because logprob is different)
measurable_min: Max
if base_rv.type.dtype.startswith("int"):
measurable_min = MeasurableDiscreteMaxNeg(axis)
else:
measurable_min = MeasurableMaxNeg(axis)

return measurable_min.make_node(base_rv).outputs


measurable_ir_rewrites_db.register(
"find_measurable_max_neg",
find_measurable_max_neg,
"basic",
"min",
)


@_logprob.register(MeasurableMaxNeg)
def max_neg_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
"""
(value,) = values

logprob = _logprob_helper(base_rv, -value)
logcdf = _logcdf_helper(base_rv, -value)

[n] = constant_fold([base_rv.size])
logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n)

return logprob


@_logprob.register(MeasurableDiscreteMaxNeg)
def discrete_max_neg_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.

The formula that we use here is :
.. math::
\ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n)
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
"""

(value,) = values

# The cdf of a negative variable is the survival at the negated value
logcdf = pt.log1mexp(_logcdf_helper(base_rv, -value))
logcdf_prev = pt.log1mexp(_logcdf_helper(base_rv, -(value + 1)))

[n] = constant_fold([base_rv.size])

# Now we can use the same expression as the discrete max
logprob = pt.where(
pt.and_(pt.eq(logcdf, -pt.inf), pt.eq(logcdf_prev, -pt.inf)),
-pt.inf,
logdiffexp(n * logcdf_prev, n * logcdf),
)
base_rv_shape = constant_fold(tuple(base_rv.shape), raise_not_constant=False)
bcast_value = pt.broadcast_to(value, base_rv_shape)
logcdf = _logcdf_helper(base_rv, bcast_value)[0]
logcdf_prev = _logcdf_helper(base_rv, bcast_value - 1)[0]

return logprob
n = pt.prod(base_rv_shape)
return logdiffexp(n * logcdf, n * logcdf_prev)
13 changes: 6 additions & 7 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,20 +232,20 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
"""Compute the log-CDF graph for a `MeasurabeTransform`."""
other_inputs = list(inputs)
measurable_input = other_inputs.pop(op.measurable_input_idx)

# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
raise NotImplementedError("logcdf of transformed discrete variables not implemented")

backward_value = op.transform_elemwise.backward(value, *other_inputs)

# Fail if transformation is not injective
# A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many
if isinstance(backward_value, tuple):
raise NotImplementedError

is_discrete = measurable_input.type.dtype.startswith("int")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use discrete_types here?

discrete_types = bool_types | int_types

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong preference


logcdf = _logcdf_helper(measurable_input, backward_value)
logccdf = pt.log1mexp(logcdf)
if is_discrete:
logccdf = pt.log1mexp(_logcdf_helper(measurable_input, backward_value - 1))
else:
logccdf = pt.log1mexp(logcdf)

if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
pass
Expand All @@ -271,7 +271,6 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg

# The jacobian is used to ensure a value in the supported domain was provided
jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)

return pt.switch(pt.isnan(jacobian), -np.inf, logcdf)


Expand Down
45 changes: 14 additions & 31 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def test_argmax():
"""Test whether the logprob for ```pt.argmax``` is correctly rejected"""
x = pt.random.normal(0, 1, size=(3,))
x.name = "x"
x_max = pt.argmax(x, axis=-1)
x_max_value = pt.vector("x_max_value")
x_argmax = pt.argmax(x, axis=-1)
x_max_value = pt.scalar("x_max_value", dtype=x_argmax.type.dtype)

with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented for Argmax")):
x_max_logprob = logp(x_max, x_max_value)
logp(x_argmax, x_max_value)


@pytest.mark.parametrize(
Expand All @@ -72,26 +72,9 @@ def test_non_iid_fails(pt_op):
x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,))
x.name = "x"
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
x_m_value = pt.scalar("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)


@pytest.mark.parametrize(
"pt_op",
[
pt.max,
pt.min,
],
)
def test_non_rv_fails(pt_op):
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
x = pt.exp(pt.random.beta(0, 1, size=(3,)))
x.name = "x"
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)
logp(x_m, x_m_value)


@pytest.mark.parametrize(
Expand All @@ -107,9 +90,9 @@ def test_multivariate_rv_fails(pt_op):
x = pm.StickBreakingWeights.dist(_alpha, _k)
x.name = "x"
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
x_m_value = pt.scalar("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)
logp(x_m, x_m_value)


@pytest.mark.parametrize(
Expand All @@ -124,9 +107,9 @@ def test_categorical(pt_op):
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
x.name = "x"
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
x_m_value = pt.scalar("x_value", dtype=x.type.dtype)
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)
logp(x_m, x_m_value)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -230,19 +213,19 @@ def test_min_non_mul_elemwise_fails():
x = pt.log(pt.random.beta(0, 1, size=(3,)))
x.name = "x"
x_min = pt.min(x, axis=-1)
x_min_value = pt.vector("x_min_value")
x_min_value = pt.scalar("x_min_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_min_logprob = logp(x_min, x_min_value)
logp(x_min, x_min_value)


@pytest.mark.parametrize(
"mu, size, value, axis",
[(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
)
def test_max_discrete(mu, size, value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
x = pm.Poisson.dist(name="x", mu=mu, size=size)
x_max = pt.max(x, axis=axis)
x_max_value = pt.scalar("x_max_value")
x_max_value = pt.scalar("x_max_value", dtype=x.type.dtype)
x_max_logprob = logp(x_max, x_max_value)

test_value = value
Expand All @@ -265,7 +248,7 @@ def test_max_discrete(mu, size, value, axis):
def test_min_discrete(mu, n, test_value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(n,))
x_min = pt.min(x, axis=axis)
x_min_value = pt.scalar("x_min_value")
x_min_value = pt.scalar("x_min_value", dtype=x.type.dtype)
x_min_logprob = logp(x_min, x_min_value)

sf_before = 1 - sp.poisson(mu).cdf(test_value - 1)
Expand Down
Loading
Loading