Skip to content

Commit fe5f30b

Browse files
committed
Added icdf - logcdf consistency tests.
1 parent 69058cf commit fe5f30b

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

pymc/testing.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,82 @@ def check_selfconsistency_discrete_logcdf(
657657
)
658658

659659

660+
def check_selfconsistency_continuous_icdf(
661+
distribution: Distribution,
662+
paramdomains: Dict[str, Domain],
663+
decimal: Optional[int] = None,
664+
n_samples: int = 100,
665+
) -> None:
666+
"""
667+
Check that the icdf and logcdf functions of the distribution are consistent for a sample of probability values.
668+
"""
669+
if decimal is None:
670+
decimal = select_by_precision(float64=6, float32=3)
671+
672+
dist = create_dist_from_paramdomains(distribution, paramdomains)
673+
value = dist.type()
674+
value.name = "value"
675+
676+
dist_icdf = icdf(dist, value)
677+
dist_icdf_fn = pytensor.function(list(inputvars(dist_icdf)), dist_icdf)
678+
679+
dist_logcdf = logcdf(dist, value)
680+
dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf)
681+
682+
domains = paramdomains.copy()
683+
domains["value"] = Domain(np.linspace(0, 1, 10))
684+
685+
for point in product(domains, n_samples=n_samples):
686+
point = dict(point)
687+
value = point.pop("value")
688+
689+
with pytensor.config.change_flags(mode=Mode("py")):
690+
npt.assert_almost_equal(
691+
value,
692+
np.exp(dist_logcdf_fn(**point, value=dist_icdf_fn(**point, value=value))),
693+
decimal=decimal,
694+
err_msg=f"point: {point}, value: {value}",
695+
)
696+
697+
698+
def check_selfconsistency_discrete_icdf(
699+
distribution: Distribution,
700+
domain: Domain,
701+
paramdomains: Dict[str, Domain],
702+
n_samples: int = 100,
703+
) -> None:
704+
"""
705+
Check that the icdf and logcdf functions of the distribution are
706+
consistent for a sample of values in the domain of the
707+
distribution.
708+
"""
709+
dist = create_dist_from_paramdomains(distribution, paramdomains)
710+
711+
value = pt.TensorType(dtype="float64", shape=[])("value")
712+
713+
dist_icdf = icdf(dist, value)
714+
dist_icdf_fn = pytensor.function(list(inputvars(dist_icdf)), dist_icdf)
715+
716+
dist_logcdf = logcdf(dist, value)
717+
dist_logcdf_fn = compile_pymc(list(inputvars(dist_logcdf)), dist_logcdf)
718+
719+
domains = paramdomains.copy()
720+
domains["value"] = domain
721+
722+
for point in product(domains, n_samples=n_samples):
723+
point = dict(point)
724+
value = point.pop("value")
725+
726+
with pytensor.config.change_flags(mode=Mode("py")):
727+
expected_value = value
728+
computed_value = dist_icdf_fn(
729+
**point, value=np.exp(dist_logcdf_fn(**point, value=value))
730+
)
731+
assert (
732+
expected_value == computed_value
733+
), f"expected_value = {expected_value}, computed_value = {computed_value}, {point}"
734+
735+
660736
def assert_moment_is_expected(model, expected, check_finite_logp=True):
661737
fn = make_initial_point_fn(
662738
model=model,

tests/distributions/test_continuous.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
check_icdf,
4747
check_logcdf,
4848
check_logp,
49+
check_selfconsistency_continuous_icdf,
4950
continuous_random_tester,
5051
seeded_numpy_distribution_builder,
5152
seeded_scipy_distribution_builder,
@@ -424,6 +425,10 @@ def scipy_log_cdf(value, a, b):
424425
{"a": Rplus, "b": Rplus},
425426
scipy_log_cdf,
426427
)
428+
check_selfconsistency_continuous_icdf(
429+
pm.Kumaraswamy,
430+
{"a": Rplusbig, "b": Rplusbig},
431+
)
427432

428433
def test_exponential(self):
429434
check_logp(

tests/distributions/test_discrete.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
check_icdf,
5252
check_logcdf,
5353
check_logp,
54+
check_selfconsistency_discrete_icdf,
5455
check_selfconsistency_discrete_logcdf,
5556
seeded_numpy_distribution_builder,
5657
seeded_scipy_distribution_builder,
@@ -119,6 +120,11 @@ def test_discrete_unif(self):
119120
lambda q, lower, upper: st.randint.ppf(q=q, low=lower, high=upper + 1),
120121
skip_paramdomain_outside_edge_test=True,
121122
)
123+
check_selfconsistency_discrete_icdf(
124+
pm.DiscreteUniform,
125+
Rdunif,
126+
{"lower": -Rplusdunif, "upper": Rplusdunif},
127+
)
122128
# Custom logp / logcdf check for invalid parameters
123129
invalid_dist = pm.DiscreteUniform.dist(lower=1, upper=0)
124130
with pytensor.config.change_flags(mode=Mode("py")):
@@ -152,6 +158,11 @@ def test_geometric(self):
152158
{"p": Unit},
153159
st.geom.ppf,
154160
)
161+
check_selfconsistency_discrete_icdf(
162+
pm.Geometric,
163+
Nat,
164+
{"p": Unit},
165+
)
155166

156167
def test_hypergeometric(self):
157168
def modified_scipy_hypergeom_logcdf(value, N, k, n):

0 commit comments

Comments
 (0)