Skip to content

Commit

Permalink
List more math function in API docs (pymc-devs#7211)
Browse files Browse the repository at this point in the history
Also removes deprecated functions and deprecates numpy helpers


Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
  • Loading branch information
2 people authored and mkusnetsov committed Oct 26, 2024
1 parent 8a66e90 commit 16e6d69
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 67 deletions.
106 changes: 74 additions & 32 deletions docs/source/api/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,56 +19,98 @@ Functions exposed in pymc namespace
invlogit
probit
invprobit
logaddexp
logsumexp


Functions exposed in pymc.math
------------------------------

.. automodule:: pymc.math
.. autosummary::
:toctree: generated/

dot
constant
flatten
zeros_like
ones_like
stack
concatenate
sum
abs
prod
lt
gt
le
ge
dot
eq
neq
switch
clip
where
and_
or_
abs
ge
gt
le
lt
exp
log
cos
sgn
sqr
sqrt
sum
ceil
floor
sin
tan
cosh
sinh
arcsin
arcsinh
cos
cosh
arccos
arccosh
tan
tanh
sqr
sqrt
erf
erfinv
dot
arctan
arctanh
cumprod
cumsum
matmul
and_
broadcast_to
clip
concatenate
flatten
or_
stack
switch
where
flatten_list
constant
max
maximum
mean
min
minimum
sgn
ceil
floor
matrix_inverse
sigmoid
round
erf
erfc
erfcinv
erfinv
log1pexp
log1mexp
logaddexp
logsumexp
invlogit
logdiffexp
logit
invlogit
probit
invprobit
sigmoid
softmax
log_softmax
logbern
full
full_like
ones
ones_like
zeros
zeros_like
kronecker
cartesian
kron_dot
kron_solve_lower
kron_solve_upper
kron_diag
flat_outer
expand_packed_triangular
batched_diag
block_diagonal
matrix_inverse
logdet
26 changes: 12 additions & 14 deletions pymc/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
ones_like,
or_,
prod,
round,
sgn,
sigmoid,
sin,
Expand Down Expand Up @@ -178,6 +179,7 @@
"expand_packed_triangular",
"batched_diag",
"block_diagonal",
"round",
]


Expand Down Expand Up @@ -272,27 +274,18 @@ def kron_diag(*diags):
return reduce(flat_outer, diags)


def round(*args, **kwargs):
"""
Temporary function to silence round warning in PyTensor. Please remove
when the warning disappears.
"""
kwargs["mode"] = "half_to_even"
return pt.round(*args, **kwargs)


def tround(*args, **kwargs):
warnings.warn("tround is deprecated. Use round instead.")
return round(*args, **kwargs)


def logdiffexp(a, b):
"""log(exp(a) - exp(b))"""
return a + pt.log1mexp(b - a)


def logdiffexp_numpy(a, b):
"""log(exp(a) - exp(b))"""
warnings.warn(
"pymc.math.logdiffexp_numpy is being deprecated.",
FutureWarning,
stacklevel=2,
)
return a + log1mexp_numpy(b - a, negative_input=True)


Expand Down Expand Up @@ -341,6 +334,11 @@ def log1mexp_numpy(x, *, negative_input=False):
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
"""
warnings.warn(
"pymc.math.log1mexp_numpy is being deprecated.",
FutureWarning,
stacklevel=2,
)
x = np.asarray(x, dtype="float")

if not negative_input:
Expand Down
5 changes: 4 additions & 1 deletion tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,11 @@ def test_kumaraswamy(self):
def scipy_log_pdf(value, a, b):
return np.log(a) + np.log(b) + (a - 1) * np.log(value) + (b - 1) * np.log(1 - value**a)

def log1mexp(x):
return np.log1p(-np.exp(x)) if x < np.log(0.5) else np.log(-np.expm1(x))

def scipy_log_cdf(value, a, b):
return pm.math.log1mexp_numpy(b * np.log1p(-(value**a)), negative_input=True)
return log1mexp(b * np.log1p(-(value**a)))

check_logp(
pm.Kumaraswamy,
Expand Down
41 changes: 21 additions & 20 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,45 +145,46 @@ def test_log1mexp():
)
actual = pt.log1mexp(-vals).eval()
npt.assert_allclose(actual, expected)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning)
warnings.filterwarnings("ignore", "invalid value encountered in log", RuntimeWarning)
actual_ = log1mexp_numpy(-vals, negative_input=True)
with pytest.warns(FutureWarning, match="deprecated"):
actual_ = log1mexp_numpy(-vals, negative_input=True)
npt.assert_allclose(actual_, expected)
# Check that input was not changed in place
npt.assert_allclose(vals, vals_)


@pytest.mark.filterwarnings("error")
def test_log1mexp_numpy_no_warning():
"""Assert RuntimeWarning is not raised for very small numbers"""
with warnings.catch_warnings():
warnings.simplefilter("error")
with pytest.warns(FutureWarning, match="deprecated"):
log1mexp_numpy(-1e-25, negative_input=True)


def test_log1mexp_numpy_integer_input():
assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval())
with pytest.warns(FutureWarning, match="deprecated"):
assert np.isclose(log1mexp_numpy(-2, negative_input=True), pt.log1mexp(-2).eval())


@pytest.mark.filterwarnings("error")
def test_log1mexp_deprecation_warnings():
with pytest.warns(
FutureWarning,
match="pymc.math.log1mexp_numpy will expect a negative input",
):
res_pos = log1mexp_numpy(2)
with pytest.warns(FutureWarning, match="deprecated"):
with pytest.warns(
FutureWarning,
match="pymc.math.log1mexp_numpy will expect a negative input",
):
res_pos = log1mexp_numpy(2)

with warnings.catch_warnings():
warnings.simplefilter("error")
res_neg = log1mexp_numpy(-2, negative_input=True)

with pytest.warns(
FutureWarning,
match="pymc.math.log1mexp will expect a negative input",
):
res_pos_at = log1mexp(2).eval()
with pytest.warns(
FutureWarning,
match="pymc.math.log1mexp will expect a negative input",
):
res_pos_at = log1mexp(2).eval()

with warnings.catch_warnings():
warnings.simplefilter("error")
res_neg_at = log1mexp(-2, negative_input=True).eval()

assert np.isclose(res_pos, res_neg)
Expand All @@ -196,8 +197,8 @@ def test_logdiffexp():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning)
b = np.log([0, 1, 2, 3])

assert np.allclose(logdiffexp_numpy(a, b), 0)
with pytest.warns(FutureWarning, match="deprecated"):
assert np.allclose(logdiffexp_numpy(a, b), 0)
assert np.allclose(logdiffexp(a, b).eval(), 0)


Expand Down

0 comments on commit 16e6d69

Please sign in to comment.