Skip to content

Commit eff89e2

Browse files
committed
addressed reviewer comments.
Note that logging isn’t added to rdd.py because it seemed to be clobbering with the log4j logs
1 parent ecab508 commit eff89e2

File tree

5 files changed

+54
-36
lines changed

5 files changed

+54
-36
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,6 @@ abstract class RDD[T: ClassTag](
389389
def takeSample(withReplacement: Boolean,
390390
num: Int,
391391
seed: Long = Utils.random.nextLong): Array[T] = {
392-
var fraction = 0.0
393-
var total = 0
394392
val numStDev = 10.0
395393
val initialCount = this.count()
396394

@@ -407,27 +405,30 @@ abstract class RDD[T: ClassTag](
407405
"sampling without replacement")
408406
}
409407

410-
if (initialCount > Int.MaxValue - 1) {
411-
val maxSelected = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
412-
if (num > maxSelected) {
413-
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
414-
s"$numStDev * math.sqrt(Int.MaxValue)")
415-
}
408+
val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
409+
if (num > maxSampleSize) {
410+
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
411+
s"$numStDev * math.sqrt(Int.MaxValue)")
416412
}
417413

418-
fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, withReplacement)
419-
total = num
414+
val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
415+
withReplacement)
420416

421417
val rand = new Random(seed)
422418
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
423419

424420
// If the first sample didn't turn out large enough, keep trying to take samples;
425421
// this shouldn't happen often because we use a big multiplier for the initial size
426-
while (samples.length < total) {
422+
var numIters = 0
423+
while (samples.length < num) {
424+
if (numIters > 0) {
425+
logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters")
426+
}
427427
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
428+
numIters += 1
428429
}
429430

430-
Utils.randomizeInPlace(samples, rand).take(total)
431+
Utils.randomizeInPlace(samples, rand).take(num)
431432
}
432433

433434
/**

core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ package org.apache.spark.util.random
2020
private[spark] object SamplingUtils {
2121

2222
/**
23+
* Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
24+
* the time.
25+
*
26+
* How the sampling rate is determined:
2327
* Let p = num / total, where num is the sample size and total is the total number of
2428
* datapoints in the RDD. We're trying to compute q > p such that
2529
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,8 +547,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {
547547
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
548548
}
549549
for (seed <- 1 to 5) {
550-
val sample = data.takeSample(withReplacement=true, 2*n, seed)
551-
assert(sample.size === 2*n) // Got exactly 200 elements
550+
val sample = data.takeSample(withReplacement=true, 2 * n, seed)
551+
assert(sample.size === 2 * n) // Got exactly 200 elements
552552
// Chance of getting all distinct elements is still quite low, so test we got < 100
553553
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
554554
}

core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ package org.apache.spark.util.random
2020
import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
2121
import org.scalatest.FunSuite
2222

23-
class SamplingUtilsSuite extends FunSuite{
23+
class SamplingUtilsSuite extends FunSuite {
2424

2525
test("computeFraction") {
26-
// test that the computed fraction guarantees enough datapoints
26+
// test that the computed fraction guarantees enough data points
2727
// in the sample with a failure rate <= 0.0001
2828
val n = 100000
2929

python/pyspark/rdd.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
# limitations under the License.
1616
#
1717

18-
from base64 import standard_b64encode as b64enc
19-
import copy
2018
from collections import defaultdict
2119
from collections import namedtuple
2220
from itertools import chain, ifilter, imap
@@ -364,8 +362,8 @@ def takeSample(self, withReplacement, num, seed=None):
364362
[4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
365363
"""
366364

367-
fraction = 0.0
368-
total = 0
365+
#TODO remove
366+
logging.basicConfig(level=logging.INFO)
369367
numStDev = 10.0
370368
initialCount = self.count()
371369

@@ -378,38 +376,53 @@ def takeSample(self, withReplacement, num, seed=None):
378376
if (not withReplacement) and num > initialCount:
379377
raise ValueError
380378

381-
if initialCount > sys.maxint - 1:
382-
maxSelected = sys.maxint - int(numStDev * sqrt(sys.maxint))
383-
if num > maxSelected:
384-
raise ValueError
385-
386-
fraction = self._computeFraction(num, initialCount, withReplacement)
387-
total = num
379+
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
380+
if num > maxSampleSize:
381+
raise ValueError
388382

383+
fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement)
384+
389385
samples = self.sample(withReplacement, fraction, seed).collect()
390386

391387
# If the first sample didn't turn out large enough, keep trying to take samples;
392388
# this shouldn't happen often because we use a big multiplier for their initial size.
393389
# See: scala/spark/RDD.scala
394390
rand = Random(seed)
395-
while len(samples) < total:
391+
while len(samples) < num:
396392
samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect()
397393

398394
sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint))
399395
sampler.shuffle(samples)
400-
return samples[0:total]
401-
402-
def _computeFraction(self, num, total, withReplacement):
403-
fraction = float(num)/total
396+
return samples[0:num]
397+
398+
@staticmethod
399+
def _computeFractionForSampleSize(sampleSizeLowerBound, total, withReplacement):
400+
"""
401+
Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
402+
the time.
403+
404+
How the sampling rate is determined:
405+
Let p = num / total, where num is the sample size and total is the total number of
406+
datapoints in the RDD. We're trying to compute q > p such that
407+
- when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
408+
where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to
409+
total), i.e. the failure rate of not having a sufficiently large sample < 0.0001.
410+
Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
411+
num > 12, but we need a slightly larger q (9 empirically determined).
412+
- when sampling without replacement, we're drawing each datapoint with prob_i
413+
~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
414+
rate, where success rate is defined the same as in sampling with replacement.
415+
"""
416+
fraction = float(sampleSizeLowerBound) / total
404417
if withReplacement:
405418
numStDev = 5
406-
if (num < 12):
419+
if (sampleSizeLowerBound < 12):
407420
numStDev = 9
408-
return fraction + numStDev * sqrt(fraction/total)
421+
return fraction + numStDev * sqrt(fraction / total)
409422
else:
410423
delta = 0.00005
411-
gamma = - log(delta)/total
412-
return min(1, fraction + gamma + sqrt(gamma * gamma + 2* gamma * fraction))
424+
gamma = - log(delta) / total
425+
return min(1, fraction + gamma + sqrt(gamma * gamma + 2 * gamma * fraction))
413426

414427
def union(self, other):
415428
"""

0 commit comments

Comments
 (0)