@@ -30,31 +30,36 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
30
30
(x._1._1 ~= x._2._1 absTol 1E-5 ) && (x._1._2 ~= x._2._2 absTol 1E-5 )
31
31
32
32
private def assertSequencesMatch (left : Seq [Double ], right : Seq [Double ]): Unit = {
33
- assert(left.zip(right).forall(areWithinEpsilon))
33
+ assert(left.zip(right).forall(areWithinEpsilon))
34
34
}
35
35
36
- private def assertTupleSequencesMatch (left : Seq [(Double , Double )], right : Seq [(Double , Double )]): Unit = {
36
+ private def assertTupleSequencesMatch (left : Seq [(Double , Double )],
37
+ right : Seq [(Double , Double )]): Unit = {
37
38
assert(left.zip(right).forall(pairsWithinEpsilon))
38
39
}
39
40
40
41
private def validateMetrics (metrics : BinaryClassificationMetrics ,
41
- expectedThresholds : Seq [Double ],
42
- expectedROCCurve : Seq [(Double , Double )],
43
- expectedPRCurve : Seq [(Double , Double )],
44
- expectedFMeasures1 : Seq [Double ],
45
- expectedFmeasures2 : Seq [Double ],
46
- expectedPrecisions : Seq [Double ],
47
- expectedRecalls : Seq [Double ]) = {
42
+ expectedThresholds : Seq [Double ],
43
+ expectedROCCurve : Seq [(Double , Double )],
44
+ expectedPRCurve : Seq [(Double , Double )],
45
+ expectedFMeasures1 : Seq [Double ],
46
+ expectedFmeasures2 : Seq [Double ],
47
+ expectedPrecisions : Seq [Double ],
48
+ expectedRecalls : Seq [Double ]) = {
48
49
49
50
assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds)
50
51
assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
51
52
assert(metrics.areaUnderROC() ~== AreaUnderCurve .of(expectedROCCurve) absTol 1E-5 )
52
53
assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve)
53
54
assert(metrics.areaUnderPR() ~== AreaUnderCurve .of(expectedPRCurve) absTol 1E-5 )
54
- assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), expectedThresholds.zip(expectedFMeasures1))
55
- assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0 ).collect(), expectedThresholds.zip(expectedFmeasures2))
56
- assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), expectedThresholds.zip(expectedPrecisions))
57
- assertTupleSequencesMatch(metrics.recallByThreshold().collect(), expectedThresholds.zip(expectedRecalls))
55
+ assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(),
56
+ expectedThresholds.zip(expectedFMeasures1))
57
+ assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0 ).collect(),
58
+ expectedThresholds.zip(expectedFmeasures2))
59
+ assertTupleSequencesMatch(metrics.precisionByThreshold().collect(),
60
+ expectedThresholds.zip(expectedPrecisions))
61
+ assertTupleSequencesMatch(metrics.recallByThreshold().collect(),
62
+ expectedThresholds.zip(expectedRecalls))
58
63
}
59
64
60
65
test(" binary evaluation metrics" ) {
@@ -80,9 +85,9 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
80
85
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
81
86
}
82
87
83
- test(" binary evaluation metrics for All Positive RDD " ) {
88
+ test(" binary evaluation metrics for RDD where all examples have positive label " ) {
84
89
val scoreAndLabels = sc.parallelize(Seq ((0.5 , 1.0 ), (0.5 , 1.0 )), 2 )
85
- val metrics : BinaryClassificationMetrics = new BinaryClassificationMetrics (scoreAndLabels)
90
+ val metrics = new BinaryClassificationMetrics (scoreAndLabels)
86
91
87
92
val thresholds = Seq (0.5 )
88
93
val precisions = Seq (1.0 )
@@ -97,9 +102,9 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
97
102
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
98
103
}
99
104
100
- test(" binary evaluation metrics for All Negative RDD " ) {
105
+ test(" binary evaluation metrics for RDD where all examples have negative label " ) {
101
106
val scoreAndLabels = sc.parallelize(Seq ((0.5 , 0.0 ), (0.5 , 0.0 )), 2 )
102
- val metrics : BinaryClassificationMetrics = new BinaryClassificationMetrics (scoreAndLabels)
107
+ val metrics = new BinaryClassificationMetrics (scoreAndLabels)
103
108
104
109
val thresholds = Seq (0.5 )
105
110
val precisions = Seq (0.0 )
0 commit comments