Skip to content

Commit c0b7fa4

Browse files
committed
Switch FoldedRDD to use BernoulliSampler and PartitionwiseSampledRDD
1 parent 08f8e4d commit c0b7fa4

File tree

2 files changed

+23
-28
lines changed

2 files changed

+23
-28
lines changed

core/src/main/scala/org/apache/spark/rdd/FoldedRDD.scala

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import cern.jet.random.Poisson
2424
import cern.jet.random.engine.DRand
2525

2626
import org.apache.spark.{Partition, TaskContext}
27+
import org.apache.spark.util.random.BernoulliSampler
2728

2829
private[spark]
2930
class FoldedRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
@@ -32,24 +33,10 @@ class FoldedRDDPartition(val prev: Partition, val seed: Int) extends Partition w
3233

3334
class FoldedRDD[T: ClassTag](
3435
prev: RDD[T],
35-
fold: Int,
36-
folds: Int,
36+
fold: Float,
37+
folds: Float,
3738
seed: Int)
38-
extends RDD[T](prev) {
39-
40-
override def getPartitions: Array[Partition] = {
41-
val rg = new Random(seed)
42-
firstParent[T].partitions.map(x => new FoldedRDDPartition(x, rg.nextInt))
43-
}
44-
45-
override def getPreferredLocations(split: Partition): Seq[String] =
46-
firstParent[T].preferredLocations(split.asInstanceOf[FoldedRDDPartition].prev)
47-
48-
override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
49-
val split = splitIn.asInstanceOf[FoldedRDDPartition]
50-
val rand = new Random(split.seed)
51-
firstParent[T].iterator(split.prev, context).filter(x => (rand.nextInt(folds) == fold-1))
52-
}
39+
extends PartitionwiseSampledRDD[T, T](prev, new BernoulliSampler((fold-1)/folds,fold/folds, false), seed) {
5340
}
5441

5542
/**
@@ -58,14 +45,8 @@ class FoldedRDD[T: ClassTag](
5845
*/
5946
class CompositeFoldedRDD[T: ClassTag](
6047
prev: RDD[T],
61-
fold: Int,
62-
folds: Int,
48+
fold: Float,
49+
folds: Float,
6350
seed: Int)
64-
extends FoldedRDD[T](prev, fold, folds, seed) {
65-
66-
override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
67-
val split = splitIn.asInstanceOf[FoldedRDDPartition]
68-
val rand = new Random(split.seed)
69-
firstParent[T].iterator(split.prev, context).filter(x => (rand.nextInt(folds) != fold-1))
70-
}
51+
extends PartitionwiseSampledRDD[T, T](prev, new BernoulliSampler((fold-1)/folds, fold/folds, true), seed) {
7152
}

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,14 +513,28 @@ class RDDSuite extends FunSuite with SharedSparkContext {
513513
}
514514
}
515515

516+
test("FoldedRDD") {
517+
val data = sc.parallelize(1 to 100, 2)
518+
val lowerFoldedRdd = new FoldedRDD(data, 1, 2, 1)
519+
val upperFoldedRdd = new FoldedRDD(data, 2, 2, 1)
520+
val lowerCompositeFoldedRdd = new CompositeFoldedRDD(data, 1, 2, 1)
521+
assert(lowerFoldedRdd.collect().sorted.size == 50)
522+
assert(lowerCompositeFoldedRdd.collect().sorted.size == 50)
523+
assert(lowerFoldedRdd.subtract(lowerCompositeFoldedRdd).collect().sorted ===
524+
lowerFoldedRdd.collect().sorted)
525+
assert(upperFoldedRdd.collect().sorted.size == 50)
526+
}
527+
516528
test("kfoldRdd") {
517529
val data = sc.parallelize(1 to 100, 2)
518-
for (folds <- 1 to 10) {
530+
val collectedData = data.collect().sorted
531+
for (folds <- 2 to 10) {
519532
for (seed <- 1 to 5) {
520533
val foldedRdds = data.kFoldRdds(folds, seed)
521534
assert(foldedRdds.size === folds)
522535
foldedRdds.map{case (test, train) =>
523-
assert(test.union(train).collect().sorted === data.collect().sorted,
536+
val result = test.union(train).collect().sorted
537+
assert(result === collectedData,
524538
"Each training+test set combined contains all of the data")
525539
}
526540
// K fold cross validation should only have each element in the test set exactly once

0 commit comments

Comments
 (0)