Skip to content

Commit 0e18a2d

Browse files
committed
add histgram() API
1 parent 050f8d0 commit 0e18a2d

File tree

2 files changed

+191
-1
lines changed

2 files changed

+191
-1
lines changed

python/pyspark/rdd.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import heapq
3333
import bisect
3434
from random import Random
35-
from math import sqrt, log
35+
from math import sqrt, log, isinf, isnan
3636

3737
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
3838
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
@@ -856,6 +856,104 @@ def redFunc(left_counter, right_counter):
856856

857857
return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc)
858858

859+
def histogram(self, buckets, evenBuckets=False):
860+
"""
861+
Compute a histogram using the provided buckets. The buckets
862+
are all open to the right except for the last which is closed.
863+
e.g. [1,10,20,50] means the buckets are [1,10) [10,20) [20,50],
864+
which means 1<=x<10, 10<=x<20, 20<=x<=50. And on the input of 1
865+
and 50 we would have a histogram of 1,0,1.
866+
867+
If your histogram is evenly spaced (e.g. [0, 10, 20, 30]),
868+
this can be switched from an O(log n) inseration to O(1) per
869+
element(where n = # buckets), if you set `even` to True.
870+
871+
Buckets must be sorted and not contain any duplicates, must be
872+
at least two elements.
873+
874+
If `buckets` is a number, it will generates buckets which is
875+
evenly spaced between the minimum and maximum of the RDD. For
876+
example, if the min value is 0 and the max is 100, given buckets
877+
as 2, the resulting buckets will be [0,50) [50,100]. buckets must
878+
be at least 1 If the RDD contains infinity, NaN throws an exception
879+
If the elements in RDD do not vary (max == min) always returns
880+
a single bucket.
881+
882+
It will return an tuple of buckets and histogram.
883+
884+
>>> rdd = sc.parallelize(range(51))
885+
>>> rdd.histogram(2)
886+
([0, 25, 50], [25, 26])
887+
>>> rdd.histogram([0, 5, 25, 50])
888+
([0, 5, 25, 50], [5, 20, 26])
889+
>>> rdd.histogram([0, 15, 30, 45, 60], True)
890+
([0, 15, 30, 45, 60], [15, 15, 15, 6])
891+
"""
892+
893+
if isinstance(buckets, (int, long)):
894+
if buckets < 1:
895+
raise ValueError("buckets should not less than 1")
896+
897+
# filter out non-comparable elements
898+
self = self.filter(lambda x: x is not None and not isnan(x))
899+
900+
# faster than stats()
901+
def minmax(a, b):
902+
return min(a[0], b[0]), max(a[1], b[1])
903+
try:
904+
minv, maxv = self.map(lambda x: (x, x)).reduce(minmax)
905+
except TypeError as e:
906+
if e.message == "reduce() of empty sequence with no initial value":
907+
raise ValueError("can not generate buckets from empty RDD")
908+
raise
909+
910+
if minv == maxv or buckets == 1:
911+
return [minv, maxv], [self.count()]
912+
913+
inc = (maxv - minv) / buckets
914+
if isinf(inc):
915+
raise ValueError("Can not generate buckets with infinite value")
916+
917+
# keep them as integer if possible
918+
if inc * buckets != maxv - minv:
919+
inc = (maxv - minv) * 1.0 / buckets
920+
921+
buckets = [i * inc + minv for i in range(buckets)]
922+
buckets.append(maxv) # fix accumulated error
923+
evenBuckets = True
924+
925+
else:
926+
if len(buckets) < 2:
927+
raise ValueError("buckets should have more than one value")
928+
929+
if any(i is None or isnan(i) for i in buckets):
930+
raise ValueError("can not have None or NaN in buckets")
931+
932+
if sorted(buckets) != buckets:
933+
raise ValueError("buckets should be sorted")
934+
935+
minv = buckets[0]
936+
maxv = buckets[-1]
937+
inc = buckets[1] - buckets[0] if evenBuckets else None
938+
939+
def histogram(iterator):
940+
counters = [0] * len(buckets)
941+
for i in iterator:
942+
if i is None or isnan(i) or i > maxv or i < minv:
943+
continue
944+
t = (int((i - minv) / inc) if evenBuckets
945+
else bisect.bisect_right(buckets, i) - 1)
946+
counters[t] += 1
947+
# add last two together
948+
last = counters.pop()
949+
counters[-1] += last
950+
return [counters]
951+
952+
def mergeCounters(a, b):
953+
return [i + j for i, j in zip(a, b)]
954+
955+
return buckets, self.mapPartitions(histogram).reduce(mergeCounters)
956+
859957
def mean(self):
860958
"""
861959
Compute the mean of this RDD's elements.

python/pyspark/tests.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,98 @@ def test_zip_with_different_number_of_items(self):
364364
self.assertEquals(a.count(), b.count())
365365
self.assertRaises(Exception, lambda: a.zip(b).count())
366366

367+
def test_histogram(self):
368+
# empty
369+
rdd = self.sc.parallelize([])
370+
self.assertEquals([0], rdd.histogram([0, 10])[1])
371+
self.assertEquals([0], rdd.histogram([0, 10], True)[1])
372+
373+
# out of range
374+
rdd = self.sc.parallelize([10.01, -0.01])
375+
self.assertEquals([0], rdd.histogram([0, 10])[1])
376+
self.assertEquals([0], rdd.histogram([0, 10], True)[1])
377+
378+
# in range with one bucket
379+
rdd = self.sc.parallelize(range(1, 5))
380+
self.assertEquals([4], rdd.histogram([0, 10])[1])
381+
self.assertEquals([4], rdd.histogram([0, 10], True)[1])
382+
383+
# in range with one bucket exact match
384+
self.assertEquals([4], rdd.histogram([1, 4])[1])
385+
self.assertEquals([4], rdd.histogram([1, 4], True)[1])
386+
387+
# out of range with two buckets
388+
rdd = self.sc.parallelize([10.01, -0.01])
389+
self.assertEquals([0, 0], rdd.histogram([0, 5, 10])[1])
390+
self.assertEquals([0, 0], rdd.histogram([0, 5, 10], True)[1])
391+
392+
# out of range with two uneven buckets
393+
rdd = self.sc.parallelize([10.01, -0.01])
394+
self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1])
395+
396+
# in range with two buckets
397+
rdd = self.sc.parallelize([1, 2, 3, 5, 6])
398+
self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1])
399+
self.assertEquals([3, 2], rdd.histogram([0, 5, 10], True)[1])
400+
401+
# in range with two bucket and None
402+
rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')])
403+
self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1])
404+
self.assertEquals([3, 2], rdd.histogram([0, 5, 10], True)[1])
405+
406+
# in range with two uneven buckets
407+
rdd = self.sc.parallelize([1, 2, 3, 5, 6])
408+
self.assertEquals([3, 2], rdd.histogram([0, 5, 11])[1])
409+
410+
# mixed range with two uneven buckets
411+
rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01])
412+
self.assertEquals([4, 3], rdd.histogram([0, 5, 11])[1])
413+
414+
# mixed range with four uneven buckets
415+
rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1])
416+
self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
417+
418+
# mixed range with uneven buckets and NaN
419+
rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0,
420+
199.0, 200.0, 200.1, None, float('nan')])
421+
self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
422+
423+
# out of range with infinite buckets
424+
rdd = self.sc.parallelize([10.01, -0.01, float('nan')])
425+
self.assertEquals([1, 1], rdd.histogram([float('-inf'), 0, float('inf')])[1])
426+
427+
# invalid buckets
428+
self.assertRaises(ValueError, lambda: rdd.histogram([]))
429+
self.assertRaises(ValueError, lambda: rdd.histogram([1]))
430+
431+
# without buckets
432+
rdd = self.sc.parallelize(range(1, 5))
433+
self.assertEquals(([1, 4], [4]), rdd.histogram(1))
434+
435+
# without buckets single element
436+
rdd = self.sc.parallelize([1])
437+
self.assertEquals(([1, 1], [1]), rdd.histogram(1))
438+
439+
# without bucket no range
440+
rdd = self.sc.parallelize([1] * 4)
441+
self.assertEquals(([1, 1], [4]), rdd.histogram(1))
442+
443+
# without buckets basic two
444+
rdd = self.sc.parallelize(range(1, 5))
445+
self.assertEquals(([1, 2.5, 4], [2, 2]), rdd.histogram(2))
446+
447+
# without buckets with more requested than elements
448+
rdd = self.sc.parallelize([1, 2])
449+
buckets = [1 + 0.2 * i for i in range(6)]
450+
hist = [1, 0, 0, 0, 1]
451+
self.assertEquals((buckets, hist), rdd.histogram(5))
452+
453+
# invalid RDDs
454+
rdd = self.sc.parallelize([1, float('inf')])
455+
self.assertRaises(ValueError, lambda: rdd.histogram(2))
456+
rdd = self.sc.parallelize([float('nan')])
457+
self.assertRaises(ValueError, lambda: rdd.histogram(2))
458+
367459

368460
class TestIO(PySparkTestCase):
369461

0 commit comments

Comments
 (0)