Skip to content

Commit 84e85fa

Browse files
committed
remove evenBuckets, add more tests (including str)
1 parent d9a0722 commit 84e85fa

File tree

2 files changed

+65
-24
lines changed

2 files changed

+65
-24
lines changed

python/pyspark/rdd.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def redFunc(left_counter, right_counter):
856856

857857
return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc)
858858

859-
def histogram(self, buckets, evenBuckets=False):
859+
def histogram(self, buckets):
860860
"""
861861
Compute a histogram using the provided buckets. The buckets
862862
are all open to the right except for the last which is closed.
@@ -866,7 +866,7 @@ def histogram(self, buckets, evenBuckets=False):
866866
867867
If your histogram is evenly spaced (e.g. [0, 10, 20, 30]),
868868
this can be switched from an O(log n) inseration to O(1) per
869-
element(where n = # buckets), if you set `even` to True.
869+
element(where n = # buckets).
870870
871871
Buckets must be sorted and not contain any duplicates, must be
872872
at least two elements.
@@ -886,31 +886,45 @@ def histogram(self, buckets, evenBuckets=False):
886886
([0, 25, 50], [25, 26])
887887
>>> rdd.histogram([0, 5, 25, 50])
888888
([0, 5, 25, 50], [5, 20, 26])
889-
>>> rdd.histogram([0, 15, 30, 45, 60], True)
889+
>>> rdd.histogram([0, 15, 30, 45, 60]) # evenly spaced buckets
890890
([0, 15, 30, 45, 60], [15, 15, 15, 6])
891+
>>> rdd = sc.parallelize(["ab", "ac", "b", "bd", "ef"])
892+
>>> rdd.histogram(("a", "b", "c"))
893+
(('a', 'b', 'c'), [2, 2])
891894
"""
892895

893896
if isinstance(buckets, (int, long)):
894897
if buckets < 1:
895898
raise ValueError("number of buckets must be >= 1")
896899

897900
# filter out non-comparable elements
898-
self = self.filter(lambda x: x is not None and not isnan(x))
901+
def comparable(x):
902+
if x is None:
903+
return False
904+
if type(x) is float and isnan(x):
905+
return False
906+
return True
907+
908+
filtered = self.filter(comparable)
899909

900910
# faster than stats()
901911
def minmax(a, b):
902912
return min(a[0], b[0]), max(a[1], b[1])
903913
try:
904-
minv, maxv = self.map(lambda x: (x, x)).reduce(minmax)
914+
minv, maxv = filtered.map(lambda x: (x, x)).reduce(minmax)
905915
except TypeError as e:
906-
if e.message == "reduce() of empty sequence with no initial value":
916+
if " empty " in e.message:
907917
raise ValueError("can not generate buckets from empty RDD")
908918
raise
909919

910920
if minv == maxv or buckets == 1:
911-
return [minv, maxv], [self.count()]
921+
return [minv, maxv], [filtered.count()]
922+
923+
try:
924+
inc = (maxv - minv) / buckets
925+
except TypeError:
926+
raise TypeError("Can not generate buckets with non-number in RDD")
912927

913-
inc = (maxv - minv) / buckets
914928
if isinf(inc):
915929
raise ValueError("Can not generate buckets with infinite value")
916930

@@ -920,28 +934,43 @@ def minmax(a, b):
920934

921935
buckets = [i * inc + minv for i in range(buckets)]
922936
buckets.append(maxv) # fix accumulated error
923-
evenBuckets = True
937+
even = True
924938

925-
else:
939+
elif isinstance(buckets, (list, tuple)):
926940
if len(buckets) < 2:
927941
raise ValueError("buckets should have more than one value")
928942

929-
if any(i is None or isnan(i) for i in buckets):
943+
if any(i is None or isinstance(i, float) and isnan(i) for i in buckets):
930944
raise ValueError("can not have None or NaN in buckets")
931945

932-
if sorted(buckets) != buckets:
946+
if sorted(buckets) != list(buckets):
933947
raise ValueError("buckets should be sorted")
934948

949+
if len(set(buckets)) != len(buckets):
950+
raise ValueError("buckets should not contain duplicated values")
951+
935952
minv = buckets[0]
936953
maxv = buckets[-1]
937-
inc = buckets[1] - buckets[0] if evenBuckets else None
954+
even = False
955+
inc = None
956+
try:
957+
steps = [buckets[i + 1] - buckets[i] for i in range(len(buckets) - 1)]
958+
except TypeError:
959+
pass # objects in buckets do not support '-'
960+
else:
961+
if max(steps) - min(steps) < 1e-10: # handle precision errors
962+
even = True
963+
inc = (maxv - minv) / (len(buckets) - 1)
964+
965+
else:
966+
raise TypeError("buckets should be a list or tuple or number(int or long)")
938967

