-
Notifications
You must be signed in to change notification settings - Fork 28.6k
SPARK-1310: Start adding k-fold cross validation to MLLib [adds kFold to MLUtils & fixes bug in BernoulliSampler] #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a751ec6
08f8e4d
c0b7fa4
dd0b737
264502a
91eae64
b78804e
e8741a7
5a33f1d
163c5b1
bb5fa56
7ebe4d5
c5b723f
e187e35
c702a96
2cb90b3
150889c
90896c7
7157ae9
6ddbf05
e84f2fc
208db9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,11 +17,16 @@ | |
|
||
package org.apache.spark.mllib.util | ||
|
||
import scala.reflect.ClassTag | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. organize imports |
||
|
||
import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} | ||
|
||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.SparkContext | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.rdd.PartitionwiseSampledRDD | ||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.util.random.BernoulliSampler | ||
import org.apache.spark.mllib.regression.LabeledPoint | ||
import org.apache.spark.mllib.linalg.Vectors | ||
|
||
|
@@ -157,6 +162,22 @@ object MLUtils { | |
dataStr.saveAsTextFile(dir) | ||
} | ||
|
||
/** | ||
* Return a k element array of pairs of RDDs with the first element of each pair | ||
* containing the training data, a complement of the validation data and the second | ||
* element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. | ||
*/ | ||
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is natural to have |
||
val numFoldsF = numFolds.toFloat | ||
(1 to numFolds).map { fold => | ||
val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, | ||
complement = false) | ||
val validation = new PartitionwiseSampledRDD(rdd, sampler, seed) | ||
val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed) | ||
(training, validation) | ||
}.toArray | ||
} | ||
|
||
/** | ||
* Returns the squared Euclidean distance between two vectors. The following formula will be used | ||
* if it does not introduce too much numerical error: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,9 @@ package org.apache.spark.mllib.util | |
|
||
import java.io.File | ||
|
||
import scala.math | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. insert an empty line between java imports and scala imports |
||
import scala.util.Random | ||
|
||
import org.scalatest.FunSuite | ||
|
||
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm, | ||
|
@@ -93,4 +96,40 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { | |
case t: Throwable => | ||
} | ||
} | ||
|
||
test("kFold") { | ||
val data = sc.parallelize(1 to 100, 2) | ||
val collectedData = data.collect().sorted | ||
val twoFoldedRdd = MLUtils.kFold(data, 2, 1) | ||
assert(twoFoldedRdd(0)._1.collect().sorted === twoFoldedRdd(1)._2.collect().sorted) | ||
assert(twoFoldedRdd(0)._2.collect().sorted === twoFoldedRdd(1)._1.collect().sorted) | ||
for (folds <- 2 to 10) { | ||
for (seed <- 1 to 5) { | ||
val foldedRdds = MLUtils.kFold(data, folds, seed) | ||
assert(foldedRdds.size === folds) | ||
foldedRdds.map { case (training, validation) => | ||
val result = validation.union(training).collect().sorted | ||
val validationSize = validation.collect().size.toFloat | ||
assert(validationSize > 0, "empty validation data") | ||
val p = 1 / folds.toFloat | ||
// Within 3 standard deviations of the mean | ||
val range = 3 * math.sqrt(100 * p * (1 - p)) | ||
val expected = 100 * p | ||
val lowerBound = expected - range | ||
val upperBound = expected + range | ||
assert(validationSize > lowerBound, | ||
s"Validation data ($validationSize) smaller than expected ($lowerBound)" ) | ||
assert(validationSize < upperBound, | ||
s"Validation data ($validationSize) larger than expected ($upperBound)" ) | ||
assert(training.collect().size > 0, "empty training data") | ||
assert(result === collectedData, | ||
"Each training+validation set combined should contain all of the data.") | ||
} | ||
// K fold cross validation should only have each element in the validation set exactly once | ||
assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted === | ||
data.collect().sorted) | ||
} | ||
} | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add return type.