Skip to content

Commit

Permalink
adding icdf function for Cauchy and Logistic with tests (#6747)
Browse files Browse the repository at this point in the history

Co-authored-by: seuabeia <126102654+seuabeia@users.noreply.github.com>
  • Loading branch information
amyoshino and seuabeia authored Jun 16, 2023
1 parent 69514ac commit 77f24d7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
19 changes: 18 additions & 1 deletion pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,7 +1714,6 @@ def logcdf(value, mu, sigma):
-np.inf,
normal_lcdf(mu, sigma, pt.log(value)),
)

return check_parameters(
res,
sigma > 0,
Expand Down Expand Up @@ -2039,6 +2038,15 @@ def logcdf(value, alpha, beta):
msg="beta > 0",
)

def icdf(value, alpha, beta):
res = alpha + beta * pt.tan(np.pi * (value - 0.5))
res = check_icdf_value(res, value)
return check_parameters(
res,
beta > 0,
msg="beta > 0",
)


class HalfCauchy(PositiveContinuous):
r"""
Expand Down Expand Up @@ -3357,6 +3365,15 @@ def logcdf(value, mu, s):
msg="s > 0",
)

def icdf(value, mu, s):
res = mu + s * pt.log(value / (1 - value))
res = check_icdf_value(res, value)
return check_parameters(
res,
s > 0,
msg="s > 0",
)


class LogitNormalRV(RandomVariable):
name = "logit_normal"
Expand Down
11 changes: 11 additions & 0 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,11 @@ def test_cauchy(self):
{"alpha": R, "beta": Rplusbig},
lambda value, alpha, beta: st.cauchy.logcdf(value, alpha, beta),
)
check_icdf(
pm.Cauchy,
{"alpha": R, "beta": Rplusbig},
lambda q, alpha, beta: st.cauchy.ppf(q, alpha, beta),
)

def test_half_cauchy(self):
check_logp(
Expand Down Expand Up @@ -768,6 +773,12 @@ def test_logistic(self):
lambda value, mu, s: st.logistic.logcdf(value, mu, s),
decimal=select_by_precision(float64=6, float32=1),
)
check_icdf(
pm.Logistic,
{"mu": R, "s": Rplus},
lambda q, mu, s: st.logistic.ppf(q, mu, s),
decimal=select_by_precision(float64=6, float32=1),
)

def test_logitnormal(self):
check_logp(
Expand Down

0 comments on commit 77f24d7

Please sign in to comment.