939968
def histogram(iterator):
940969
counters = [0] * len(buckets)
941970
for i in iterator:
942-
if i is None or isnan(i) or i > maxv or i < minv:
971+
if i is None or (type(i) is float and isnan(i)) or i > maxv or i < minv:
943972
continue
944-
t = (int((i - minv) / inc) if evenBuckets
973+
t = (int((i - minv) / inc) if even
945974
else bisect.bisect_right(buckets, i) - 1)
946975
counters[t] += 1
947976
# add last two together

python/pyspark/tests.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -368,26 +368,25 @@ def test_histogram(self):
368368
# empty
369369
rdd = self.sc.parallelize([])
370370
self.assertEquals([0], rdd.histogram([0, 10])[1])
371-
self.assertEquals([0], rdd.histogram([0, 10], True)[1])
371+
self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1])
372+
self.assertRaises(ValueError, lambda: rdd.histogram(1))
372373

373374
# out of range
374375
rdd = self.sc.parallelize([10.01, -0.01])
375376
self.assertEquals([0], rdd.histogram([0, 10])[1])
376-
self.assertEquals([0], rdd.histogram([0, 10], True)[1])
377+
self.assertEquals([0, 0], rdd.histogram((0, 4, 10))[1])
377378

378379
# in range with one bucket
379380
rdd = self.sc.parallelize(range(1, 5))
380381
self.assertEquals([4], rdd.histogram([0, 10])[1])
381-
self.assertEquals([4], rdd.histogram([0, 10], True)[1])
382+
self.assertEquals([3, 1], rdd.histogram([0, 4, 10])[1])
382383

383384
# in range with one bucket exact match
384385
self.assertEquals([4], rdd.histogram([1, 4])[1])
385-
self.assertEquals([4], rdd.histogram([1, 4], True)[1])
386386

387387
# out of range with two buckets
388388
rdd = self.sc.parallelize([10.01, -0.01])
389389
self.assertEquals([0, 0], rdd.histogram([0, 5, 10])[1])
390-
self.assertEquals([0, 0], rdd.histogram([0, 5, 10], True)[1])
391390

392391
# out of range with two uneven buckets
393392
rdd = self.sc.parallelize([10.01, -0.01])
@@ -396,12 +395,10 @@ def test_histogram(self):
396395
# in range with two buckets
397396
rdd = self.sc.parallelize([1, 2, 3, 5, 6])
398397
self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1])
399-
self.assertEquals([3, 2], rdd.histogram([0, 5, 10], True)[1])
400398

401399
# in range with two bucket and None
402400
rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')])
403401
self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1])
404-
self.assertEquals([3, 2], rdd.histogram([0, 5, 10], True)[1])
405402

406403
# in range with two uneven buckets
407404
rdd = self.sc.parallelize([1, 2, 3, 5, 6])
@@ -421,12 +418,14 @@ def test_histogram(self):
421418
self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
422419

423420
# out of range with infinite buckets
424-
rdd = self.sc.parallelize([10.01, -0.01, float('nan')])
425-
self.assertEquals([1, 1], rdd.histogram([float('-inf'), 0, float('inf')])[1])
421+
rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")])
422+
self.assertEquals([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1])
426423

427424
# invalid buckets
428425
self.assertRaises(ValueError, lambda: rdd.histogram([]))
429426
self.assertRaises(ValueError, lambda: rdd.histogram([1]))
427+
self.assertRaises(ValueError, lambda: rdd.histogram(0))
428+
self.assertRaises(TypeError, lambda: rdd.histogram({}))
430429

431430
# without buckets
432431
rdd = self.sc.parallelize(range(1, 5))
@@ -456,6 +455,19 @@ def test_histogram(self):
456455
rdd = self.sc.parallelize([float('nan')])
457456
self.assertRaises(ValueError, lambda: rdd.histogram(2))
458457

458+
# string
459+
rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2)
460+
self.assertEquals([2, 2], rdd.histogram(["a", "b", "c"])[1])
461+
self.assertEquals((["ab", "ef"], [5]), rdd.histogram(1))
462+
self.assertRaises(TypeError, lambda: rdd.histogram(2))
463+
464+
# mixed RDD
465+
rdd = self.sc.parallelize([1, 4, "ab", "ac", "b"], 2)
466+
self.assertEquals([1, 1], rdd.histogram([0, 4, 10])[1])
467+
self.assertEquals([2, 1], rdd.histogram(["a", "b", "c"])[1])
468+
self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
469+
self.assertRaises(TypeError, lambda: rdd.histogram(2))
470+
459471

460472
class TestIO(PySparkTestCase):
461473

0 commit comments

Comments
 (0)