Skip to content

Commit d9a09ef

Browse files
author
Andrew Bullen
committed
Make Binary Evaluation Metrics functions defined in cases where there are 0 positive or 0 negative examples.
1 parent c238fb4 commit d9a09ef

File tree

2 files changed

+76
-4
lines changed

2 files changed

+76
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,31 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl
2727
/** Precision. */
2828
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
2929
override def apply(c: BinaryConfusionMatrix): Double =
30-
c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
30+
if (c.numTruePositives + c.numFalsePositives == 0) {
31+
0.0
32+
} else {
33+
c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
34+
}
3135
}
3236

3337
/** False positive rate. */
3438
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
3539
override def apply(c: BinaryConfusionMatrix): Double =
36-
c.numFalsePositives.toDouble / c.numNegatives
40+
if (c.numNegatives == 0) {
41+
0.0
42+
} else {
43+
c.numFalsePositives.toDouble / c.numNegatives
44+
}
3745
}
3846

3947
/** Recall. */
4048
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
4149
override def apply(c: BinaryConfusionMatrix): Double =
42-
c.numTruePositives.toDouble / c.numPositives
50+
if (c.numPositives == 0) {
51+
0.0
52+
} else {
53+
c.numTruePositives.toDouble / c.numPositives
54+
}
4355
}
4456

4557
/**
@@ -52,6 +64,10 @@ private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificati
5264
override def apply(c: BinaryConfusionMatrix): Double = {
5365
val precision = Precision(c)
5466
val recall = Recall(c)
55-
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
67+
if (precision + recall == 0) {
68+
0.0
69+
} else {
70+
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
71+
}
5672
}
5773
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,60 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
5959
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
6060
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
6161
}
62+
63+
test("binary evaluation metrics for All Positive RDD") {
64+
val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0)), 2)
65+
val metrics: BinaryClassificationMetrics = new BinaryClassificationMetrics(scoreAndLabels)
66+
67+
val threshold = Seq(0.5)
68+
val precision = Seq(1.0)
69+
val recall = Seq(1.0)
70+
val fpr = Seq(0.0)
71+
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
72+
val pr = recall.zip(precision)
73+
val prCurve = Seq((0.0, 1.0)) ++ pr
74+
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
75+
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
76+
77+
assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
78+
assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
79+
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
80+
assert(metrics.pr().collect().zip(prCurve).forall(cond2))
81+
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
82+
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
83+
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
84+
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
85+
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
86+
}
87+
88+
test("binary evaluation metrics for All Negative RDD") {
89+
val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0)), 2)
90+
val metrics: BinaryClassificationMetrics = new BinaryClassificationMetrics(scoreAndLabels)
91+
92+
val threshold = Seq(0.5)
93+
val precision = Seq(0.0)
94+
val recall = Seq(0.0)
95+
val fpr = Seq(1.0)
96+
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
97+
val pr = recall.zip(precision)
98+
val prCurve = Seq((0.0, 1.0)) ++ pr
99+
val f1 = pr.map {
100+
case (0,0) => 0.0
101+
case (r, p) => 2.0 * (p * r) / (p + r)
102+
}
103+
val f2 = pr.map {
104+
case (0,0) => 0.0
105+
case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
106+
}
107+
108+
assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
109+
assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
110+
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
111+
assert(metrics.pr().collect().zip(prCurve).forall(cond2))
112+
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
113+
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
114+
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
115+
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
116+
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
117+
}
62118
}

0 commit comments

Comments
 (0)