Skip to content

Commit fb1452f

Browse files
committed
allowing num to be greater than count in all cases
1 parent 1481b01 commit fb1452f

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -396,25 +396,24 @@ abstract class RDD[T: ClassTag](
396396
throw new IllegalArgumentException("Negative number of elements requested")
397397
}
398398

399-
if (initialCount == 0) {
399+
if (initialCount == 0 || num == 0) {
400400
return new Array[T](0)
401401
}
402402

403-
if (!withReplacement && num > initialCount) {
404-
throw new IllegalArgumentException("Cannot create sample larger than the original when " +
405-
"sampling without replacement")
406-
}
407-
408403
val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
409404
if (num > maxSampleSize) {
410405
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
411406
s"$numStDev * math.sqrt(Int.MaxValue)")
412407
}
413408

409+
val rand = new Random(seed)
410+
if (!withReplacement && num > initialCount) {
411+
return Utils.randomizeInPlace(this.collect(), rand)
412+
}
413+
414414
val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
415415
withReplacement)
416416

417-
val rand = new Random(seed)
418417
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
419418

420419
// If the first sample didn't turn out large enough, keep trying to take samples;

python/pyspark/rdd.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -366,35 +366,37 @@ def takeSample(self, withReplacement, num, seed=None):
366366
[4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
367367
"""
368368

369-
#TODO remove
370-
logging.basicConfig(level=logging.INFO)
371369
numStDev = 10.0
372370
initialCount = self.count()
373371

374372
if num < 0:
375373
raise ValueError
376374

377-
if initialCount == 0:
375+
if initialCount == 0 or num == 0:
378376
return list()
379377

378+
rand = Random(seed)
380379
if (not withReplacement) and num > initialCount:
381-
raise ValueError
380+
# shuffle current RDD and return
381+
samples = self.collect()
382+
fraction = float(num) / initialCount
383+
num = initialCount
384+
else:
385+
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
386+
if num > maxSampleSize:
387+
raise ValueError
382388

383-
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
384-
if num > maxSampleSize:
385-
raise ValueError
389+
fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement)
386390

387-
fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement)
388-
389-
samples = self.sample(withReplacement, fraction, seed).collect()
391+
samples = self.sample(withReplacement, fraction, seed).collect()
390392

391-
# If the first sample didn't turn out large enough, keep trying to take samples;
392-
# this shouldn't happen often because we use a big multiplier for their initial size.
393-
# See: scala/spark/RDD.scala
394-
rand = Random(seed)
395-
while len(samples) < num:
396-
#TODO add log warning for when more than one iteration was run
397-
samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect()
393+
# If the first sample didn't turn out large enough, keep trying to take samples;
394+
# this shouldn't happen often because we use a big multiplier for their initial size.
395+
# See: scala/spark/RDD.scala
396+
while len(samples) < num:
397+
#TODO add log warning for when more than one iteration was run
398+
seed = rand.randint(0, sys.maxint)
399+
samples = self.sample(withReplacement, fraction, seed).collect()
398400

399401
sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint))
400402
sampler.shuffle(samples)

0 commit comments

Comments
 (0)