Skip to content

Commit 3d71525

Browse files
committed
move binary evalution classes to evaluation.binary
1 parent 8f78958 commit 3d71525

File tree

3 files changed

+149
-79
lines changed

3 files changed

+149
-79
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationEvaluator.scala renamed to mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationEvaluator.scala

Lines changed: 51 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -15,76 +15,52 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.mllib.evaluation
18+
package org.apache.spark.mllib.evaluation.binary
1919

2020
import org.apache.spark.rdd.RDD
2121
import org.apache.spark.SparkContext._
22+
import org.apache.spark.mllib.evaluation.AreaUnderCurve
23+
import org.apache.spark.Logging
2224

2325
/**
24-
* Binary confusion matrix.
26+
* Implementation of [[org.apache.spark.mllib.evaluation.binary.BinaryConfusionMatrix]].
2527
*
2628
* @param count label counter for labels with scores greater than or equal to the current score
27-
* @param total label counter for all labels
29+
* @param totalCount label counter for all labels
2830
*/
29-
case class BinaryConfusionMatrix(
31+
private case class BinaryConfusionMatrixImpl(
3032
private val count: LabelCounter,
31-
private val total: LabelCounter) extends Serializable {
33+
private val totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable {
3234

3335
/** number of true positives */
34-
def tp: Long = count.numPositives
36+
override def tp: Long = count.numPositives
3537

3638
/** number of false positives */
37-
def fp: Long = count.numNegatives
39+
override def fp: Long = count.numNegatives
3840

3941
/** number of false negatives */
40-
def fn: Long = total.numPositives - count.numPositives
42+
override def fn: Long = totalCount.numPositives - count.numPositives
4143

4244
/** number of true negatives */
43-
def tn: Long = total.numNegatives - count.numNegatives
45+
override def tn: Long = totalCount.numNegatives - count.numNegatives
4446

4547
/** number of positives */
46-
def p: Long = total.numPositives
48+
override def p: Long = totalCount.numPositives
4749

4850
/** number of negatives */
49-
def n: Long = total.numNegatives
50-
}
51-
52-
private trait Metric {
53-
def apply(c: BinaryConfusionMatrix): Double
54-
}
55-
56-
object Precision extends Metric {
57-
override def apply(c: BinaryConfusionMatrix): Double =
58-
c.tp.toDouble / (c.tp + c.fp)
59-
}
60-
61-
object FalsePositiveRate extends Metric {
62-
override def apply(c: BinaryConfusionMatrix): Double =
63-
c.fp.toDouble / c.n
64-
}
65-
66-
object Recall extends Metric {
67-
override def apply(c: BinaryConfusionMatrix): Double =
68-
c.tp.toDouble / c.p
69-
}
70-
71-
case class FMeasure(beta: Double) extends Metric {
72-
private val beta2 = beta * beta
73-
override def apply(c: BinaryConfusionMatrix): Double = {
74-
val precision = Precision(c)
75-
val recall = Recall(c)
76-
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
77-
}
51+
override def n: Long = totalCount.numNegatives
7852
}
7953

8054
/**
8155
* Evaluator for binary classification.
8256
*
8357
* @param scoreAndlabels an RDD of (score, label) pairs.
8458
*/
85-
class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) extends Serializable {
59+
class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) extends Serializable with Logging {
8660

87-
private lazy val (cumCounts: RDD[(Double, LabelCounter)], totalCount: LabelCounter, scoreAndConfusion: RDD[(Double, BinaryConfusionMatrix)]) = {
61+
private lazy val (
62+
cumCounts: RDD[(Double, LabelCounter)],
63+
confusionByThreshold: RDD[(Double, BinaryConfusionMatrix)]) = {
8864
// Create a bin for each distinct score value, count positives and negatives within each bin,
8965
// and then sort by score values in descending order.
9066
val counts = scoreAndlabels.combineByKey(
@@ -99,6 +75,7 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
9975
}, preservesPartitioning = true).collect()
10076
val cum = agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg + c)
10177
val totalCount = cum.last
78+
logInfo(s"Total counts: totalCount")
10279
val cumCounts = counts.mapPartitionsWithIndex((index: Int, iter: Iterator[(Double, LabelCounter)]) => {
10380
val cumCount = cum(index)
10481
iter.map { case (score, c) =>
@@ -108,76 +85,71 @@ class BinaryClassificationEvaluator(scoreAndlabels: RDD[(Double, Double)]) exten
10885
}, preservesPartitioning = true)
10986
cumCounts.persist()
11087
val scoreAndConfusion = cumCounts.map { case (score, cumCount) =>
111-
(score, BinaryConfusionMatrix(cumCount, totalCount))
88+
(score, BinaryConfusionMatrixImpl(cumCount, totalCount))
11289
}
11390
(cumCounts, totalCount, scoreAndConfusion)
11491
}
11592

93+
/** Unpersist intermediate RDDs used in the computation. */
11694
def unpersist() {
11795
cumCounts.unpersist()
11896
}
11997

98+
/**
99+
* Returns the receiver operating characteristic (ROC) curve.
100+
* @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
101+
*/
120102
def rocCurve(): RDD[(Double, Double)] = createCurve(FalsePositiveRate, Recall)
121103

104+
/**
105+
* Computes the area under the receiver operating characteristic (ROC) curve.
106+
*/
122107
def rocAUC(): Double = AreaUnderCurve.of(rocCurve())
123108

109+
/**
110+
* Returns the precision-recall curve.
111+
* @see http://en.wikipedia.org/wiki/Precision_and_recall
112+
*/
124113
def prCurve(): RDD[(Double, Double)] = createCurve(Recall, Precision)
125114

115+
/**
116+
* Computes the area under the precision-recall curve.
117+
*/
126118
def prAUC(): Double = AreaUnderCurve.of(prCurve())
127119

120+
/**
121+
* Returns the (threshold, F-Measure) curve.
122+
* @param beta the beta factor in F-Measure computation.
123+
* @return an RDD of (threshold, F-Measure) pairs.
124+
* @see http://en.wikipedia.org/wiki/F1_score
125+
*/
128126
def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta))
129127

