Skip to content

Commit 82dde31

Browse files
committed
update pyspark's takeSample
1 parent 48d954d commit 82dde31

File tree

1 file changed

+32
-26
lines changed

1 file changed

+32
-26
lines changed

python/pyspark/rdd.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -362,44 +362,50 @@ def takeSample(self, withReplacement, num, seed=None):
362362
Return a fixed-size sampled subset of this RDD (currently requires
363363
numpy).
364364
365-
>>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP
366-
[4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
365+
>>> rdd = sc.parallelize(range(0, 10))
366+
>>> len(rdd.takeSample(True, 20, 1))
367+
20
368+
>>> len(rdd.takeSample(False, 5, 2))
369+
5
370+
>>> len(rdd.takeSample(False, 15, 3))
371+
10
367372
"""
368373

369-
numStDev = 10.0
370-
initialCount = self.count()
371-
372374
if num < 0:
373-
raise ValueError
375+
raise ValueError("Sample size cannot be negative.")
376+
elif num == 0:
377+
return []
374378

375-
if initialCount == 0 or num == 0:
376-
return list()
379+
initialCount = self.count()
380+
if initialCount == 0:
381+
return []
377382

378383
rand = Random(seed)
379-
if (not withReplacement) and num > initialCount:
384+
385+
if (not withReplacement) and num >= initialCount:
380386
# shuffle current RDD and return
381387
samples = self.collect()
382-
fraction = float(num) / initialCount
383-
num = initialCount
384-
else:
385-
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
386-
if num > maxSampleSize:
387-
raise ValueError
388-
389-
fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement)
388+
rand.shuffle(samples)
389+
return samples
390390

391+
numStDev = 10.0
392+
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
393+
if num > maxSampleSize:
394+
raise ValueError("Sample size cannot be greater than %d." % maxSampleSize)
395+
396+
fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement)
397+
samples = self.sample(withReplacement, fraction, seed).collect()
398+
399+
# If the first sample didn't turn out large enough, keep trying to take samples;
400+
# this shouldn't happen often because we use a big multiplier for their initial size.
401+
# See: scala/spark/RDD.scala
402+
while len(samples) < num:
403+
# TODO: add log warning for when more than one iteration was run
404+
seed = rand.randint(0, sys.maxint)
391405
samples = self.sample(withReplacement, fraction, seed).collect()
392406

393-
# If the first sample didn't turn out large enough, keep trying to take samples;
394-
# this shouldn't happen often because we use a big multiplier for their initial size.
395-
# See: scala/spark/RDD.scala
396-
while len(samples) < num:
397-
#TODO add log warning for when more than one iteration was run
398-
seed = rand.randint(0, sys.maxint)
399-
samples = self.sample(withReplacement, fraction, seed).collect()
407+
rand.shuffle(samples)
400408

401-
sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint))
402-
sampler.shuffle(samples)
403409
return samples[0:num]
404410

405411
@staticmethod

0 commit comments

Comments
 (0)