Skip to content

Commit

Permalink
Improve logsumexp to work with infinite values (#4360)
Browse files Browse the repository at this point in the history
* Make logsumexp work with inifinite values, matching scipy behavior

* Run pre-commit

* Add note to release_notes
  • Loading branch information
ricardoV94 authored Dec 20, 2020
1 parent 34447a7 commit 0e9b9a4
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ This is the first release to support Python3.9 and to drop Python3.6.
- Removed `theanof.set_theano_config` because it illegally touched Theano's privates (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)).
- In `sample_posterior_predictive` the `vars` kwarg was removed in favor of `var_names` (see [#4343](https://github.com/pymc-devs/pymc3/pull/4343)).
- The notebook gallery has been moved to https://github.com/pymc-devs/pymc-examples (see [#4348](https://github.com/pymc-devs/pymc3/pull/4348)).

- `math.logsumexp` now matches `scipy.special.logsumexp` when arrays contain infinite values (see [#4360](https://github.com/pymc-devs/pymc3/pull/4360)).

## PyMC3 3.10.0 (7 December 2020)

Expand Down
1 change: 1 addition & 0 deletions pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def tround(*args, **kwargs):
def logsumexp(x, axis=None, keepdims=True):
# Adapted from https://github.com/Theano/Theano/issues/1563
x_max = tt.max(x, axis=axis, keepdims=True)
x_max = tt.switch(tt.isinf(x_max), 0, x_max)
res = tt.log(tt.sum(tt.exp(x - x_max), axis=axis, keepdims=True)) + x_max
return res if keepdims else res.squeeze()

Expand Down
27 changes: 27 additions & 0 deletions pymc3/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import theano
import theano.tensor as tt

from scipy.special import logsumexp as scipy_logsumexp

from pymc3.math import (
LogDet,
cartesian,
Expand All @@ -30,6 +32,7 @@
log1mexp_numpy,
log1pexp,
logdet,
logsumexp,
probit,
)
from pymc3.tests.helpers import SeededTest, verify_grad
Expand Down Expand Up @@ -207,3 +210,27 @@ def test_expand_packed_triangular():
assert np.all(expand_upper.eval({packed: upper_packed}) == upper)
assert np.all(expand_diag_lower.eval({packed: lower_packed}) == floatX(np.diag(vals)))
assert np.all(expand_diag_upper.eval({packed: upper_packed}) == floatX(np.diag(vals)))


@pytest.mark.parametrize(
"values, axis, keepdims",
[
(np.array([-4, -2]), None, True),
(np.array([-np.inf, -2]), None, True),
(np.array([-2, np.inf]), None, True),
(np.array([-np.inf, -np.inf]), None, True),
(np.array([np.inf, np.inf]), None, True),
(np.array([-np.inf, np.inf]), None, True),
(np.array([[-np.inf, -np.inf], [-np.inf, -np.inf]]), None, True),
(np.array([[-np.inf, -np.inf], [-np.inf, -np.inf]]), 0, True),
(np.array([[-np.inf, -np.inf], [-np.inf, -np.inf]]), 1, True),
(np.array([[-np.inf, -np.inf], [-np.inf, -np.inf]]), 0, False),
(np.array([[-np.inf, -np.inf], [-np.inf, -np.inf]]), 1, False),
(np.array([[-2, np.inf], [-np.inf, -np.inf]]), 0, True),
],
)
def test_logsumexp(values, axis, keepdims):
npt.assert_almost_equal(
logsumexp(values, axis=axis, keepdims=keepdims).eval(),
scipy_logsumexp(values, axis=axis, keepdims=keepdims),
)

0 comments on commit 0e9b9a4

Please sign in to comment.