@@ -362,44 +362,50 @@ def takeSample(self, withReplacement, num, seed=None):
362
362
Return a fixed-size sampled subset of this RDD (currently requires
363
363
numpy).
364
364
365
- >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP
366
- [4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
365
+ >>> rdd = sc.parallelize(range(0, 10))
366
+ >>> len(rdd.takeSample(True, 20, 1))
367
+ 20
368
+ >>> len(rdd.takeSample(False, 5, 2))
369
+ 5
370
+ >>> len(rdd.takeSample(False, 15, 3))
371
+ 10
367
372
"""
368
373
369
- numStDev = 10.0
370
- initialCount = self .count ()
371
-
372
374
if num < 0 :
373
- raise ValueError
375
+ raise ValueError ("Sample size cannot be negative." )
376
+ elif num == 0 :
377
+ return []
374
378
375
- if initialCount == 0 or num == 0 :
376
- return list ()
379
+ initialCount = self .count ()
380
+ if initialCount == 0 :
381
+ return []
377
382
378
383
rand = Random (seed )
379
- if (not withReplacement ) and num > initialCount :
384
+
385
+ if (not withReplacement ) and num >= initialCount :
380
386
# shuffle current RDD and return
381
387
samples = self .collect ()
382
- fraction = float (num ) / initialCount
383
- num = initialCount
384
- else :
385
- maxSampleSize = sys .maxint - int (numStDev * sqrt (sys .maxint ))
386
- if num > maxSampleSize :
387
- raise ValueError
388
-
389
- fraction = self ._computeFractionForSampleSize (num , initialCount , withReplacement )
388
+ rand .shuffle (samples )
389
+ return samples
390
390
391
+ numStDev = 10.0
392
+ maxSampleSize = sys .maxint - int (numStDev * sqrt (sys .maxint ))
393
+ if num > maxSampleSize :
394
+ raise ValueError ("Sample size cannot be greater than %d." % maxSampleSize )
395
+
396
+ fraction = RDD ._computeFractionForSampleSize (num , initialCount , withReplacement )
397
+ samples = self .sample (withReplacement , fraction , seed ).collect ()
398
+
399
+ # If the first sample didn't turn out large enough, keep trying to take samples;
400
+ # this shouldn't happen often because we use a big multiplier for their initial size.
401
+ # See: scala/spark/RDD.scala
402
+ while len (samples ) < num :
403
+ # TODO: add log warning for when more than one iteration was run
404
+ seed = rand .randint (0 , sys .maxint )
391
405
samples = self .sample (withReplacement , fraction , seed ).collect ()
392
406
393
- # If the first sample didn't turn out large enough, keep trying to take samples;
394
- # this shouldn't happen often because we use a big multiplier for their initial size.
395
- # See: scala/spark/RDD.scala
396
- while len (samples ) < num :
397
- #TODO add log warning for when more than one iteration was run
398
- seed = rand .randint (0 , sys .maxint )
399
- samples = self .sample (withReplacement , fraction , seed ).collect ()
407
+ rand .shuffle (samples )
400
408
401
- sampler = RDDSampler (withReplacement , fraction , rand .randint (0 , sys .maxint ))
402
- sampler .shuffle (samples )
403
409
return samples [0 :num ]
404
410
405
411
@staticmethod
0 commit comments