Skip to content

Commit ffea61a

Browse files
committed
SPARK-1939: Refactor takeSample method in RDD
Reviewer comments addressed: - commons-math3 is now a test-only dependency. bumped up to v3.3 - comments added to explain what computeFraction is doing - fixed the unit for computeFraction to use BinomialDitro for without replacement sampling - stylistic fixes
1 parent 1441977 commit ffea61a

File tree

6 files changed

+44
-32
lines changed

6 files changed

+44
-32
lines changed

core/pom.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
<dependency>
7171
<groupId>org.apache.commons</groupId>
7272
<artifactId>commons-math3</artifactId>
73+
<scope>test</scope>
7374
</dependency>
7475
<dependency>
7576
<groupId>com.google.code.findbugs</groupId>

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ abstract class RDD[T: ClassTag](
388388
* @return sample of specified size in an array
389389
*/
390390
def takeSample(withReplacement: Boolean,
391-
num: Int,
392-
seed: Long = Utils.random.nextLong): Array[T] = {
391+
num: Int,
392+
seed: Long = Utils.random.nextLong): Array[T] = {
393393
var fraction = 0.0
394394
var total = 0
395395
val multiplier = 3.0
@@ -431,18 +431,31 @@ abstract class RDD[T: ClassTag](
431431
Utils.randomizeInPlace(samples, rand).take(total)
432432
}
433433

434-
private[spark] def computeFraction(num: Int, total: Long, withReplacement: Boolean) : Double = {
434+
/**
435+
* Let p = num / total, where num is the sample size and total is the total number of
436+
* datapoints in the RDD. We're trying to compute q > p such that
437+
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
438+
* where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
439+
* i.e. the failure rate of not having a sufficiently large sample < 0.0001.
440+
* Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
441+
* num > 12, but we need a slightly larger q (9 empirically determined).
442+
* - when sampling without replacement, we're drawing each datapoint with prob_i
443+
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
444+
* rate, where success rate is defined the same as in sampling with replacement.
445+
*
446+
* @param num sample size
447+
* @param total size of RDD
448+
* @param withReplacement whether sampling with replacement
449+
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
450+
*/
451+
private[rdd] def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = {
435452
val fraction = num.toDouble / total
436453
if (withReplacement) {
437-
var numStDev = 5
438-
if (num < 12) {
439-
// special case to guarantee sample size for small s
440-
numStDev = 9
441-
}
454+
val numStDev = if (num < 12) 9 else 5
442455
fraction + numStDev * math.sqrt(fraction / total)
443456
} else {
444-
val delta = 0.00005
445-
val gamma = - math.log(delta)/total
457+
val delta = 1e-4
458+
val gamma = - math.log(delta) / total
446459
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
447460
}
448461
}

core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
7070
}
7171

7272
/**
73-
* Return a sampler which is the complement of the range specified of the current sampler.
73+
* Return a sampler that is the complement of the range specified of the current sampler.
7474
*/
7575
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
7676

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ import scala.reflect.ClassTag
2222

2323
import org.scalatest.FunSuite
2424

25+
import org.apache.commons.math3.distribution.BinomialDistribution
2526
import org.apache.commons.math3.distribution.PoissonDistribution
27+
2628
import org.apache.spark._
2729
import org.apache.spark.SparkContext._
2830
import org.apache.spark.rdd._
@@ -496,29 +498,25 @@ class RDDSuite extends FunSuite with SharedSparkContext {
496498
}
497499

498500
test("computeFraction") {
499-
// test that the computed fraction guarantees enough datapoints in the sample with a failure rate <= 0.0001
501+
// test that the computed fraction guarantees enough datapoints
502+
// in the sample with a failure rate <= 0.0001
500503
val data = new EmptyRDD[Int](sc)
501504
val n = 100000
502505

503506
for (s <- 1 to 15) {
504507
val frac = data.computeFraction(s, n, true)
505-
val qpois = new PoissonDistribution(frac * n)
506-
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
508+
val poisson = new PoissonDistribution(frac * n)
509+
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
507510
}
508-
for (s <- 1 to 15) {
509-
val frac = data.computeFraction(s, n, false)
510-
val qpois = new PoissonDistribution(frac * n)
511-
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
512-
}
513-
for (s <- List(1, 10, 100, 1000)) {
511+
for (s <- List(20, 100, 1000)) {
514512
val frac = data.computeFraction(s, n, true)
515-
val qpois = new PoissonDistribution(frac * n)
516-
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
513+
val poisson = new PoissonDistribution(frac * n)
514+
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
517515
}
518516
for (s <- List(1, 10, 100, 1000)) {
519517
val frac = data.computeFraction(s, n, false)
520-
val qpois = new PoissonDistribution(frac * n)
521-
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
518+
val binomial = new BinomialDistribution(n, frac)
519+
assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
522520
}
523521
}
524522

@@ -530,37 +528,37 @@ class RDDSuite extends FunSuite with SharedSparkContext {
530528
val sample = data.takeSample(withReplacement=false, num=num)
531529
assert(sample.size === num) // Got exactly num elements
532530
assert(sample.toSet.size === num) // Elements are distinct
533-
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
531+
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
534532
}
535533
for (seed <- 1 to 5) {
536534
val sample = data.takeSample(withReplacement=false, 20, seed)
537535
assert(sample.size === 20) // Got exactly 20 elements
538536
assert(sample.toSet.size === 20) // Elements are distinct
539-
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
537+
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
540538
}
541539
for (seed <- 1 to 5) {
542540
val sample = data.takeSample(withReplacement=false, 100, seed)
543541
assert(sample.size === 100) // Got only 100 elements
544542
assert(sample.toSet.size === 100) // Elements are distinct
545-
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
543+
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
546544
}
547545
for (seed <- 1 to 5) {
548546
val sample = data.takeSample(withReplacement=true, 20, seed)
549547
assert(sample.size === 20) // Got exactly 20 elements
550-
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
548+
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
551549
}
552550
{
553551
val sample = data.takeSample(withReplacement=true, num=20)
554552
assert(sample.size === 20) // Got exactly 100 elements
555553
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
556-
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
554+
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
557555
}
558556
{
559557
val sample = data.takeSample(withReplacement=true, num=n)
560558
assert(sample.size === n) // Got exactly 100 elements
561559
// Chance of getting all distinct elements is astronomically low, so test we got < 100
562560
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
563-
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
561+
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
564562
}
565563
for (seed <- 1 to 5) {
566564
val sample = data.takeSample(withReplacement=true, n, seed)

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@
248248
<dependency>
249249
<groupId>org.apache.commons</groupId>
250250
<artifactId>commons-math3</artifactId>
251-
<version>3.2</version>
251+
<version>3.3</version>
252252
</dependency>
253253
<dependency>
254254
<groupId>com.google.code.findbugs</groupId>

project/SparkBuild.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ object SparkBuild extends Build {
331331
libraryDependencies ++= Seq(
332332
"com.google.guava" % "guava" % "14.0.1",
333333
"org.apache.commons" % "commons-lang3" % "3.3.2",
334-
"org.apache.commons" % "commons-math3" % "3.2",
334+
"org.apache.commons" % "commons-math3" % "3.3" % "test",
335335
"com.google.code.findbugs" % "jsr305" % "1.3.9",
336336
"log4j" % "log4j" % "1.2.17",
337337
"org.slf4j" % "slf4j-api" % slf4jVersion,

0 commit comments

Comments
 (0)