Skip to content

Commit 1de1d70

Browse files
dorxmengxr
authored andcommitted
SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
Modified the takeSample method in RDD to use the ScaSRS sampling technique to improve performance. Added a private method that computes sampling rate > sample_size/total to ensure sufficient sample size with success rate >= 0.9999. Added a unit test for the private method to validate choice of sampling rate. Author: Doris Xin <doris.s.xin@gmail.com> Author: dorx <doris.s.xin@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #916 from dorx/takeSample and squashes the following commits: 5b061ae [Doris Xin] merge master 444e750 [Doris Xin] edge cases 3de882b [dorx] Merge pull request #2 from mengxr/SPARK-1939 82dde31 [Xiangrui Meng] update pyspark's takeSample 48d954d [Doris Xin] remove unused imports from RDDSuite fb1452f [Doris Xin] allowing num to be greater than count in all cases 1481b01 [Doris Xin] washing test tubes and making coffee dc699f3 [Doris Xin] give back imports removed by accident in rdd.py 64e445b [Doris Xin] logwarnning as soon as it enters the while loop 55518ed [Doris Xin] added TODO for logging in rdd.py eff89e2 [Doris Xin] addressed reviewer comments. ecab508 [Doris Xin] "fixed checkstyle violation 0a9b3e3 [Doris Xin] "reviewer comment addressed" f80f270 [Doris Xin] Merge branch 'master' into takeSample ae3ad04 [Doris Xin] fixed edge cases to prevent overflow 065ebcd [Doris Xin] Merge branch 'master' into takeSample 9bdd36e [Doris Xin] Check sample size and move computeFraction e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample 7cab53a [Doris Xin] fixed import bug in rdd.py ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD 1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
1 parent 0154587 commit 1de1d70

File tree

8 files changed

+263
-100
lines changed

8 files changed

+263
-100
lines changed

core/pom.xml

+5
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@
6767
<groupId>org.apache.commons</groupId>
6868
<artifactId>commons-lang3</artifactId>
6969
</dependency>
70+
<dependency>
71+
<groupId>org.apache.commons</groupId>
72+
<artifactId>commons-math3</artifactId>
73+
<scope>test</scope>
74+
</dependency>
7075
<dependency>
7176
<groupId>com.google.code.findbugs</groupId>
7277
<artifactId>jsr305</artifactId>

core/src/main/scala/org/apache/spark/rdd/RDD.scala

+31-21
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ import org.apache.spark.partial.PartialResult
4242
import org.apache.spark.storage.StorageLevel
4343
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
4444
import org.apache.spark.util.collection.OpenHashMap
45-
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler}
45+
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
4646

4747
/**
4848
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -378,46 +378,56 @@ abstract class RDD[T: ClassTag](
378378
}.toArray
379379
}
380380

381-
def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] =
382-
{
383-
var fraction = 0.0
384-
var total = 0
385-
val multiplier = 3.0
386-
val initialCount = this.count()
387-
var maxSelected = 0
381+
/**
382+
* Return a fixed-size sampled subset of this RDD in an array
383+
*
384+
* @param withReplacement whether sampling is done with replacement
385+
* @param num size of the returned sample
386+
* @param seed seed for the random number generator
387+
* @return sample of specified size in an array
388+
*/
389+
def takeSample(withReplacement: Boolean,
390+
num: Int,
391+
seed: Long = Utils.random.nextLong): Array[T] = {
392+
val numStDev = 10.0
388393

389394
if (num < 0) {
390395
throw new IllegalArgumentException("Negative number of elements requested")
396+
} else if (num == 0) {
397+
return new Array[T](0)
391398
}
392399

400+
val initialCount = this.count()
393401
if (initialCount == 0) {
394402
return new Array[T](0)
395403
}
396404

397-
if (initialCount > Integer.MAX_VALUE - 1) {
398-
maxSelected = Integer.MAX_VALUE - 1
399-
} else {
400-
maxSelected = initialCount.toInt
405+
val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
406+
if (num > maxSampleSize) {
407+
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
408+
s"$numStDev * math.sqrt(Int.MaxValue)")
401409
}
402410

403-
if (num > initialCount && !withReplacement) {
404-
total = maxSelected
405-
fraction = multiplier * (maxSelected + 1) / initialCount
406-
} else {
407-
fraction = multiplier * (num + 1) / initialCount
408-
total = num
411+
val rand = new Random(seed)
412+
if (!withReplacement && num >= initialCount) {
413+
return Utils.randomizeInPlace(this.collect(), rand)
409414
}
410415

411-
val rand = new Random(seed)
416+
val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
417+
withReplacement)
418+
412419
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
413420

414421
// If the first sample didn't turn out large enough, keep trying to take samples;
415422
// this shouldn't happen often because we use a big multiplier for the initial size
416-
while (samples.length < total) {
423+
var numIters = 0
424+
while (samples.length < num) {
425+
logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters")
417426
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
427+
numIters += 1
418428
}
419429

420-
Utils.randomizeInPlace(samples, rand).take(total)
430+
Utils.randomizeInPlace(samples, rand).take(num)
421431
}
422432

423433
/**

core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
7070
}
7171

7272
/**
73-
* Return a sampler with is the complement of the range specified of the current sampler.
73+
* Return a sampler that is the complement of the range specified of the current sampler.
7474
*/
7575
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
7676

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.random
19+
20+
private[spark] object SamplingUtils {
21+
22+
/**
23+
* Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
24+
* the time.
25+
*
26+
* How the sampling rate is determined:
27+
* Let p = num / total, where num is the sample size and total is the total number of
28+
* datapoints in the RDD. We're trying to compute q > p such that
29+
* - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
30+
* where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
31+
* i.e. the failure rate of not having a sufficiently large sample < 0.0001.
32+
* Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
33+
* num > 12, but we need a slightly larger q (9 empirically determined).
34+
* - when sampling without replacement, we're drawing each datapoint with prob_i
35+
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
36+
* rate, where success rate is defined the same as in sampling with replacement.
37+
*
38+
* @param sampleSizeLowerBound sample size
39+
* @param total size of RDD
40+
* @param withReplacement whether sampling with replacement
41+
* @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
42+
*/
43+
def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long,
44+
withReplacement: Boolean): Double = {
45+
val fraction = sampleSizeLowerBound.toDouble / total
46+
if (withReplacement) {
47+
val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
48+
fraction + numStDev * math.sqrt(fraction / total)
49+
} else {
50+
val delta = 1e-4
51+
val gamma = - math.log(delta) / total
52+
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
53+
}
54+
}
55+
}

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