128+
/** Returns the (threshold, F-Measure) curve with beta = 1.0. */
130129
def fMeasureByThreshold() = fMeasureByThreshold(1.0)
131130

132-
private def createCurve(y: Metric): RDD[(Double, Double)] = {
133-
scoreAndConfusion.map { case (s, c) =>
131+
/** Creates a curve of (threshold, metric). */
132+
private def createCurve(y: BinaryClassificationMetric): RDD[(Double, Double)] = {
133+
confusionByThreshold.map { case (s, c) =>
134134
(s, y(c))
135135
}
136136
}
137137

138-
private def createCurve(x: Metric, y: Metric): RDD[(Double, Double)] = {
139-
scoreAndConfusion.map { case (_, c) =>
138+
/** Creates a curve of (metricX, metricY). */
139+
private def createCurve(x: BinaryClassificationMetric, y: BinaryClassificationMetric): RDD[(Double, Double)] = {
140+
confusionByThreshold.map { case (_, c) =>
140141
(x(c), y(c))
141142
}
142143
}
143144
}
144145

145-
class LocalBinaryClassificationEvaluator {
146-
def get(data: Iterable[(Double, Double)]) {
147-
val counts = data.groupBy(_._1).mapValues { s =>
148-
val c = new LabelCounter()
149-
s.foreach(c += _._2)
150-
c
151-
}.toSeq.sortBy(- _._1)
152-
println("counts: " + counts.toList)
153-
val total = new LabelCounter()
154-
val cum = counts.map { s =>
155-
total += s._2
156-
(s._1, total.clone())
157-
}
158-
println("cum: " + cum.toList)
159-
val roc = cum.map { case (s, c) =>
160-
(1.0 * c.numNegatives / total.numNegatives, 1.0 * c.numPositives / total.numPositives)
161-
}
162-
val rocAUC = AreaUnderCurve.of(roc)
163-
println(rocAUC)
164-
val pr = cum.map { case (s, c) =>
165-
(1.0 * c.numPositives / total.numPositives,
166-
1.0 * c.numPositives / (c.numPositives + c.numNegatives))
167-
}
168-
val prAUC = AreaUnderCurve.of(pr)
169-
println(prAUC)
170-
}
171-
}
172-
173146
/**
174147
* A counter for positives and negatives.
175148
*
176-
* @param numPositives
177-
* @param numNegatives
149+
* @param numPositives number of positive labels
150+
* @param numNegatives number of negative labels
178151
*/
179-
private[evaluation]
180-
class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable {
152+
private class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable {
181153

182154
/** Process a label. */
183155
def +=(label: Double): LabelCounter = {
@@ -208,6 +180,6 @@ class LabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) ext
208180
new LabelCounter(numPositives, numNegatives)
209181
}
210182

211-
override def toString: String = s"[$numPositives,$numNegatives]"
183+
override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"
212184
}
213185

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation.binary
19+
20+
/**
21+
* Trait for a binary classification evaluation metric.
22+
*/
23+
private[evaluation] trait BinaryClassificationMetric {
24+
def apply(c: BinaryConfusionMatrix): Double
25+
}
26+
27+
/** Precision. */
28+
private[evaluation] object Precision extends BinaryClassificationMetric {
29+
override def apply(c: BinaryConfusionMatrix): Double =
30+
c.tp.toDouble / (c.tp + c.fp)
31+
}
32+
33+
/** False positive rate. */
34+
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetric {
35+
override def apply(c: BinaryConfusionMatrix): Double =
36+
c.fp.toDouble / c.n
37+
}
38+
39+
/** Recall. */
40+
private[evalution] object Recall extends BinaryClassificationMetric {
41+
override def apply(c: BinaryConfusionMatrix): Double =
42+
c.tp.toDouble / c.p
43+
}
44+
45+
/**
46+
* F-Measure.
47+
* @param beta the beta constant in F-Measure
48+
* @see http://en.wikipedia.org/wiki/F1_score
49+
*/
50+
private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificationMetric {
51+
private val beta2 = beta * beta
52+
override def apply(c: BinaryConfusionMatrix): Double = {
53+
val precision = Precision(c)
54+
val recall = Recall(c)
55+
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
56+
}
57+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation.binary
19+
20+
/**
21+
* Trait for a binary confusion matrix.
22+
*/
23+
private[evaluation] trait BinaryConfusionMatrix {
24+
/** number of true positives */
25+
def tp: Long
26+
27+
/** number of false positives */
28+
def fp: Long
29+
30+
/** number of false negatives */
31+
def fn: Long
32+
33+
/** number of true negatives */
34+
def tn: Long
35+
36+
/** number of positives */
37+
def p: Long = tp + fn
38+
39+
/** number of negatives */
40+
def n: Long = fp + tn
41+
}

0 commit comments

Comments
 (0)