Skip to content

Commit ce2ea7d

Browse files
authored
Minor speed/accuracy improvement for kde() (gh-119910)
1 parent 90ec19f commit ce2ea7d

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

Lib/statistics.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -953,12 +953,14 @@ def kde(data, h, kernel='normal', *, cumulative=False):
953953

954954
case 'quartic' | 'biweight':
955955
K = lambda t: 15/16 * (1.0 - t * t) ** 2
956-
W = lambda t: 3/16 * t**5 - 5/8 * t**3 + 15/16 * t + 1/2
956+
W = lambda t: sumprod((3/16, -5/8, 15/16, 1/2),
957+
(t**5, t**3, t, 1.0))
957958
support = 1.0
958959

959960
case 'triweight':
960961
K = lambda t: 35/32 * (1.0 - t * t) ** 3
961-
W = lambda t: 35/32 * (-1/7*t**7 + 3/5*t**5 - t**3 + t) + 1/2
962+
W = lambda t: sumprod((-5/32, 21/32, -35/32, 35/32, 1/2),
963+
(t**7, t**5, t**3, t, 1.0))
962964
support = 1.0
963965

964966
case 'cosine':
@@ -974,12 +976,10 @@ def kde(data, h, kernel='normal', *, cumulative=False):
974976
if support is None:
975977

976978
def pdf(x):
977-
n = len(data)
978-
return sum(K((x - x_i) / h) for x_i in data) / (n * h)
979+
return sum(K((x - x_i) / h) for x_i in data) / (len(data) * h)
979980

980981
def cdf(x):
981-
n = len(data)
982-
return sum(W((x - x_i) / h) for x_i in data) / n
982+
return sum(W((x - x_i) / h) for x_i in data) / len(data)
983983

984984
else:
985985

@@ -1732,7 +1732,7 @@ def _quartic_invcdf_estimate(p):
17321732

17331733
_quartic_invcdf = _newton_raphson(
17341734
f_inv_estimate = _quartic_invcdf_estimate,
1735-
f = lambda t: 3/16 * t**5 - 5/8 * t**3 + 15/16 * t + 1/2,
1735+
f = lambda t: sumprod((3/16, -5/8, 15/16, 1/2), (t**5, t**3, t, 1.0)),
17361736
f_prime = lambda t: 15/16 * (1.0 - t * t) ** 2)
17371737

17381738
def _triweight_invcdf_estimate(p):
@@ -1742,7 +1742,8 @@ def _triweight_invcdf_estimate(p):
17421742

17431743
_triweight_invcdf = _newton_raphson(
17441744
f_inv_estimate = _triweight_invcdf_estimate,
1745-
f = lambda t: 35/32 * (-1/7*t**7 + 3/5*t**5 - t**3 + t) + 1/2,
1745+
f = lambda t: sumprod((-5/32, 21/32, -35/32, 35/32, 1/2),
1746+
(t**7, t**5, t**3, t, 1.0)),
17461747
f_prime = lambda t: 35/32 * (1.0 - t * t) ** 3)
17471748

17481749
_kernel_invcdfs = {

Lib/test/test_statistics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2444,7 +2444,7 @@ def test_kde_kernel_invcdfs(self):
24442444
with self.subTest(kernel=kernel):
24452445
cdf = kde([0.0], h=1.0, kernel=kernel, cumulative=True)
24462446
for x in xarr:
2447-
self.assertAlmostEqual(invcdf(cdf(x)), x, places=5)
2447+
self.assertAlmostEqual(invcdf(cdf(x)), x, places=6)
24482448

24492449
@support.requires_resource('cpu')
24502450
def test_kde_random(self):

0 commit comments

Comments
 (0)