|
17 | 17 |
|
18 | 18 | package org.apache.spark.mllib.evaluation.binary
|
19 | 19 |
|
20 |
| -import org.apache.spark.rdd.RDD |
| 20 | +import org.apache.spark.rdd.{UnionRDD, RDD} |
21 | 21 | import org.apache.spark.SparkContext._
|
22 | 22 | import org.apache.spark.mllib.evaluation.AreaUnderCurve
|
23 | 23 | import org.apache.spark.Logging
|
@@ -103,22 +103,34 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)])
|
103 | 103 |
|
104 | 104 | /**
|
105 | 105 | * Returns the receiver operating characteristic (ROC) curve,
|
106 |
| - * which is an RDD of (false positive rate, true positive rate). |
| 106 | + * which is an RDD of (false positive rate, true positive rate) |
| 107 | + * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. |
107 | 108 | * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
|
108 | 109 | */
|
109 |
| - def roc(): RDD[(Double, Double)] = createCurve(FalsePositiveRate, Recall) |
| 110 | + def roc(): RDD[(Double, Double)] = { |
| 111 | + val rocCurve = createCurve(FalsePositiveRate, Recall) |
| 112 | + val sc = confusions.context |
| 113 | + val first = sc.makeRDD(Seq((0.0, 0.0)), 1) |
| 114 | + val last = sc.makeRDD(Seq((1.0, 1.0)), 1) |
| 115 | + new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last)) |
| 116 | + } |
110 | 117 |
|
111 | 118 | /**
|
112 | 119 | * Computes the area under the receiver operating characteristic (ROC) curve.
|
113 | 120 | */
|
114 | 121 | def areaUnderROC(): Double = AreaUnderCurve.of(roc())
|
115 | 122 |
|
116 | 123 | /**
|
117 |
| - * Returns the precision-recall curve, |
118 |
| - * which is an RDD of (recall, precision), NOT (precision, recall). |
| 124 | + * Returns the precision-recall curve, which is an RDD of (recall, precision), |
| 125 | + * NOT (precision, recall), with (0.0, 1.0) prepended to it. |
119 | 126 | * @see http://en.wikipedia.org/wiki/Precision_and_recall
|
120 | 127 | */
|
121 |
| - def pr(): RDD[(Double, Double)] = createCurve(Recall, Precision) |
| 128 | + def pr(): RDD[(Double, Double)] = { |
| 129 | + val prCurve = createCurve(Recall, Precision) |
| 130 | + val sc = confusions.context |
| 131 | + val first = sc.makeRDD(Seq((0.0, 1.0)), 1) |
| 132 | + first.union(prCurve) |
| 133 | + } |
122 | 134 |
|
123 | 135 | /**
|
124 | 136 | * Computes the area under the precision-recall curve.
|
|
0 commit comments