+18-17
Original file line numberDiff line numberDiff line change
@@ -505,55 +505,56 @@ class RDDSuite extends FunSuite with SharedSparkContext {
505505
}
506506

507507
test("takeSample") {
508-
val data = sc.parallelize(1 to 100, 2)
508+
val n = 1000000
509+
val data = sc.parallelize(1 to n, 2)
509510

510511
for (num <- List(5, 20, 100)) {
511512
val sample = data.takeSample(withReplacement=false, num=num)
512513
assert(sample.size === num) // Got exactly num elements
513514
assert(sample.toSet.size === num) // Elements are distinct
514-
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
515+
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
515516
}
516517
for (seed <- 1 to 5) {
517518
val sample = data.takeSample(withReplacement=false, 20, seed)
518519
assert(sample.size === 20) // Got exactly 20 elements
519520
assert(sample.toSet.size === 20) // Elements are distinct
520-
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
521+
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
521522
}
522523
for (seed <- 1 to 5) {
523-
val sample = data.takeSample(withReplacement=false, 200, seed)
524+
val sample = data.takeSample(withReplacement=false, 100, seed)
524525
assert(sample.size === 100) // Got only 100 elements
525526
assert(sample.toSet.size === 100) // Elements are distinct
526-
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
527+
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
527528
}
528529
for (seed <- 1 to 5) {
529530
val sample = data.takeSample(withReplacement=true, 20, seed)
530531
assert(sample.size === 20) // Got exactly 20 elements
531-
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
532+
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
532533
}
533534
{
534535
val sample = data.takeSample(withReplacement=true, num=20)
535536
assert(sample.size === 20) // Got exactly 100 elements
536537
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
537-
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
538+
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
538539
}
539540
{
540-
val sample = data.takeSample(withReplacement=true, num=100)
541-
assert(sample.size === 100) // Got exactly 100 elements
541+
val sample = data.takeSample(withReplacement=true, num=n)
542+
assert(sample.size === n) // Got exactly 100 elements
542543
// Chance of getting all distinct elements is astronomically low, so test we got < 100
543-
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
544-
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
544+
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
545+
assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
545546
}
546547
for (seed <- 1 to 5) {
547-
val sample = data.takeSample(withReplacement=true, 100, seed)
548-
assert(sample.size === 100) // Got exactly 100 elements
548+
val sample = data.takeSample(withReplacement=true, n, seed)
549+
assert(sample.size === n) // Got exactly 100 elements
549550
// Chance of getting all distinct elements is astronomically low, so test we got < 100
550-
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
551+
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
551552
}
552553
for (seed <- 1 to 5) {
553-
val sample = data.takeSample(withReplacement=true, 200, seed)
554-
assert(sample.size === 200) // Got exactly 200 elements
554+
val sample = data.takeSample(withReplacement=true, 2 * n, seed)
555+
assert(sample.size === 2 * n) // Got exactly 200 elements
555556
// Chance of getting all distinct elements is still quite low, so test we got < 100
556-
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
557+
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
557558
}
558559
}
559560

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.random
19+
20+
import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
21+
import org.scalatest.FunSuite
22+
23+
class SamplingUtilsSuite extends FunSuite {
24+
25+
test("computeFraction") {
26+
// test that the computed fraction guarantees enough data points
27+
// in the sample with a failure rate <= 0.0001
28+
val n = 100000
29+
30+
for (s <- 1 to 15) {
31+
val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
32+
val poisson = new PoissonDistribution(frac * n)
33+
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
34+
}
35+
for (s <- List(20, 100, 1000)) {
36+
val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
37+
val poisson = new PoissonDistribution(frac * n)
38+
assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
39+
}
40+
for (s <- List(1, 10, 100, 1000)) {
41+
val frac = SamplingUtils.computeFractionForSampleSize(s, n, false)
42+
val binomial = new BinomialDistribution(n, frac)
43+
assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
44+
}
45+
}
46+
}

project/SparkBuild.scala

+1
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ object SparkBuild extends Build {
349349
libraryDependencies ++= Seq(
350350
"com.google.guava" % "guava" % "14.0.1",
351351
"org.apache.commons" % "commons-lang3" % "3.3.2",
352+
"org.apache.commons" % "commons-math3" % "3.3" % "test",
352353
"com.google.code.findbugs" % "jsr305" % "1.3.9",
353354
"log4j" % "log4j" % "1.2.17",
354355
"org.slf4j" % "slf4j-api" % slf4jVersion,

0 commit comments

Comments
 (0)