Skip to content

Commit 484fecb

Browse files
Andrew Bullenmengxr
Andrew Bullen
authored andcommitted
[SPARK-4256] Make Binary Evaluation Metrics functions defined in cases where there ar...
...e 0 positive or 0 negative examples. Author: Andrew Bullen <andrew.bullen@workday.com> Closes #3118 from abull/master and squashes the following commits: c2bf2b1 [Andrew Bullen] [SPARK-4256] Update Code formatting for BinaryClassificationMetricsSpec 36b0533 [Andrew Bullen] [SYMAN-4256] Extract BinaryClassificationMetricsSuite assertions into private method 4d2f79a [Andrew Bullen] [SPARK-4256] Refactor classification metrics tests - extract comparison functions in test f411e70 [Andrew Bullen] [SPARK-4256] Define precision as 1.0 when there are no positive examples; update code formatting per pull request comments d9a09ef [Andrew Bullen] Make Binary Evaluation Metrics functions defined in cases where there are 0 positive or 0 negative examples.
1 parent b9e1c2e commit 484fecb

File tree

2 files changed

+113
-27
lines changed

2 files changed

+113
-27
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,43 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl
2424
def apply(c: BinaryConfusionMatrix): Double
2525
}
2626

27-
/** Precision. */
27+
/** Precision. Defined as 1.0 when there are no positive examples. */
2828
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
29-
override def apply(c: BinaryConfusionMatrix): Double =
30-
c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
29+
override def apply(c: BinaryConfusionMatrix): Double = {
30+
val totalPositives = c.numTruePositives + c.numFalsePositives
31+
if (totalPositives == 0) {
32+
1.0
33+
} else {
34+
c.numTruePositives.toDouble / totalPositives
35+
}
36+
}
3137
}
3238

33-
/** False positive rate. */
39+
/** False positive rate. Defined as 0.0 when there are no negative examples. */
3440
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
35-
override def apply(c: BinaryConfusionMatrix): Double =
36-
c.numFalsePositives.toDouble / c.numNegatives
41+
override def apply(c: BinaryConfusionMatrix): Double = {
42+
if (c.numNegatives == 0) {
43+
0.0
44+
} else {
45+
c.numFalsePositives.toDouble / c.numNegatives
46+
}
47+
}
3748
}
3849

39-
/** Recall. */
50+
/** Recall. Defined as 0.0 when there are no positive examples. */
4051
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
41-
override def apply(c: BinaryConfusionMatrix): Double =
42-
c.numTruePositives.toDouble / c.numPositives
52+
override def apply(c: BinaryConfusionMatrix): Double = {
53+
if (c.numPositives == 0) {
54+
0.0
55+
} else {
56+
c.numTruePositives.toDouble / c.numPositives
57+
}
58+
}
4359
}
4460

4561
/**
46-
* F-Measure.
62+
* F-Measure. Defined as 0 if both precision and recall are 0. EG in the case that all examples
63+
* are false positives.
4764
* @param beta the beta constant in F-Measure
4865
* @see http://en.wikipedia.org/wiki/F1_score
4966
*/
@@ -52,6 +69,10 @@ private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificati
5269
override def apply(c: BinaryConfusionMatrix): Double = {
5370
val precision = Precision(c)
5471
val recall = Recall(c)
55-
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
72+
if (precision + recall == 0) {
73+
0.0
74+
} else {
75+
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
76+
}
5677
}
5778
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala

Lines changed: 81 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,39 +24,104 @@ import org.apache.spark.mllib.util.TestingUtils._
2424

2525
class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext {
2626

27-
def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
27+
private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
2828

29-
def cond2(x: ((Double, Double), (Double, Double))): Boolean =
29+
private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean =
3030
(x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)
3131

32+
private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = {
33+
assert(left.zip(right).forall(areWithinEpsilon))
34+
}
35+
36+
private def assertTupleSequencesMatch(left: Seq[(Double, Double)],
37+
right: Seq[(Double, Double)]): Unit = {
38+
assert(left.zip(right).forall(pairsWithinEpsilon))
39+
}
40+
41+
private def validateMetrics(metrics: BinaryClassificationMetrics,
42+
expectedThresholds: Seq[Double],
43+
expectedROCCurve: Seq[(Double, Double)],
44+
expectedPRCurve: Seq[(Double, Double)],
45+
expectedFMeasures1: Seq[Double],
46+
expectedFmeasures2: Seq[Double],
47+
expectedPrecisions: Seq[Double],
48+
expectedRecalls: Seq[Double]) = {
49+
50+
assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds)
51+
assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
52+
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(expectedROCCurve) absTol 1E-5)
53+
assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve)
54+
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(expectedPRCurve) absTol 1E-5)
55+
assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(),
56+
expectedThresholds.zip(expectedFMeasures1))
57+
assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(),
58+
expectedThresholds.zip(expectedFmeasures2))
59+
assertTupleSequencesMatch(metrics.precisionByThreshold().collect(),
60+
expectedThresholds.zip(expectedPrecisions))
61+
assertTupleSequencesMatch(metrics.recallByThreshold().collect(),
62+
expectedThresholds.zip(expectedRecalls))
63+
}
64+
3265
test("binary evaluation metrics") {
3366
val scoreAndLabels = sc.parallelize(
3467
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
3568
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
36-
val threshold = Seq(0.8, 0.6, 0.4, 0.1)
69+
val thresholds = Seq(0.8, 0.6, 0.4, 0.1)
3770
val numTruePositives = Seq(1, 3, 3, 4)
3871
val numFalsePositives = Seq(0, 1, 2, 3)
3972
val numPositives = 4
4073
val numNegatives = 3
41-
val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
74+
val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
4275
t.toDouble / (t + f)
4376
}
44-
val recall = numTruePositives.map(t => t.toDouble / numPositives)
77+
val recalls = numTruePositives.map(t => t.toDouble / numPositives)
4578
val fpr = numFalsePositives.map(f => f.toDouble / numNegatives)
46-
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
47-
val pr = recall.zip(precision)
79+
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
80+
val pr = recalls.zip(precisions)
4881
val prCurve = Seq((0.0, 1.0)) ++ pr
4982
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
5083
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
5184

52-
assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
53-
assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
54-
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
55-
assert(metrics.pr().collect().zip(prCurve).forall(cond2))
56-
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
57-
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
58-
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
59-
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
60-
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
85+
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
86+
}
87+
88+
test("binary evaluation metrics for RDD where all examples have positive label") {
89+
val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2)
90+
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
91+
92+
val thresholds = Seq(0.5)
93+
val precisions = Seq(1.0)
94+
val recalls = Seq(1.0)
95+
val fpr = Seq(0.0)
96+
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
97+
val pr = recalls.zip(precisions)
98+
val prCurve = Seq((0.0, 1.0)) ++ pr
99+
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
100+
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
101+
102+
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
103+
}
104+
105+
test("binary evaluation metrics for RDD where all examples have negative label") {
106+
val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0), (0.5, 0.0)), 2)
107+
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
108+
109+
val thresholds = Seq(0.5)
110+
val precisions = Seq(0.0)
111+
val recalls = Seq(0.0)
112+
val fpr = Seq(1.0)
113+
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
114+
val pr = recalls.zip(precisions)
115+
val prCurve = Seq((0.0, 1.0)) ++ pr
116+
val f1 = pr.map {
117+
case (0, 0) => 0.0
118+
case (r, p) => 2.0 * (p * r) / (p + r)
119+
}
120+
val f2 = pr.map {
121+
case (0, 0) => 0.0
122+
case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
123+
}
124+
125+
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
61126
}
62127
}

0 commit comments

Comments
 (0)