@@ -366,35 +366,37 @@ def takeSample(self, withReplacement, num, seed=None):
366
366
[4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
367
367
"""
368
368
369
- #TODO remove
370
- logging .basicConfig (level = logging .INFO )
371
369
numStDev = 10.0
372
370
initialCount = self .count ()
373
371
374
372
if num < 0 :
375
373
raise ValueError
376
374
377
- if initialCount == 0 :
375
+ if initialCount == 0 or num == 0 :
378
376
return list ()
379
377
378
+ rand = Random (seed )
380
379
if (not withReplacement ) and num > initialCount :
381
- raise ValueError
380
+ # shuffle current RDD and return
381
+ samples = self .collect ()
382
+ fraction = float (num ) / initialCount
383
+ num = initialCount
384
+ else :
385
+ maxSampleSize = sys .maxint - int (numStDev * sqrt (sys .maxint ))
386
+ if num > maxSampleSize :
387
+ raise ValueError
382
388
383
- maxSampleSize = sys .maxint - int (numStDev * sqrt (sys .maxint ))
384
- if num > maxSampleSize :
385
- raise ValueError
389
+ fraction = self ._computeFractionForSampleSize (num , initialCount , withReplacement )
386
390
387
- fraction = self ._computeFractionForSampleSize (num , initialCount , withReplacement )
388
-
389
- samples = self .sample (withReplacement , fraction , seed ).collect ()
391
+ samples = self .sample (withReplacement , fraction , seed ).collect ()
390
392
391
- # If the first sample didn't turn out large enough, keep trying to take samples;
392
- # this shouldn't happen often because we use a big multiplier for their initial size.
393
- # See: scala/spark/RDD.scala
394
- rand = Random ( seed )
395
- while len ( samples ) < num :
396
- #TODO add log warning for when more than one iteration was run
397
- samples = self .sample (withReplacement , fraction , rand . randint ( 0 , sys . maxint ) ).collect ()
393
+ # If the first sample didn't turn out large enough, keep trying to take samples;
394
+ # this shouldn't happen often because we use a big multiplier for their initial size.
395
+ # See: scala/spark/RDD.scala
396
+ while len ( samples ) < num :
397
+ #TODO add log warning for when more than one iteration was run
398
+ seed = rand . randint ( 0 , sys . maxint )
399
+ samples = self .sample (withReplacement , fraction , seed ).collect ()
398
400
399
401
sampler = RDDSampler (withReplacement , fraction , rand .randint (0 , sys .maxint ))
400
402
sampler .shuffle (samples )
0 commit comments