Skip to content

[SPARK-2871] [PySpark] add countApproxDistinct() API #2142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ abstract class RDD[T: ClassTag](
*/
@Experimental
def countApproxDistinct(p: Int, sp: Int): Long = {
require(p >= 4, s"p ($p) must be greater than 0")
require(p >= 4, s"p ($p) must be at least 4")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

require(sp <= 32, s"sp ($sp) cannot be greater than 32")
require(sp == 0 || p <= sp, s"p ($p) cannot be greater than sp ($sp)")
val zeroCounter = new HyperLogLogPlus(p, sp)
Expand Down
39 changes: 34 additions & 5 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def portable_hash(x):

>>> portable_hash(None)
0
>>> portable_hash((None, 1))
>>> portable_hash((None, 1)) & 0xffffffff
219750521
"""
if x is None:
Expand All @@ -72,7 +72,7 @@ def portable_hash(x):
for i in x:
h ^= portable_hash(i)
h *= 1000003
h &= 0xffffffff
h &= sys.maxint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In 64-bit machines, the hash of tuple() should be 64 bits (sys.maxint is 64bits).

h ^= len(x)
if h == -1:
h = -2
Expand Down Expand Up @@ -1942,7 +1942,7 @@ def _is_pickled(self):
return True
return False

def _to_jrdd(self):
def _to_java_object_rdd(self):
""" Return an JavaRDD of Object by unpickling

It will convert each Python object into Java object by Pyrolite, whenever the
Expand Down Expand Up @@ -1977,7 +1977,7 @@ def sumApprox(self, timeout, confidence=0.95):
>>> (rdd.sumApprox(1000) - r) / r < 0.05
True
"""
jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_jrdd()
jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd()
jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd())
r = jdrdd.sumApprox(timeout, confidence).getFinalValue()
return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high())
Expand All @@ -1993,11 +1993,40 @@ def meanApprox(self, timeout, confidence=0.95):
>>> (rdd.meanApprox(1000) - r) / r < 0.05
True
"""
jrdd = self.map(float)._to_jrdd()
jrdd = self.map(float)._to_java_object_rdd()
jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd())
r = jdrdd.meanApprox(timeout, confidence).getFinalValue()
return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high())

def countApproxDistinct(self, relativeSD=0.05):
"""
:: Experimental ::
Return approximate number of distinct elements in the RDD.

The algorithm used is based on streamlib's implementation of
"HyperLogLog in Practice: Algorithmic Engineering of a State
of The Art Cardinality Estimation Algorithm", available
<a href="http://dx.doi.org/10.1145/2452376.2452456">here</a>.

@param relativeSD Relative accuracy. Smaller values create
counters that require more space.
It must be greater than 0.000017.

>>> n = sc.parallelize(range(1000)).map(str).countApproxDistinct()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can u add a test to make sure that if u have 1000 non-distinct elements (i.e. the same element appearing 1000 times), this doesn't return ~ 1000?

Asking because I'm not sure how pyspark interacts with Java - if it is through byte array, then the hashcode could be wrong for byte arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, it helped me to find out an invalid test case of tuple, which will be unpickled as []Object in Java, and the hashCode of it is not determined by content, so I changed it into set([]), which should have similar behavior across Python and Java.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had changed it to do hash calculation in Python, so it can support all hashable types in Python.

>>> 950 < n < 1050
True
>>> n = sc.parallelize([i % 20 for i in range(1000)]).countApproxDistinct()
>>> 18 < n < 22
True
"""
if relativeSD < 0.000017:
raise ValueError("relativeSD should be greater than 0.000017")
if relativeSD > 0.37:
raise ValueError("relativeSD should be smaller than 0.37")
# the hash space in Java is 2^32
hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF)
return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD)


class PipelinedRDD(RDD):

Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,22 @@ def test_zip_with_different_number_of_items(self):
self.assertEquals(a.count(), b.count())
self.assertRaises(Exception, lambda: a.zip(b).count())

def test_count_approx_distinct(self):
rdd = self.sc.parallelize(range(1000))
self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050)
self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050)
self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050)
self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050)

rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7)
self.assertTrue(18 < rdd.countApproxDistinct() < 22)
self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22)
self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22)
self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22)

self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001))
self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.5))

def test_histogram(self):
# empty
rdd = self.sc.parallelize([])
Expand Down