Skip to content

Commit 85a234a

Browse files
gokuldricardoV94
andauthored
Add ICDF for the Kumaraswamy distribution (#6642)
* Added ICDF for the Kumaraswamy distribution. * Added icdf - logcdf consistency tests. * Tmp testing with np assert_almost_equal print messages. * Testing a fix by truncating cdf to precision. * Revert "Tmp testing with np assert_almost_equal print messages." This reverts commit b0c93b77e3d39c81b22325b644cf94c2beeb9e9c. Avoid mypy test failure. This is not needed, we have identified the root cause now. * Remove discrete selfconsistency icdf --------- Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
1 parent 8146a5d commit 85a234a

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,16 @@ def logcdf(value, a, b):
13241324
msg="a > 0, b > 0",
13251325
)
13261326

1327+
def icdf(value, a, b):
1328+
res = pt.exp(pt.reciprocal(a) * pt.log1mexp(pt.reciprocal(b) * pt.log1p(-value)))
1329+
res = check_icdf_value(res, value)
1330+
return check_icdf_parameters(
1331+
res,
1332+
a > 0,
1333+
b > 0,
1334+
msg="a > 0, b > 0",
1335+
)
1336+
13271337

13281338
class Exponential(PositiveContinuous):
13291339
r"""

pymc/testing.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,52 @@ def check_selfconsistency_discrete_logcdf(
668668
)
669669

670670

671+
def check_selfconsistency_icdf(
672+
distribution: Distribution,
673+
paramdomains: dict[str, Domain],
674+
*,
675+
decimal: int | None = None,
676+
n_samples: int = 100,
677+
) -> None:
678+
"""Check that the icdf and logcdf functions of the distribution are consistent.
679+
680+
Only works with continuous distributions.
681+
"""
682+
if decimal is None:
683+
decimal = select_by_precision(float64=6, float32=3)
684+
685+
dist = create_dist_from_paramdomains(distribution, paramdomains)
686+
if dist.type.dtype.startswith("int"):
687+
raise NotImplementedError(
688+
"check_selfconsistency_icdf is not robust against discrete distributions."
689+
)
690+
value = dist.astype("float64").type("value")
691+
dist_icdf = icdf(dist, value)
692+
dist_cdf = pt.exp(logcdf(dist, value))
693+
694+
py_mode = Mode("py")
695+
dist_icdf_fn = pytensor.function(list(inputvars(dist_icdf)), dist_icdf, mode=py_mode)
696+
dist_cdf_fn = compile(list(inputvars(dist_cdf)), dist_cdf, mode=py_mode)
697+
698+
domains = paramdomains.copy()
699+
domains["value"] = Domain(np.linspace(0, 1, 10))
700+
701+
for point in product(domains, n_samples=n_samples):
702+
point = dict(point)
703+
value = point.pop("value")
704+
icdf_value = dist_icdf_fn(**point, value=value)
705+
recovered_value = dist_cdf_fn(
706+
**point,
707+
value=icdf_value,
708+
)
709+
np.testing.assert_almost_equal(
710+
value,
711+
recovered_value,
712+
decimal=decimal,
713+
err_msg=f"point: {point}",
714+
)
715+
716+
671717
def assert_support_point_is_expected(model, expected, check_finite_logp=True):
672718
fn = make_initial_point_fn(
673719
model=model,

tests/distributions/test_continuous.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
check_icdf,
4646
check_logcdf,
4747
check_logp,
48+
check_selfconsistency_icdf,
4849
continuous_random_tester,
4950
seeded_numpy_distribution_builder,
5051
seeded_scipy_distribution_builder,
@@ -441,6 +442,10 @@ def scipy_log_cdf(value, a, b):
441442
{"a": Rplus, "b": Rplus},
442443
scipy_log_cdf,
443444
)
445+
check_selfconsistency_icdf(
446+
pm.Kumaraswamy,
447+
{"a": Rplusbig, "b": Rplusbig},
448+
)
444449

445450
def test_exponential(self):
446451
check_logp(

0 commit comments

Comments
 (0)