Skip to content

Commit 221ebce

Browse files
committed
add a new test to sliding
1 parent a920865 commit 221ebce

File tree

3 files changed

+30
-14
lines changed

3 files changed

+30
-14
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,22 @@
1818
package org.apache.spark.mllib.evaluation
1919

2020
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.mllib.rdd.RDDFunctions._
2122

2223
/**
2324
* Computes the area under the curve (AUC) using the trapezoidal rule.
2425
*/
2526
object AreaUnderCurve {
2627

27-
private def trapezoid(points: Array[(Double, Double)]): Double = {
28+
/**
29+
* Uses the trapezoidal rule to compute the area under the line connecting the two input points.
30+
* @param points two 2D points stored in Seq
31+
*/
32+
private def trapezoid(points: Seq[(Double, Double)]): Double = {
2833
require(points.length == 2)
29-
(points(1)._1 - points(0)._1) * (points(1)._2 + points(0)._2 ) / 2.0
34+
val x = points.head
35+
val y = points.last
36+
(y._1 - x._1) * (y._2 + x._2) / 2.0
3037
}
3138

3239
/**
@@ -36,20 +43,20 @@ object AreaUnderCurve {
3643
*/
3744
def of(curve: RDD[(Double, Double)]): Double = {
3845
curve.sliding(2).aggregate(0.0)(
39-
seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points),
40-
combOp = (_ + _)
46+
seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
47+
combOp = _ + _
4148
)
4249
}
4350

4451
/**
4552
* Returns the area under the given curve.
4653
*
47-
* @param curve an iterable of ordered 2D points stored in pairs representing a curve
54+
* @param curve an iterator over ordered 2D points stored in pairs representing a curve
4855
*/
49-
def of(curve: Iterable[(Double, Double)]): Double = {
50-
curve.sliding(2).map(_.toArray).filter(_.size == 2).aggregate(0.0)(
51-
seqop = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points),
52-
combop = (_ + _)
56+
def of(curve: Iterator[(Double, Double)]): Double = {
57+
curve.sliding(2).withPartial(false).aggregate(0.0)(
58+
seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
59+
combop = _ + _
5360
)
5461
}
5562
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,21 @@ class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {
2626
test("auc computation") {
2727
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
2828
val auc = 4.0
29-
assert(AreaUnderCurve.of(curve) === auc)
29+
assert(AreaUnderCurve.of(curve.toIterator) === auc)
3030
val rddCurve = sc.parallelize(curve, 2)
3131
assert(AreaUnderCurve.of(rddCurve) == auc)
3232
}
3333

3434
test("auc of an empty curve") {
3535
val curve = Seq.empty[(Double, Double)]
36-
assert(AreaUnderCurve.of(curve) === 0.0)
36+
assert(AreaUnderCurve.of(curve.toIterator) === 0.0)
3737
val rddCurve = sc.parallelize(curve, 2)
3838
assert(AreaUnderCurve.of(rddCurve) === 0.0)
3939
}
4040

4141
test("auc of a curve with a single point") {
4242
val curve = Seq((1.0, 1.0))
43-
assert(AreaUnderCurve.of(curve) === 0.0)
43+
assert(AreaUnderCurve.of(curve.toIterator) === 0.0)
4444
val rddCurve = sc.parallelize(curve, 2)
4545
assert(AreaUnderCurve.of(rddCurve) === 0.0)
4646
}

mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,21 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext {
2929
for (numPartitions <- 1 to 8) {
3030
val rdd = sc.parallelize(data, numPartitions)
3131
for (windowSize <- 1 to 6) {
32-
val slided = rdd.sliding(windowSize).collect().map(_.toList).toList
32+
val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList
3333
val expected = data.sliding(windowSize).map(_.toList).toList
34-
assert(slided === expected)
34+
assert(sliding === expected)
3535
}
3636
assert(rdd.sliding(7).collect().isEmpty,
3737
"Should return an empty RDD if the window size is greater than the number of items.")
3838
}
3939
}
40+
41+
test("sliding with empty partitions") {
42+
val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7))
43+
val rdd = sc.parallelize(data, data.length).flatMap(s => s)
44+
assert(rdd.partitions.size === data.length)
45+
val sliding = rdd.sliding(3)
46+
val expected = data.flatMap(x => x).sliding(3).toList
47+
assert(sliding.collect().toList === expected)
48+
}
4049
}

0 commit comments

Comments
 (0)