Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List more math function in API docs #7211

Merged
merged 10 commits into from
Mar 27, 2024
106 changes: 74 additions & 32 deletions docs/source/api/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,56 +19,98 @@ Functions exposed in pymc namespace
invlogit
probit
invprobit
logaddexp
logsumexp


Functions exposed in pymc.math
------------------------------

.. automodule:: pymc.math
.. autosummary::
:toctree: generated/

dot
constant
flatten
zeros_like
ones_like
stack
concatenate
sum
abs
prod
lt
gt
le
ge
dot
eq
neq
switch
clip
where
and_
or_
abs
ge
gt
le
lt
exp
log
cos
sgn
sqr
sqrt
sum
ceil
floor
sin
tan
cosh
sinh
arcsin
arcsinh
cos
cosh
arccos
arccosh
tan
tanh
sqr
sqrt
erf
erfinv
dot
arctan
arctanh
cumprod
cumsum
matmul
and_
broadcast_to
clip
concatenate
flatten
or_
stack
switch
where
flatten_list
constant
max
maximum
mean
min
minimum
sgn
ceil
floor
matrix_inverse
sigmoid
round
erf
erfc
erfcinv
erfinv
log1pexp
log1mexp
logaddexp
logsumexp
invlogit
logdiffexp
logit
invlogit
probit
invprobit
sigmoid
softmax
log_softmax
logbern
full
full_like
ones
ones_like
zeros
zeros_like
kronecker
cartesian
kron_dot
kron_solve_lower
kron_solve_upper
kron_diag
flat_outer
expand_packed_triangular
batched_diag
block_diagonal
matrix_inverse
logdet
26 changes: 12 additions & 14 deletions pymc/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
ones_like,
or_,
prod,
round,
sgn,
sigmoid,
sin,
Expand Down Expand Up @@ -178,6 +179,7 @@
"expand_packed_triangular",
"batched_diag",
"block_diagonal",
"round",
]


Expand Down Expand Up @@ -272,27 +274,18 @@ def kron_diag(*diags):
return reduce(flat_outer, diags)


def round(*args, **kwargs):
"""
Temporary function to silence round warning in PyTensor. Please remove
when the warning disappears.
"""
kwargs["mode"] = "half_to_even"
return pt.round(*args, **kwargs)


def tround(*args, **kwargs):
warnings.warn("tround is deprecated. Use round instead.")
return round(*args, **kwargs)


def logdiffexp(a, b):
"""log(exp(a) - exp(b))"""
return a + pt.log1mexp(b - a)


def logdiffexp_numpy(a, b):
"""log(exp(a) - exp(b))"""
warnings.warn(
"pymc.math.logdiffexp_numpy is being deprecated.",
FutureWarning,
stacklevel=2,
)
return a + log1mexp_numpy(b - a, negative_input=True)


Expand Down Expand Up @@ -341,6 +334,11 @@ def log1mexp_numpy(x, *, negative_input=False):
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
"""
warnings.warn(
"pymc.math.log1mexp_numpy is being deprecated.",
FutureWarning,
stacklevel=2,
)
x = np.asarray(x, dtype="float")

if not negative_input:
Expand Down
1 change: 1 addition & 0 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def scipy_log_pdf(value, a, b):
return np.log(a) + np.log(b) + (a - 1) * np.log(value) + (b - 1) * np.log(1 - value**a)

def scipy_log_cdf(value, a, b):
warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning)
return pm.math.log1mexp_numpy(b * np.log1p(-(value**a)), negative_input=True)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

check_logp(
Expand Down
12 changes: 9 additions & 3 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
kron_solve_lower,
kronecker,
log1mexp,
log1mexp_numpy,
log1mexp_numpy, # to be deprecated
logdet,
logdiffexp,
logdiffexp_numpy,
logdiffexp_numpy, # to be deprecated
probit,
)
from pymc.pytensorf import floatX
Expand Down Expand Up @@ -148,6 +148,8 @@ def test_log1mexp():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning)
warnings.filterwarnings("ignore", "invalid value encountered in log", RuntimeWarning)

warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
actual_ = log1mexp_numpy(-vals, negative_input=True)
npt.assert_allclose(actual_, expected)
# Check that input was not changed in place
Expand All @@ -158,10 +160,12 @@ def test_log1mexp_numpy_no_warning():
"""Assert RuntimeWarning is not raised for very small numbers"""
with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning)
log1mexp_numpy(-1e-25, negative_input=True)


def test_log1mexp_numpy_integer_input():
warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning)
assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval())


Expand All @@ -170,10 +174,12 @@ def test_log1mexp_deprecation_warnings():
FutureWarning,
match="pymc.math.log1mexp_numpy will expect a negative input",
):
warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning)
res_pos = log1mexp_numpy(2)

with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.warn("pymc.math.log1mexp_numpy is being deprecated.", FutureWarning)
res_neg = log1mexp_numpy(-2, negative_input=True)

with pytest.warns(
Expand All @@ -196,7 +202,7 @@ def test_logdiffexp():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning)
b = np.log([0, 1, 2, 3])

warnings.warn("pymc.math.logdiffexp_numpy is being deprecated.", FutureWarning)
assert np.allclose(logdiffexp_numpy(a, b), 0)
assert np.allclose(logdiffexp(a, b).eval(), 0)

Expand Down