Skip to content

Commit 3cca196

Browse files
committed
[SPARK-4148][PySpark] fix seed distribution and add some tests for rdd.sample
The current way of seed distribution makes the random sequences from partition i and i+1 offset by 1. ~~~ In [14]: import random In [15]: r1 = random.Random(10) In [16]: r1.randint(0, 1) Out[16]: 1 In [17]: r1.random() Out[17]: 0.4288890546751146 In [18]: r1.random() Out[18]: 0.5780913011344704 In [19]: r2 = random.Random(10) In [20]: r2.randint(0, 1) Out[20]: 1 In [21]: r2.randint(0, 1) Out[21]: 0 In [22]: r2.random() Out[22]: 0.5780913011344704 ~~~ Note: The new tests are not for this bug fix. Author: Xiangrui Meng <meng@databricks.com> Closes #3010 from mengxr/SPARK-4148 and squashes the following commits: 869ae4b [Xiangrui Meng] move tests tests.py c1bacd9 [Xiangrui Meng] fix seed distribution and add some tests for rdd.sample
1 parent 2aca97c commit 3cca196

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

python/pyspark/rdd.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,9 +316,6 @@ def sample(self, withReplacement, fraction, seed=None):
316316
"""
317317
Return a sampled subset of this RDD (relies on numpy and falls back
318318
on default random generator if numpy is unavailable).
319-
320-
>>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
321-
[2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
322319
"""
323320
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
324321
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)

python/pyspark/rddsampler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,13 @@ def __init__(self, withReplacement, seed=None):
4040
def initRandomGenerator(self, split):
4141
if self._use_numpy:
4242
import numpy
43-
self._random = numpy.random.RandomState(self._seed)
43+
self._random = numpy.random.RandomState(self._seed ^ split)
4444
else:
45-
self._random = random.Random(self._seed)
45+
self._random = random.Random(self._seed ^ split)
4646

47-
for _ in range(0, split):
48-
# discard the next few values in the sequence to have a
49-
# different seed for the different splits
50-
self._random.randint(0, 2 ** 32 - 1)
47+
# mixing because the initial seeds are close to each other
48+
for _ in xrange(10):
49+
self._random.randint(0, 1)
5150

5251
self._split = split
5352
self._rand_initialized = True

python/pyspark/tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,21 @@ def test_distinct(self):
648648
self.assertEquals(result.getNumPartitions(), 5)
649649
self.assertEquals(result.count(), 3)
650650

651+
def test_sample(self):
652+
rdd = self.sc.parallelize(range(0, 100), 4)
653+
wo = rdd.sample(False, 0.1, 2).collect()
654+
wo_dup = rdd.sample(False, 0.1, 2).collect()
655+
self.assertSetEqual(set(wo), set(wo_dup))
656+
wr = rdd.sample(True, 0.2, 5).collect()
657+
wr_dup = rdd.sample(True, 0.2, 5).collect()
658+
self.assertSetEqual(set(wr), set(wr_dup))
659+
wo_s10 = rdd.sample(False, 0.3, 10).collect()
660+
wo_s20 = rdd.sample(False, 0.3, 20).collect()
661+
self.assertNotEqual(set(wo_s10), set(wo_s20))
662+
wr_s11 = rdd.sample(True, 0.4, 11).collect()
663+
wr_s21 = rdd.sample(True, 0.4, 21).collect()
664+
self.assertNotEqual(set(wr_s11), set(wr_s21))
665+
651666

652667
class ProfilerTests(PySparkTestCase):
653668

0 commit comments

Comments
 (0)