Skip to content

Commit

Permalink
Merge pull request scipy#2494 from argriffing/rvs-faddeeva
Browse files Browse the repository at this point in the history
BUG: fix non-symmetry in tails of stats.truncnorm distribution.
  • Loading branch information
rgommers committed May 19, 2013
2 parents 191e7a2 + b3e9b92 commit 5db3b49
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 13 deletions.
32 changes: 22 additions & 10 deletions scipy/stats/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,26 +2076,30 @@ def _ppf(self,q):
_norm_pdf_C = math.sqrt(2*pi)
_norm_pdf_logC = math.log(_norm_pdf_C)


def _norm_pdf(x):
return exp(-x**2/2.0) / _norm_pdf_C


def _norm_logpdf(x):
return -x**2 / 2.0 - _norm_pdf_logC


def _norm_cdf(x):
return special.ndtr(x)


def _norm_logcdf(x):
return special.log_ndtr(x)


def _norm_ppf(q):
return special.ndtri(q)

def _norm_sf(x):
return special.ndtr(-x)

def _norm_logsf(x):
return special.log_ndtr(-x)

def _norm_isf(q):
return -special.ndtri(q)


class norm_gen(rv_continuous):
"""A normal continuous random variable.
Expand Down Expand Up @@ -2130,16 +2134,16 @@ def _logcdf(self, x):
return _norm_logcdf(x)

def _sf(self, x):
return _norm_cdf(-x)
return _norm_sf(x)

def _logsf(self, x):
return _norm_logcdf(-x)
return _norm_logsf(x)

def _ppf(self,q):
return _norm_ppf(q)

def _isf(self,q):
return -_norm_ppf(q)
return _norm_isf(q)

def _stats(self):
return 0.0, 1.0, 0.0, 0.0
Expand Down Expand Up @@ -5578,7 +5582,12 @@ def _argcheck(self, a, b):
self.b = b
self._nb = _norm_cdf(b)
self._na = _norm_cdf(a)
self._delta = self._nb - self._na
self._sb = _norm_sf(b)
self._sa = _norm_sf(a)
if self.a > 0:
self._delta = -(self._sb - self._sa)
else:
self._delta = self._nb - self._na
self._logdelta = log(self._delta)
return (a != b)

Expand All @@ -5594,7 +5603,10 @@ def _cdf(self, x, a, b):
return (_norm_cdf(x) - self._na) / self._delta

def _ppf(self, q, a, b):
return norm._ppf(q*self._nb + self._na*(1.0-q))
if self.a > 0:
return _norm_isf(q*self._sb + self._sa*(1.0-q))
else:
return _norm_ppf(q*self._nb + self._na*(1.0-q))

def _stats(self, a, b):
nA, nB = self._na, self._nb
Expand Down
1 change: 1 addition & 0 deletions scipy/stats/tests/test_continuous_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
['triang', (0.15785029824528218,)],
['truncexpon', (4.6907725456810478,)],
['truncnorm', (-1.0978730080013919, 2.7306754109031979)],
['truncnorm', (0.1, 2.)],
['tukeylambda', (3.1321477856738267,)],
['uniform', ()],
['vonmises', (3.9939042581071398,)],
Expand Down
29 changes: 26 additions & 3 deletions scipy/stats/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
"""
from __future__ import division, print_function, absolute_import

from numpy.testing import TestCase, run_module_suite, assert_equal, \
assert_array_equal, assert_almost_equal, assert_array_almost_equal, \
assert_allclose, assert_, assert_raises, rand, dec
from numpy.testing import (TestCase, run_module_suite, assert_equal,
assert_array_equal, assert_array_less,
assert_almost_equal, assert_array_almost_equal, assert_allclose,
assert_, assert_raises, rand, dec)
from numpy.testing.utils import WarningManager
from nose import SkipTest

Expand Down Expand Up @@ -1119,6 +1120,28 @@ def test_foldnorm_zero():
rv = stats.foldnorm(0, scale=1)
assert_equal(rv.cdf(0), 0) # rv.cdf(0) previously resulted in: nan

def test_gh_2477_small_values():
# Check a case that worked in the original issue.
low, high = -11, -10
x = stats.truncnorm.rvs(low, high, 0, 1, size=10)
assert_(low < x.min() < x.max() < high)
# Check a case that failed in the original issue.
low, high = 10, 11
x = stats.truncnorm.rvs(low, high, 0, 1, size=10)
assert_(low < x.min() < x.max() < high)

def test_gh_2477_large_values():
# Check a case that fails because of extreme tailness.
raise SkipTest('truncnorm rvs is know to fail at extreme tails')
low, high = 100, 101
x = stats.truncnorm.rvs(low, high, 0, 1, size=10)
assert_(low < x.min() < x.max() < high)

def test_gh_1489_trac_962_rvs():
# Check the original example.
low, high = 10, 15
x = stats.truncnorm.rvs(low, high, 0, 1, size=10)
assert_(low < x.min() < x.max() < high)

if __name__ == "__main__":
run_module_suite()

0 comments on commit 5db3b49

Please sign in to comment.