@@ -24,39 +24,104 @@ import org.apache.spark.mllib.util.TestingUtils._
24
24
25
25
class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext {
26
26
27
- def cond1 (x : (Double , Double )): Boolean = x._1 ~= (x._2) absTol 1E-5
27
+ private def areWithinEpsilon (x : (Double , Double )): Boolean = x._1 ~= (x._2) absTol 1E-5
28
28
29
- def cond2 (x : ((Double , Double ), (Double , Double ))): Boolean =
29
+ private def pairsWithinEpsilon (x : ((Double , Double ), (Double , Double ))): Boolean =
30
30
(x._1._1 ~= x._2._1 absTol 1E-5 ) && (x._1._2 ~= x._2._2 absTol 1E-5 )
31
31
32
+ private def assertSequencesMatch (left : Seq [Double ], right : Seq [Double ]): Unit = {
33
+ assert(left.zip(right).forall(areWithinEpsilon))
34
+ }
35
+
36
+ private def assertTupleSequencesMatch (left : Seq [(Double , Double )],
37
+ right : Seq [(Double , Double )]): Unit = {
38
+ assert(left.zip(right).forall(pairsWithinEpsilon))
39
+ }
40
+
41
+ private def validateMetrics (metrics : BinaryClassificationMetrics ,
42
+ expectedThresholds : Seq [Double ],
43
+ expectedROCCurve : Seq [(Double , Double )],
44
+ expectedPRCurve : Seq [(Double , Double )],
45
+ expectedFMeasures1 : Seq [Double ],
46
+ expectedFmeasures2 : Seq [Double ],
47
+ expectedPrecisions : Seq [Double ],
48
+ expectedRecalls : Seq [Double ]) = {
49
+
50
+ assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds)
51
+ assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
52
+ assert(metrics.areaUnderROC() ~== AreaUnderCurve .of(expectedROCCurve) absTol 1E-5 )
53
+ assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve)
54
+ assert(metrics.areaUnderPR() ~== AreaUnderCurve .of(expectedPRCurve) absTol 1E-5 )
55
+ assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(),
56
+ expectedThresholds.zip(expectedFMeasures1))
57
+ assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0 ).collect(),
58
+ expectedThresholds.zip(expectedFmeasures2))
59
+ assertTupleSequencesMatch(metrics.precisionByThreshold().collect(),
60
+ expectedThresholds.zip(expectedPrecisions))
61
+ assertTupleSequencesMatch(metrics.recallByThreshold().collect(),
62
+ expectedThresholds.zip(expectedRecalls))
63
+ }
64
+
32
65
test(" binary evaluation metrics" ) {
33
66
val scoreAndLabels = sc.parallelize(
34
67
Seq ((0.1 , 0.0 ), (0.1 , 1.0 ), (0.4 , 0.0 ), (0.6 , 0.0 ), (0.6 , 1.0 ), (0.6 , 1.0 ), (0.8 , 1.0 )), 2 )
35
68
val metrics = new BinaryClassificationMetrics (scoreAndLabels)
36
- val threshold = Seq (0.8 , 0.6 , 0.4 , 0.1 )
69
+ val thresholds = Seq (0.8 , 0.6 , 0.4 , 0.1 )
37
70
val numTruePositives = Seq (1 , 3 , 3 , 4 )
38
71
val numFalsePositives = Seq (0 , 1 , 2 , 3 )
39
72
val numPositives = 4
40
73
val numNegatives = 3
41
- val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
74
+ val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
42
75
t.toDouble / (t + f)
43
76
}
44
- val recall = numTruePositives.map(t => t.toDouble / numPositives)
77
+ val recalls = numTruePositives.map(t => t.toDouble / numPositives)
45
78
val fpr = numFalsePositives.map(f => f.toDouble / numNegatives)
46
- val rocCurve = Seq ((0.0 , 0.0 )) ++ fpr.zip(recall ) ++ Seq ((1.0 , 1.0 ))
47
- val pr = recall .zip(precision )
79
+ val rocCurve = Seq ((0.0 , 0.0 )) ++ fpr.zip(recalls ) ++ Seq ((1.0 , 1.0 ))
80
+ val pr = recalls .zip(precisions )
48
81
val prCurve = Seq ((0.0 , 1.0 )) ++ pr
49
82
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
50
83
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
51
84
52
- assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
53
- assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
54
- assert(metrics.areaUnderROC() ~== AreaUnderCurve .of(rocCurve) absTol 1E-5 )
55
- assert(metrics.pr().collect().zip(prCurve).forall(cond2))
56
- assert(metrics.areaUnderPR() ~== AreaUnderCurve .of(prCurve) absTol 1E-5 )
57
- assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
58
- assert(metrics.fMeasureByThreshold(2.0 ).collect().zip(threshold.zip(f2)).forall(cond2))
59
- assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
60
- assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
85
+ validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
86
+ }
87
+
88
+ test(" binary evaluation metrics for RDD where all examples have positive label" ) {
89
+ val scoreAndLabels = sc.parallelize(Seq ((0.5 , 1.0 ), (0.5 , 1.0 )), 2 )
90
+ val metrics = new BinaryClassificationMetrics (scoreAndLabels)
91
+
92
+ val thresholds = Seq (0.5 )
93
+ val precisions = Seq (1.0 )
94
+ val recalls = Seq (1.0 )
95
+ val fpr = Seq (0.0 )
96
+ val rocCurve = Seq ((0.0 , 0.0 )) ++ fpr.zip(recalls) ++ Seq ((1.0 , 1.0 ))
97
+ val pr = recalls.zip(precisions)
98
+ val prCurve = Seq ((0.0 , 1.0 )) ++ pr
99
+ val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
100
+ val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
101
+
102
+ validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
103
+ }
104
+
105
+ test(" binary evaluation metrics for RDD where all examples have negative label" ) {
106
+ val scoreAndLabels = sc.parallelize(Seq ((0.5 , 0.0 ), (0.5 , 0.0 )), 2 )
107
+ val metrics = new BinaryClassificationMetrics (scoreAndLabels)
108
+
109
+ val thresholds = Seq (0.5 )
110
+ val precisions = Seq (0.0 )
111
+ val recalls = Seq (0.0 )
112
+ val fpr = Seq (1.0 )
113
+ val rocCurve = Seq ((0.0 , 0.0 )) ++ fpr.zip(recalls) ++ Seq ((1.0 , 1.0 ))
114
+ val pr = recalls.zip(precisions)
115
+ val prCurve = Seq ((0.0 , 1.0 )) ++ pr
116
+ val f1 = pr.map {
117
+ case (0 , 0 ) => 0.0
118
+ case (r, p) => 2.0 * (p * r) / (p + r)
119
+ }
120
+ val f2 = pr.map {
121
+ case (0 , 0 ) => 0.0
122
+ case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
123
+ }
124
+
125
+ validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
61
126
}
62
127
}
0 commit comments