Skip to content

Commit 64b9bc7

Browse files
gokuldricardoV94
authored andcommitted
Testing a fix by truncating cdf to precision.
1 parent 603afad commit 64b9bc7

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

pymc/testing.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,13 +710,21 @@ def check_selfconsistency_discrete_icdf(
710710
distribution: Distribution,
711711
domain: Domain,
712712
paramdomains: Dict[str, Domain],
713+
decimal: Optional[int] = None,
713714
n_samples: int = 100,
714715
) -> None:
715716
"""
716717
Check that the icdf and logcdf functions of the distribution are
717718
consistent for a sample of values in the domain of the
718719
distribution.
719720
"""
721+
722+
def ftrunc(values, decimal=0):
723+
return np.trunc(values * 10**decimal) / (10**decimal)
724+
725+
if decimal is None:
726+
decimal = select_by_precision(float64=6, float32=3)
727+
720728
dist = create_dist_from_paramdomains(distribution, paramdomains)
721729

722730
value = pt.TensorType(dtype="float64", shape=[])("value")
@@ -737,7 +745,7 @@ def check_selfconsistency_discrete_icdf(
737745
with pytensor.config.change_flags(mode=Mode("py")):
738746
expected_value = value
739747
computed_value = dist_icdf_fn(
740-
**point, value=np.exp(dist_logcdf_fn(**point, value=value))
748+
**point, value=ftrunc(np.exp(dist_logcdf_fn(**point, value=value)), decimal=decimal)
741749
)
742750
npt.assert_almost_equal(
743751
expected_value,

0 commit comments

Comments
 (0)