Skip to content

Commit 61f2de5

Browse files
committed
Testing a fix by truncating cdf to precision.
1 parent 005edbe commit 61f2de5

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

pymc/testing.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,13 +699,21 @@ def check_selfconsistency_discrete_icdf(
699699
distribution: Distribution,
700700
domain: Domain,
701701
paramdomains: Dict[str, Domain],
702+
decimal: Optional[int] = None,
702703
n_samples: int = 100,
703704
) -> None:
704705
"""
705706
Check that the icdf and logcdf functions of the distribution are
706707
consistent for a sample of values in the domain of the
707708
distribution.
708709
"""
710+
711+
def ftrunc(values, decimal=0):
712+
return np.trunc(values * 10**decimal) / (10**decimal)
713+
714+
if decimal is None:
715+
decimal = select_by_precision(float64=6, float32=3)
716+
709717
dist = create_dist_from_paramdomains(distribution, paramdomains)
710718

711719
value = pt.TensorType(dtype="float64", shape=[])("value")
@@ -726,7 +734,7 @@ def check_selfconsistency_discrete_icdf(
726734
with pytensor.config.change_flags(mode=Mode("py")):
727735
expected_value = value
728736
computed_value = dist_icdf_fn(
729-
**point, value=np.exp(dist_logcdf_fn(**point, value=value))
737+
**point, value=ftrunc(np.exp(dist_logcdf_fn(**point, value=value)), decimal=decimal)
730738
)
731739
npt.assert_almost_equal(
732740
expected_value,

0 commit comments

Comments
 (0)