Skip to content

Commit f411e70

Browse files
author
Andrew Bullen
committed
[SPARK-4256] Define precision as 1.0 when there are no positive examples; update code formatting per pull request comments
1 parent d9a09ef commit f411e70

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,38 +24,43 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl
2424
def apply(c: BinaryConfusionMatrix): Double
2525
}
2626

27-
/** Precision. */
27+
/** Precision. Defined as 1.0 when there are no positive examples. */
2828
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
29-
override def apply(c: BinaryConfusionMatrix): Double =
30-
if (c.numTruePositives + c.numFalsePositives == 0) {
31-
0.0
29+
override def apply(c: BinaryConfusionMatrix): Double = {
30+
val totalPositives = c.numTruePositives + c.numFalsePositives
31+
if (totalPositives == 0) {
32+
1.0
3233
} else {
33-
c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
34+
c.numTruePositives.toDouble / totalPositives
3435
}
36+
}
3537
}
3638

37-
/** False positive rate. */
39+
/** False positive rate. Defined as 0.0 when there are no negative examples. */
3840
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
39-
override def apply(c: BinaryConfusionMatrix): Double =
41+
override def apply(c: BinaryConfusionMatrix): Double = {
4042
if (c.numNegatives == 0) {
4143
0.0
4244
} else {
4345
c.numFalsePositives.toDouble / c.numNegatives
4446
}
47+
}
4548
}
4649

47-
/** Recall. */
50+
/** Recall. Defined as 0.0 when there are no positive examples. */
4851
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
49-
override def apply(c: BinaryConfusionMatrix): Double =
52+
override def apply(c: BinaryConfusionMatrix): Double = {
5053
if (c.numPositives == 0) {
5154
0.0
5255
} else {
5356
c.numTruePositives.toDouble / c.numPositives
5457
}
58+
}
5559
}
5660

5761
/**
58-
* F-Measure.
62+
* F-Measure. Defined as 0 if both precision and recall are 0. EG in the case that all examples
63+
* are false positives.
5964
* @param beta the beta constant in F-Measure
6065
* @see http://en.wikipedia.org/wiki/F1_score
6166
*/

mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
6161
}
6262

6363
test("binary evaluation metrics for All Positive RDD") {
64-
val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0)), 2)
64+
val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2)
6565
val metrics: BinaryClassificationMetrics = new BinaryClassificationMetrics(scoreAndLabels)
6666

6767
val threshold = Seq(0.5)
@@ -86,7 +86,7 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
8686
}
8787

8888
test("binary evaluation metrics for All Negative RDD") {
89-
val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0)), 2)
89+
val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0), (0.5, 0.0)), 2)
9090
val metrics: BinaryClassificationMetrics = new BinaryClassificationMetrics(scoreAndLabels)
9191

9292
val threshold = Seq(0.5)
@@ -97,11 +97,11 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
9797
val pr = recall.zip(precision)
9898
val prCurve = Seq((0.0, 1.0)) ++ pr
9999
val f1 = pr.map {
100-
case (0,0) => 0.0
100+
case (0, 0) => 0.0
101101
case (r, p) => 2.0 * (p * r) / (p + r)
102102
}
103103
val f2 = pr.map {
104-
case (0,0) => 0.0
104+
case (0, 0) => 0.0
105105
case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
106106
}
107107

0 commit comments

Comments
 (0)