Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List more math function in API docs #7211

Merged
merged 10 commits into from
Mar 27, 2024
106 changes: 74 additions & 32 deletions docs/source/api/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,56 +19,98 @@ Functions exposed in pymc namespace
invlogit
probit
invprobit
logaddexp
logsumexp


Functions exposed in pymc.math
------------------------------

.. automodule:: pymc.math
.. autosummary::
:toctree: generated/

dot
constant
flatten
zeros_like
ones_like
stack
concatenate
sum
abs
prod
lt
gt
le
ge
dot
eq
neq
switch
clip
where
and_
or_
abs
ge
gt
le
lt
exp
log
cos
sgn
sqr
sqrt
sum
ceil
floor
sin
tan
cosh
sinh
arcsin
arcsinh
cos
cosh
arccos
arccosh
tan
tanh
sqr
sqrt
erf
erfinv
dot
arctan
arctanh
cumprod
cumsum
matmul
and_
broadcast_to
clip
concatenate
flatten
or_
stack
switch
where
flatten_list
constant
max
maximum
mean
min
minimum
sgn
ceil
floor
matrix_inverse
sigmoid
round
erf
erfc
erfcinv
erfinv
log1pexp
log1mexp
logaddexp
logsumexp
invlogit
logdiffexp
logit
invlogit
probit
invprobit
sigmoid
softmax
log_softmax
logbern
full
full_like
ones
ones_like
zeros
zeros_like
kronecker
cartesian
kron_dot
kron_solve_lower
kron_solve_upper
kron_diag
flat_outer
expand_packed_triangular
batched_diag
block_diagonal
matrix_inverse
logdet
26 changes: 12 additions & 14 deletions pymc/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
ones_like,
or_,
prod,
round,
sgn,
sigmoid,
sin,
Expand Down Expand Up @@ -178,6 +179,7 @@
"expand_packed_triangular",
"batched_diag",
"block_diagonal",
"round",
]


Expand Down Expand Up @@ -272,27 +274,18 @@
return reduce(flat_outer, diags)


def round(*args, **kwargs):
"""
Temporary function to silence round warning in PyTensor. Please remove
when the warning disappears.
"""
kwargs["mode"] = "half_to_even"
return pt.round(*args, **kwargs)


def tround(*args, **kwargs):
warnings.warn("tround is deprecated. Use round instead.")
return round(*args, **kwargs)


def logdiffexp(a, b):
"""log(exp(a) - exp(b))"""
return a + pt.log1mexp(b - a)


def logdiffexp_numpy(a, b):
"""log(exp(a) - exp(b))"""
warnings.warn(

Check warning on line 284 in pymc/math.py

View check run for this annotation

Codecov / codecov/patch

pymc/math.py#L284

Added line #L284 was not covered by tests
"pymc.math.logdiffexp_numpy is being deprecated.",
FutureWarning,
stacklevel=2,
)
return a + log1mexp_numpy(b - a, negative_input=True)


Expand Down Expand Up @@ -341,6 +334,11 @@
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
"""
warnings.warn(

Check warning on line 337 in pymc/math.py

View check run for this annotation

Codecov / codecov/patch

pymc/math.py#L337

Added line #L337 was not covered by tests
"pymc.math.log1mexp_numpy is being deprecated.",
FutureWarning,
stacklevel=2,
)
x = np.asarray(x, dtype="float")

if not negative_input:
Expand Down
5 changes: 4 additions & 1 deletion tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,11 @@ def test_kumaraswamy(self):
def scipy_log_pdf(value, a, b):
return np.log(a) + np.log(b) + (a - 1) * np.log(value) + (b - 1) * np.log(1 - value**a)

def log1mexp(x):
return np.log1p(-np.exp(x)) if x < np.log(0.5) else np.log(-np.expm1(x))

def scipy_log_cdf(value, a, b):
return pm.math.log1mexp_numpy(b * np.log1p(-(value**a)), negative_input=True)
return log1mexp(b * np.log1p(-(value**a)))

check_logp(
pm.Kumaraswamy,
Expand Down
41 changes: 21 additions & 20 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,45 +145,46 @@ def test_log1mexp():
)
actual = pt.log1mexp(-vals).eval()
npt.assert_allclose(actual, expected)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning)
warnings.filterwarnings("ignore", "invalid value encountered in log", RuntimeWarning)
actual_ = log1mexp_numpy(-vals, negative_input=True)
with pytest.warns(FutureWarning, match="deprecated"):
actual_ = log1mexp_numpy(-vals, negative_input=True)
npt.assert_allclose(actual_, expected)
# Check that input was not changed in place
npt.assert_allclose(vals, vals_)


@pytest.mark.filterwarnings("error")
def test_log1mexp_numpy_no_warning():
"""Assert RuntimeWarning is not raised for very small numbers"""
with warnings.catch_warnings():
warnings.simplefilter("error")
with pytest.warns(FutureWarning, match="deprecated"):
log1mexp_numpy(-1e-25, negative_input=True)


def test_log1mexp_numpy_integer_input():
assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval())
with pytest.warns(FutureWarning, match="deprecated"):
assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval())


@pytest.mark.filterwarnings("error")
def test_log1mexp_deprecation_warnings():
with pytest.warns(
FutureWarning,
match="pymc.math.log1mexp_numpy will expect a negative input",
):
res_pos = log1mexp_numpy(2)
with pytest.warns(FutureWarning, match="deprecated"):
with pytest.warns(
FutureWarning,
match="pymc.math.log1mexp_numpy will expect a negative input",
):
res_pos = log1mexp_numpy(2)

with warnings.catch_warnings():
warnings.simplefilter("error")
res_neg = log1mexp_numpy(-2, negative_input=True)

with pytest.warns(
FutureWarning,
match="pymc.math.log1mexp will expect a negative input",
):
res_pos_at = log1mexp(2).eval()
with pytest.warns(
FutureWarning,
match="pymc.math.log1mexp will expect a negative input",
):
res_pos_at = log1mexp(2).eval()

with warnings.catch_warnings():
warnings.simplefilter("error")
res_neg_at = log1mexp(-2, negative_input=True).eval()

assert np.isclose(res_pos, res_neg)
Expand All @@ -196,8 +197,8 @@ def test_logdiffexp():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning)
b = np.log([0, 1, 2, 3])

assert np.allclose(logdiffexp_numpy(a, b), 0)
with pytest.warns(FutureWarning, match="deprecated"):
assert np.allclose(logdiffexp_numpy(a, b), 0)
assert np.allclose(logdiffexp(a, b).eval(), 0)


Expand Down
Loading