15
15
* limitations under the License.
16
16
*/
17
17
18
- package org .apache .spark .mllib .evaluation
18
+ package org .apache .spark .mllib .evaluation . binary
19
19
20
20
import org .apache .spark .rdd .RDD
21
21
import org .apache .spark .SparkContext ._
22
+ import org .apache .spark .mllib .evaluation .AreaUnderCurve
23
+ import org .apache .spark .Logging
22
24
23
25
/**
24
- * Binary confusion matrix .
26
+ * Implementation of [[ org.apache.spark.mllib.evaluation.binary.BinaryConfusionMatrix ]] .
25
27
*
26
28
* @param count label counter for labels with scores greater than or equal to the current score
27
- * @param total label counter for all labels
29
+ * @param totalCount label counter for all labels
28
30
*/
29
- case class BinaryConfusionMatrix (
31
+ private case class BinaryConfusionMatrixImpl (
30
32
private val count : LabelCounter ,
31
- private val total : LabelCounter ) extends Serializable {
33
+ private val totalCount : LabelCounter ) extends BinaryConfusionMatrix with Serializable {
32
34
33
35
/** number of true positives */
34
- def tp : Long = count.numPositives
36
+ override def tp : Long = count.numPositives
35
37
36
38
/** number of false positives */
37
- def fp : Long = count.numNegatives
39
+ override def fp : Long = count.numNegatives
38
40
39
41
/** number of false negatives */
40
- def fn : Long = total .numPositives - count.numPositives
42
+ override def fn : Long = totalCount .numPositives - count.numPositives
41
43
42
44
/** number of true negatives */
43
- def tn : Long = total .numNegatives - count.numNegatives
45
+ override def tn : Long = totalCount .numNegatives - count.numNegatives
44
46
45
47
/** number of positives */
46
- def p : Long = total .numPositives
48
+ override def p : Long = totalCount .numPositives
47
49
48
50
/** number of negatives */
49
- def n : Long = total.numNegatives
50
- }
51
-
52
- private trait Metric {
53
- def apply (c : BinaryConfusionMatrix ): Double
54
- }
55
-
56
- object Precision extends Metric {
57
- override def apply (c : BinaryConfusionMatrix ): Double =
58
- c.tp.toDouble / (c.tp + c.fp)
59
- }
60
-
61
- object FalsePositiveRate extends Metric {
62
- override def apply (c : BinaryConfusionMatrix ): Double =
63
- c.fp.toDouble / c.n
64
- }
65
-
66
- object Recall extends Metric {
67
- override def apply (c : BinaryConfusionMatrix ): Double =
68
- c.tp.toDouble / c.p
69
- }
70
-
71
- case class FMeasure (beta : Double ) extends Metric {
72
- private val beta2 = beta * beta
73
- override def apply (c : BinaryConfusionMatrix ): Double = {
74
- val precision = Precision (c)
75
- val recall = Recall (c)
76
- (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
77
- }
51
+ override def n : Long = totalCount.numNegatives
78
52
}
79
53
80
54
/**
81
55
* Evaluator for binary classification.
82
56
*
83
57
* @param scoreAndlabels an RDD of (score, label) pairs.
84
58
*/
85
- class BinaryClassificationEvaluator (scoreAndlabels : RDD [(Double , Double )]) extends Serializable {
59
+ class BinaryClassificationEvaluator (scoreAndlabels : RDD [(Double , Double )]) extends Serializable with Logging {
86
60
87
- private lazy val (cumCounts : RDD [(Double , LabelCounter )], totalCount : LabelCounter , scoreAndConfusion : RDD [(Double , BinaryConfusionMatrix )]) = {
61
+ private lazy val (
62
+ cumCounts : RDD [(Double , LabelCounter )],
63
+ confusionByThreshold : RDD [(Double , BinaryConfusionMatrix )]) = {
88
64
// Create a bin for each distinct score value, count positives and negatives within each bin,
89
65
// and then sort by score values in descending order.
90
66
val counts = scoreAndlabels.combineByKey(
@@ -99,6 +75,7 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
99
75
}, preservesPartitioning = true ).collect()
100
76
val cum = agg.scanLeft(new LabelCounter ())((agg : LabelCounter , c : LabelCounter ) => agg + c)
101
77
val totalCount = cum.last
78
+ logInfo(s " Total counts: totalCount " )
102
79
val cumCounts = counts.mapPartitionsWithIndex((index : Int , iter : Iterator [(Double , LabelCounter )]) => {
103
80
val cumCount = cum(index)
104
81
iter.map { case (score, c) =>
@@ -108,76 +85,71 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
108
85
}, preservesPartitioning = true )
109
86
cumCounts.persist()
110
87
val scoreAndConfusion = cumCounts.map { case (score, cumCount) =>
111
- (score, BinaryConfusionMatrix (cumCount, totalCount))
88
+ (score, BinaryConfusionMatrixImpl (cumCount, totalCount))
112
89
}
113
90
(cumCounts, totalCount, scoreAndConfusion)
114
91
}
115
92
93
+ /** Unpersist intermediate RDDs used in the computation. */
116
94
def unpersist () {
117
95
cumCounts.unpersist()
118
96
}
119
97
98
+ /**
99
+ * Returns the receiver operating characteristic (ROC) curve.
100
+ * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
101
+ */
120
102
def rocCurve (): RDD [(Double , Double )] = createCurve(FalsePositiveRate , Recall )
121
103
104
+ /**
105
+ * Computes the area under the receiver operating characteristic (ROC) curve.
106
+ */
122
107
def rocAUC (): Double = AreaUnderCurve .of(rocCurve())
123
108
109
+ /**
110
+ * Returns the precision-recall curve.
111
+ * @see http://en.wikipedia.org/wiki/Precision_and_recall
112
+ */
124
113
def prCurve (): RDD [(Double , Double )] = createCurve(Recall , Precision )
125
114
115
+ /**
116
+ * Computes the area under the precision-recall curve.
117
+ */
126
118
def prAUC (): Double = AreaUnderCurve .of(prCurve())
127
119
120
+ /**
121
+ * Returns the (threshold, F-Measure) curve.
122
+ * @param beta the beta factor in F-Measure computation.
123
+ * @return an RDD of (threshold, F-Measure) pairs.
124
+ * @see http://en.wikipedia.org/wiki/F1_score
125
+ */
128
126
def fMeasureByThreshold (beta : Double ): RDD [(Double , Double )] = createCurve(FMeasure (beta))
129
127
128
+ /** Returns the (threshold, F-Measure) curve with beta = 1.0. */
130
129
def fMeasureByThreshold () = fMeasureByThreshold(1.0 )
131
130
132
- private def createCurve (y : Metric ): RDD [(Double , Double )] = {
133
- scoreAndConfusion.map { case (s, c) =>
131
+ /** Creates a curve of (threshold, metric). */
132
+ private def createCurve (y : BinaryClassificationMetric ): RDD [(Double , Double )] = {
133
+ confusionByThreshold.map { case (s, c) =>
134
134
(s, y(c))
135
135
}
136
136
}
137
137
138
- private def createCurve (x : Metric , y : Metric ): RDD [(Double , Double )] = {
139
- scoreAndConfusion.map { case (_, c) =>
138
+ /** Creates a curve of (metricX, metricY). */
139
+ private def createCurve (x : BinaryClassificationMetric , y : BinaryClassificationMetric ): RDD [(Double , Double )] = {
140
+ confusionByThreshold.map { case (_, c) =>
140
141
(x(c), y(c))
141
142
}
142
143
}
143
144
}
144
145
145
- class LocalBinaryClassificationEvaluator {
146
- def get (data : Iterable [(Double , Double )]) {
147
- val counts = data.groupBy(_._1).mapValues { s =>
148
- val c = new LabelCounter ()
149
- s.foreach(c += _._2)
150
- c
151
- }.toSeq.sortBy(- _._1)
152
- println(" counts: " + counts.toList)
153
- val total = new LabelCounter ()
154
- val cum = counts.map { s =>
155
- total += s._2
156
- (s._1, total.clone())
157
- }
158
- println(" cum: " + cum.toList)
159
- val roc = cum.map { case (s, c) =>
160
- (1.0 * c.numNegatives / total.numNegatives, 1.0 * c.numPositives / total.numPositives)
161
- }
162
- val rocAUC = AreaUnderCurve .of(roc)
163
- println(rocAUC)
164
- val pr = cum.map { case (s, c) =>
165
- (1.0 * c.numPositives / total.numPositives,
166
- 1.0 * c.numPositives / (c.numPositives + c.numNegatives))
167
- }
168
- val prAUC = AreaUnderCurve .of(pr)
169
- println(prAUC)
170
- }
171
- }
172
-
173
146
/**
174
147
* A counter for positives and negatives.
175
148
*
176
- * @param numPositives
177
- * @param numNegatives
149
+ * @param numPositives number of positive labels
150
+ * @param numNegatives number of negative labels
178
151
*/
179
- private [evaluation]
180
- class LabelCounter (var numPositives : Long = 0L , var numNegatives : Long = 0L ) extends Serializable {
152
+ private class LabelCounter (var numPositives : Long = 0L , var numNegatives : Long = 0L ) extends Serializable {
181
153
182
154
/** Process a label. */
183
155
def += (label : Double ): LabelCounter = {
@@ -208,6 +180,6 @@ class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) ext
208
180
new LabelCounter (numPositives, numNegatives)
209
181
}
210
182
211
- override def toString : String = s " [ $numPositives, $numNegatives] "
183
+ override def toString : String = s " {numPos: $numPositives, numNeg: $numNegatives} "
212
184
}
213
185
0 commit comments