Skip to content

Commit 4d2f79a

Browse files
author
Andrew Bullen
committed
[SPARK-4256] Refactor classification metrics tests - extract comparison functions in test
1 parent f411e70 commit 4d2f79a

File tree

1 file changed

+31
-23
lines changed

1 file changed

+31
-23
lines changed

mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,19 @@ import org.apache.spark.mllib.util.TestingUtils._
2424

2525
class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
2626

27-
def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
27+
def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
2828

29-
def cond2(x: ((Double, Double), (Double, Double))): Boolean =
29+
def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean =
3030
(x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)
3131

32+
private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = {
33+
assert(left.zip(right).forall(areWithinEpsilon))
34+
}
35+
36+
private def assertTupleSequencesMatch(left: Seq[(Double, Double)], right: Seq[(Double, Double)]): Unit = {
37+
assert(left.zip(right).forall(pairsWithinEpsilon))
38+
}
39+
3240
test("binary evaluation metrics") {
3341
val scoreAndLabels = sc.parallelize(
3442
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
@@ -49,15 +57,15 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
4957
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
5058
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
5159

52-
assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
53-
assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
60+
assertSequencesMatch(metrics.thresholds().collect(), threshold)
61+
assertTupleSequencesMatch(metrics.roc().collect(), rocCurve)
5462
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
55-
assert(metrics.pr().collect().zip(prCurve).forall(cond2))
63+
assertTupleSequencesMatch(metrics.pr().collect(), prCurve)
5664
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
57-
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
58-
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
59-
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
60-
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
65+
assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))
66+
assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))
67+
assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), threshold.zip(precision))
68+
assertTupleSequencesMatch(metrics.recallByThreshold().collect(), threshold.zip(recall))
6169
}
6270

6371
test("binary evaluation metrics for All Positive RDD") {
@@ -74,15 +82,15 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
7482
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
7583
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
7684

77-
assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
78-
assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
85+
assertSequencesMatch(metrics.thresholds().collect(), threshold)
86+
assertTupleSequencesMatch(metrics.roc().collect(), rocCurve)
7987
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
80-
assert(metrics.pr().collect().zip(prCurve).forall(cond2))
88+
assertTupleSequencesMatch(metrics.pr().collect(), prCurve)
8189
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
82-
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
83-
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
84-
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
85-
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
90+
assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))
91+
assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))
92+
assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), threshold.zip(precision))
93+
assertTupleSequencesMatch(metrics.recallByThreshold().collect(), threshold.zip(recall))
8694
}
8795

8896
test("binary evaluation metrics for All Negative RDD") {
@@ -105,14 +113,14 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
105113
case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
106114
}
107115

108-
assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
109-
assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
116+
assertSequencesMatch(metrics.thresholds().collect(), threshold)
117+
assertTupleSequencesMatch(metrics.roc().collect(), rocCurve)
110118
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
111-
assert(metrics.pr().collect().zip(prCurve).forall(cond2))
119+
assertTupleSequencesMatch(metrics.pr().collect(), prCurve)
112120
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
113-
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
114-
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
115-
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
116-
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
121+
assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), threshold.zip(f1))
122+
assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2))
123+
assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), threshold.zip(precision))
124+
assertTupleSequencesMatch(metrics.recallByThreshold().collect(), threshold.zip(recall))
117125
}
118126
}

0 commit comments

Comments
 (0)