Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,16 @@ def logcdf(value, a, b):
msg="a > 0, b > 0",
)

def icdf(value, a, b):
res = pt.exp(pt.reciprocal(a) * pt.log1mexp(pt.reciprocal(b) * pt.log1p(-value)))
res = check_icdf_value(res, value)
return check_icdf_parameters(
res,
a > 0,
b > 0,
msg="a > 0, b > 0",
)


class Exponential(PositiveContinuous):
r"""
Expand Down
46 changes: 46 additions & 0 deletions pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,52 @@ def check_selfconsistency_discrete_logcdf(
)


def check_selfconsistency_icdf(
distribution: Distribution,
paramdomains: dict[str, Domain],
*,
decimal: int | None = None,
n_samples: int = 100,
) -> None:
"""Check that the icdf and logcdf functions of the distribution are consistent.

Only works with continuous distributions.
"""
if decimal is None:
decimal = select_by_precision(float64=6, float32=3)

dist = create_dist_from_paramdomains(distribution, paramdomains)
if dist.type.dtype.startswith("int"):
raise NotImplementedError(
"check_selfconsistency_icdf is not robust against discrete distributions."
)
value = dist.astype("float64").type("value")
dist_icdf = icdf(dist, value)
dist_cdf = pt.exp(logcdf(dist, value))

py_mode = Mode("py")
dist_icdf_fn = pytensor.function(list(inputvars(dist_icdf)), dist_icdf, mode=py_mode)
dist_cdf_fn = compile(list(inputvars(dist_cdf)), dist_cdf, mode=py_mode)

domains = paramdomains.copy()
domains["value"] = Domain(np.linspace(0, 1, 10))

for point in product(domains, n_samples=n_samples):
point = dict(point)
value = point.pop("value")
icdf_value = dist_icdf_fn(**point, value=value)
recovered_value = dist_cdf_fn(
**point,
value=icdf_value,
)
np.testing.assert_almost_equal(
value,
recovered_value,
decimal=decimal,
err_msg=f"point: {point}",
)


def assert_support_point_is_expected(model, expected, check_finite_logp=True):
fn = make_initial_point_fn(
model=model,
Expand Down
5 changes: 5 additions & 0 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
check_icdf,
check_logcdf,
check_logp,
check_selfconsistency_icdf,
continuous_random_tester,
seeded_numpy_distribution_builder,
seeded_scipy_distribution_builder,
Expand Down Expand Up @@ -441,6 +442,10 @@ def scipy_log_cdf(value, a, b):
{"a": Rplus, "b": Rplus},
scipy_log_cdf,
)
check_selfconsistency_icdf(
pm.Kumaraswamy,
{"a": Rplusbig, "b": Rplusbig},
)

def test_exponential(self):
check_logp(
Expand Down