@@ -59,4 +59,60 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
59
59
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
60
60
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
61
61
}
62
+
63
+ test(" binary evaluation metrics for All Positive RDD" ) {
64
+ val scoreAndLabels = sc.parallelize(Seq ((0.5 , 1.0 )), 2 )
65
+ val metrics : BinaryClassificationMetrics = new BinaryClassificationMetrics (scoreAndLabels)
66
+
67
+ val threshold = Seq (0.5 )
68
+ val precision = Seq (1.0 )
69
+ val recall = Seq (1.0 )
70
+ val fpr = Seq (0.0 )
71
+ val rocCurve = Seq ((0.0 , 0.0 )) ++ fpr.zip(recall) ++ Seq ((1.0 , 1.0 ))
72
+ val pr = recall.zip(precision)
73
+ val prCurve = Seq ((0.0 , 1.0 )) ++ pr
74
+ val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
75
+ val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
76
+
77
+ assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
78
+ assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
79
+ assert(metrics.areaUnderROC() ~== AreaUnderCurve .of(rocCurve) absTol 1E-5 )
80
+ assert(metrics.pr().collect().zip(prCurve).forall(cond2))
81
+ assert(metrics.areaUnderPR() ~== AreaUnderCurve .of(prCurve) absTol 1E-5 )
82
+ assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
83
+ assert(metrics.fMeasureByThreshold(2.0 ).collect().zip(threshold.zip(f2)).forall(cond2))
84
+ assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
85
+ assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
86
+ }
87
+
88
+ test(" binary evaluation metrics for All Negative RDD" ) {
89
+ val scoreAndLabels = sc.parallelize(Seq ((0.5 , 0.0 )), 2 )
90
+ val metrics : BinaryClassificationMetrics = new BinaryClassificationMetrics (scoreAndLabels)
91
+
92
+ val threshold = Seq (0.5 )
93
+ val precision = Seq (0.0 )
94
+ val recall = Seq (0.0 )
95
+ val fpr = Seq (1.0 )
96
+ val rocCurve = Seq ((0.0 , 0.0 )) ++ fpr.zip(recall) ++ Seq ((1.0 , 1.0 ))
97
+ val pr = recall.zip(precision)
98
+ val prCurve = Seq ((0.0 , 1.0 )) ++ pr
99
+ val f1 = pr.map {
100
+ case (0 ,0 ) => 0.0
101
+ case (r, p) => 2.0 * (p * r) / (p + r)
102
+ }
103
+ val f2 = pr.map {
104
+ case (0 ,0 ) => 0.0
105
+ case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
106
+ }
107
+
108
+ assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
109
+ assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
110
+ assert(metrics.areaUnderROC() ~== AreaUnderCurve .of(rocCurve) absTol 1E-5 )
111
+ assert(metrics.pr().collect().zip(prCurve).forall(cond2))
112
+ assert(metrics.areaUnderPR() ~== AreaUnderCurve .of(prCurve) absTol 1E-5 )
113
+ assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
114
+ assert(metrics.fMeasureByThreshold(2.0 ).collect().zip(threshold.zip(f2)).forall(cond2))
115
+ assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
116
+ assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
117
+ }
62
118
}
0 commit comments