|
29 | 29 | import pymc as pm |
30 | 30 |
|
31 | 31 | from pymc.distributions.discrete import Geometric, _OrderedLogistic, _OrderedProbit |
32 | | -from pymc.logprob.abstract import logcdf |
| 32 | +from pymc.logprob.abstract import icdf, logcdf |
33 | 33 | from pymc.logprob.joint_logprob import logp |
34 | 34 | from pymc.logprob.utils import ParameterValueError |
35 | 35 | from pymc.pytensorf import floatX |
@@ -118,13 +118,21 @@ def test_discrete_unif(self): |
118 | 118 | Domain([-10, 0, 10], "int64"), |
119 | 119 | {"lower": -Rplusdunif, "upper": Rplusdunif}, |
120 | 120 | ) |
| 121 | + check_icdf( |
| 122 | + pm.DiscreteUniform, |
| 123 | + {"lower": -Rplusdunif, "upper": Rplusdunif}, |
| 124 | + lambda q, lower, upper: st.randint.ppf(q=q, low=lower, high=upper + 1), |
| 125 | + skip_paramdomain_outside_edge_test=True, |
| 126 | + ) |
121 | 127 | # Custom logp / logcdf check for invalid parameters |
122 | 128 | invalid_dist = pm.DiscreteUniform.dist(lower=1, upper=0) |
123 | 129 | with pytensor.config.change_flags(mode=Mode("py")): |
124 | 130 | with pytest.raises(ParameterValueError): |
125 | 131 | logp(invalid_dist, 0.5).eval() |
126 | 132 | with pytest.raises(ParameterValueError): |
127 | 133 | logcdf(invalid_dist, 2).eval() |
| 134 | + with pytest.raises(ParameterValueError): |
| 135 | + icdf(invalid_dist, np.array(1)).eval() |
128 | 136 |
|
129 | 137 | def test_geometric(self): |
130 | 138 | check_logp( |
|
0 commit comments