|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import functools as ft |
| 16 | +import warnings |
16 | 17 |
|
17 | 18 | import numpy as np |
18 | 19 | import numpy.testing as npt |
@@ -890,24 +891,26 @@ def scipy_logp(value, mu, sigma, lower, upper): |
890 | 891 | assert np.isinf(logp[2]) |
891 | 892 |
|
892 | 893 | def test_get_tau_sigma(self): |
893 | | - sigma = np.array(2) |
894 | | - npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma]) |
| 894 | + # Fail on warnings |
| 895 | + with warnings.catch_warnings(): |
| 896 | + warnings.simplefilter("error") |
895 | 897 |
|
896 | | - tau = np.array(2) |
897 | | - npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5]) |
| 898 | + sigma = np.array(2) |
| 899 | + npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma]) |
898 | 900 |
|
899 | | - tau, _ = get_tau_sigma(sigma=pt.constant(-2)) |
900 | | - with pytest.raises(ParameterValueError): |
901 | | - tau.eval() |
| 901 | + tau = np.array(2) |
| 902 | + npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5]) |
902 | 903 |
|
903 | | - _, sigma = get_tau_sigma(tau=pt.constant(-2)) |
904 | | - with pytest.raises(ParameterValueError): |
905 | | - sigma.eval() |
| 904 | + tau, _ = get_tau_sigma(sigma=pt.constant(-2)) |
| 905 | + npt.assert_almost_equal(tau.eval(), -0.25) |
906 | 906 |
|
907 | | - sigma = [1, 2] |
908 | | - npt.assert_almost_equal( |
909 | | - get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)] |
910 | | - ) |
| 907 | + _, sigma = get_tau_sigma(tau=pt.constant(-2)) |
| 908 | + npt.assert_almost_equal(sigma.eval(), -np.sqrt(1 / 2)) |
| 909 | + |
| 910 | + sigma = [1, 2] |
| 911 | + npt.assert_almost_equal( |
| 912 | + get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)] |
| 913 | + ) |
911 | 914 |
|
912 | 915 | @pytest.mark.parametrize( |
913 | 916 | "value,mu,sigma,nu,logp", |
|
0 commit comments