@@ -24,11 +24,19 @@ import org.apache.spark.mllib.util.TestingUtils._
24
24
25
25
class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
26
26
27
- def cond1 (x : (Double , Double )): Boolean = x._1 ~= (x._2) absTol 1E-5
27
+ def areWithinEpsilon (x : (Double , Double )): Boolean = x._1 ~= (x._2) absTol 1E-5
28
28
29
- def cond2 (x : ((Double , Double ), (Double , Double ))): Boolean =
29
+ def pairsWithinEpsilon (x : ((Double , Double ), (Double , Double ))): Boolean =
30
30
(x._1._1 ~= x._2._1 absTol 1E-5 ) && (x._1._2 ~= x._2._2 absTol 1E-5 )
31
31
32
+ private def assertSequencesMatch (left : Seq [Double ], right : Seq [Double ]): Unit = {
33
+ assert(left.zip(right).forall(areWithinEpsilon))
34
+ }
35
+
36
+ private def assertTupleSequencesMatch (left : Seq [(Double , Double )], right : Seq [(Double , Double )]): Unit = {
37
+ assert(left.zip(right).forall(pairsWithinEpsilon))
38
+ }
39
+
32
40
test(" binary evaluation metrics" ) {
33
41
val scoreAndLabels = sc.parallelize(
34
42
Seq ((0.1 , 0.0 ), (0.1 , 1.0 ), (0.4 , 0.0 ), (0.6 , 0.0 ), (0.6 , 1.0 ), (0.6 , 1.0 ), (0.8 , 1.0 )), 2 )
@@ -49,15 +57,15 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
49
57
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
50
58
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
51
59
52
- assert (metrics.thresholds().collect().zip( threshold).forall(cond1) )
53
- assert (metrics.roc().collect().zip( rocCurve).forall(cond2) )
60
+ assertSequencesMatch (metrics.thresholds().collect(), threshold)
61
+ assertTupleSequencesMatch (metrics.roc().collect(), rocCurve)
54
62
assert(metrics.areaUnderROC() ~== AreaUnderCurve .of(rocCurve) absTol 1E-5 )
55
- assert (metrics.pr().collect().zip( prCurve).forall(cond2) )
63
+ assertTupleSequencesMatch (metrics.pr().collect(), prCurve)
56
64
assert(metrics.areaUnderPR() ~== AreaUnderCurve .of(prCurve) absTol 1E-5 )
57
- assert (metrics.fMeasureByThreshold().collect().zip( threshold.zip(f1)).forall(cond2 ))
58
- assert (metrics.fMeasureByThreshold(2.0 ).collect().zip( threshold.zip(f2)).forall(cond2 ))
59
- assert (metrics.precisionByThreshold().collect().zip( threshold.zip(precision)).forall(cond2 ))
60
- assert (metrics.recallByThreshold().collect().zip( threshold.zip(recall)).forall(cond2 ))
65
+ assertTupleSequencesMatch (metrics.fMeasureByThreshold().collect(), threshold.zip(f1))
66
+ assertTupleSequencesMatch (metrics.fMeasureByThreshold(2.0 ).collect(), threshold.zip(f2))
67
+ assertTupleSequencesMatch (metrics.precisionByThreshold().collect(), threshold.zip(precision))
68
+ assertTupleSequencesMatch (metrics.recallByThreshold().collect(), threshold.zip(recall))
61
69
}
62
70
63
71
test(" binary evaluation metrics for All Positive RDD" ) {
@@ -74,15 +82,15 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
74
82
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
75
83
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
76
84
77
- assert (metrics.thresholds().collect().zip( threshold).forall(cond1) )
78
- assert (metrics.roc().collect().zip( rocCurve).forall(cond2) )
85
+ assertSequencesMatch (metrics.thresholds().collect(), threshold)
86
+ assertTupleSequencesMatch (metrics.roc().collect(), rocCurve)
79
87
assert(metrics.areaUnderROC() ~== AreaUnderCurve .of(rocCurve) absTol 1E-5 )
80
- assert (metrics.pr().collect().zip( prCurve).forall(cond2) )
88
+ assertTupleSequencesMatch (metrics.pr().collect(), prCurve)
81
89
assert(metrics.areaUnderPR() ~== AreaUnderCurve .of(prCurve) absTol 1E-5 )
82
- assert (metrics.fMeasureByThreshold().collect().zip( threshold.zip(f1)).forall(cond2 ))
83
- assert (metrics.fMeasureByThreshold(2.0 ).collect().zip( threshold.zip(f2)).forall(cond2 ))
84
- assert (metrics.precisionByThreshold().collect().zip( threshold.zip(precision)).forall(cond2 ))
85
- assert (metrics.recallByThreshold().collect().zip( threshold.zip(recall)).forall(cond2 ))
90
+ assertTupleSequencesMatch (metrics.fMeasureByThreshold().collect(), threshold.zip(f1))
91
+ assertTupleSequencesMatch (metrics.fMeasureByThreshold(2.0 ).collect(), threshold.zip(f2))
92
+ assertTupleSequencesMatch (metrics.precisionByThreshold().collect(), threshold.zip(precision))
93
+ assertTupleSequencesMatch (metrics.recallByThreshold().collect(), threshold.zip(recall))
86
94
}
87
95
88
96
test(" binary evaluation metrics for All Negative RDD" ) {
@@ -105,14 +113,14 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
105
113
case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
106
114
}
107
115
108
- assert (metrics.thresholds().collect().zip( threshold).forall(cond1) )
109
- assert (metrics.roc().collect().zip( rocCurve).forall(cond2) )
116
+ assertSequencesMatch (metrics.thresholds().collect(), threshold)
117
+ assertTupleSequencesMatch (metrics.roc().collect(), rocCurve)
110
118
assert(metrics.areaUnderROC() ~== AreaUnderCurve .of(rocCurve) absTol 1E-5 )
111
- assert (metrics.pr().collect().zip( prCurve).forall(cond2) )
119
+ assertTupleSequencesMatch (metrics.pr().collect(), prCurve)
112
120
assert(metrics.areaUnderPR() ~== AreaUnderCurve .of(prCurve) absTol 1E-5 )
113
- assert (metrics.fMeasureByThreshold().collect().zip( threshold.zip(f1)).forall(cond2 ))
114
- assert (metrics.fMeasureByThreshold(2.0 ).collect().zip( threshold.zip(f2)).forall(cond2 ))
115
- assert (metrics.precisionByThreshold().collect().zip( threshold.zip(precision)).forall(cond2 ))
116
- assert (metrics.recallByThreshold().collect().zip( threshold.zip(recall)).forall(cond2 ))
121
+ assertTupleSequencesMatch (metrics.fMeasureByThreshold().collect(), threshold.zip(f1))
122
+ assertTupleSequencesMatch (metrics.fMeasureByThreshold(2.0 ).collect(), threshold.zip(f2))
123
+ assertTupleSequencesMatch (metrics.precisionByThreshold().collect(), threshold.zip(precision))
124
+ assertTupleSequencesMatch (metrics.recallByThreshold().collect(), threshold.zip(recall))
117
125
}
118
126
}
0 commit comments