Skip to content

Commit 3f42e98

Browse files
committed
add (0, 0), (1, 1) to roc, and (0, 1) to pr
1 parent fb4b6d2 commit 3f42e98

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.mllib.evaluation.binary
1919

20-
import org.apache.spark.rdd.RDD
20+
import org.apache.spark.rdd.{UnionRDD, RDD}
2121
import org.apache.spark.SparkContext._
2222
import org.apache.spark.mllib.evaluation.AreaUnderCurve
2323
import org.apache.spark.Logging
@@ -103,22 +103,34 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)])
103103

104104
/**
105105
* Returns the receiver operating characteristic (ROC) curve,
106-
* which is an RDD of (false positive rate, true positive rate).
106+
* which is an RDD of (false positive rate, true positive rate)
107+
* with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
107108
* @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
108109
*/
109-
def roc(): RDD[(Double, Double)] = createCurve(FalsePositiveRate, Recall)
110+
def roc(): RDD[(Double, Double)] = {
111+
val rocCurve = createCurve(FalsePositiveRate, Recall)
112+
val sc = confusions.context
113+
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
114+
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
115+
new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last))
116+
}
110117

111118
/**
112119
* Computes the area under the receiver operating characteristic (ROC) curve.
113120
*/
114121
def areaUnderROC(): Double = AreaUnderCurve.of(roc())
115122

116123
/**
117-
* Returns the precision-recall curve,
118-
* which is an RDD of (recall, precision), NOT (precision, recall).
124+
* Returns the precision-recall curve, which is an RDD of (recall, precision),
125+
* NOT (precision, recall), with (0.0, 1.0) prepended to it.
119126
* @see http://en.wikipedia.org/wiki/Precision_and_recall
120127
*/
121-
def pr(): RDD[(Double, Double)] = createCurve(Recall, Precision)
128+
def pr(): RDD[(Double, Double)] = {
129+
val prCurve = createCurve(Recall, Precision)
130+
val sc = confusions.context
131+
val first = sc.makeRDD(Seq((0.0, 1.0)), 1)
132+
first.union(prCurve)
133+
}
122134

123135
/**
124136
* Computes the area under the precision-recall curve.

mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,16 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
3535
val precision = tp.zip(fp).map { case (t, f) => t.toDouble / (t + f) }
3636
val recall = tp.map(t => t.toDouble / p)
3737
val fpr = fp.map(f => f.toDouble / n)
38-
val roc = fpr.zip(recall)
38+
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
3939
val pr = recall.zip(precision)
40+
val prCurve = Seq((0.0, 1.0)) ++ pr
4041
val f1 = pr.map { case (re, prec) => 2.0 * (prec * re) / (prec + re) }
4142
val f2 = pr.map { case (re, prec) => 5.0 * (prec * re) / (4.0 * prec + re)}
4243
assert(metrics.thresholds().collect().toSeq === score)
43-
assert(metrics.roc().collect().toSeq === roc)
44-
assert(metrics.areaUnderROC() === AreaUnderCurve.of(roc))
45-
assert(metrics.pr().collect().toSeq === pr)
46-
assert(metrics.areaUnderPR() === AreaUnderCurve.of(pr))
44+
assert(metrics.roc().collect().toSeq === rocCurve)
45+
assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve))
46+
assert(metrics.pr().collect().toSeq === prCurve)
47+
assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve))
4748
assert(metrics.fMeasureByThreshold().collect().toSeq === score.zip(f1))
4849
assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === score.zip(f2))
4950
assert(metrics.precisionByThreshold().collect().toSeq === score.zip(precision))

0 commit comments

Comments
 (0)