Skip to content

Commit

Permalink
TST: add test for broadcasting in stats.nct distribution.
Browse files Browse the repository at this point in the history
  • Loading branch information
rgommers committed May 20, 2013
1 parent bc45762 commit 53d7a02
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
1 change: 1 addition & 0 deletions scipy/stats/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4844,6 +4844,7 @@ class nct_gen(rv_continuous):
"""
def _argcheck(self, df, nc):
return (df > 0) & (nc == nc)

def _rvs(self, df, nc):
return norm.rvs(loc=nc,size=self._size)*sqrt(df) / sqrt(chi2.rvs(df,size=self._size))

Expand Down
26 changes: 17 additions & 9 deletions scipy/stats/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,23 @@ def test_poisson(self):
assert_almost_equal(prob_lb, 1, decimal=14)


class TestNct(TestCase):
def test_nc_parameter(self):
# Parameter values c<=0 were not enabled (gh-2402).
# For negative values c and for c=0 results of rv.cdf(0) below were nan
rv = stats.nct(5, 0)
assert_equal(rv.cdf(0), 0.5)
rv = stats.nct(5, -1)
assert_almost_equal(rv.cdf(0), 0.841344746069, decimal=10)

def test_broadcasting(self):
res = stats.nct.pdf(5, np.arange(4,7)[:,None], np.linspace(0.1, 1, 4))
expected = array([[ 0.00321886, 0.00557466, 0.00918418, 0.01442997],
[ 0.00217142, 0.00395366, 0.00683888, 0.01126276],
[ 0.00153078, 0.00291093, 0.00525206, 0.00900815]])
assert_allclose(res, expected, rtol=1e-5)


def test_regression_ticket_1316():
# The following was raising an exception, because _construct_default_doc()
# did not handle the default keyword extradoc=None. See ticket #1316.
Expand Down Expand Up @@ -1144,14 +1161,5 @@ def test_foldnorm_zero():
assert_equal(rv.cdf(0), 0) # rv.cdf(0) previously resulted in: nan


def test_nct_ticket_1883():
# Parameter values c<=0 were not enabled
# For negative values c and for c=0 results of rv.cdf(0) below were nan
rv = stats.nct(5, 0)
assert_equal(rv.cdf(0), 0.5)
rv = stats.nct(5, -1)
assert_almost_equal(rv.cdf(0), 0.841344746069, decimal=10)


if __name__ == "__main__":
run_module_suite()

0 comments on commit 53d7a02

Please sign in to comment.