Skip to content

Commit f5ace8d

Browse files
mengxrmateiz
authored andcommitted
[SPARK-1225, 1241] [MLLIB] Add AreaUnderCurve and BinaryClassificationMetrics
This PR implements a generic version of `AreaUnderCurve` using the `RDD.sliding` implementation from #136 . It also contains refactoring of #160 for binary classification evaluation. Author: Xiangrui Meng <meng@databricks.com> Closes #364 from mengxr/auc and squashes the following commits: a05941d [Xiangrui Meng] replace TP/FP/TN/FN by their full names 3f42e98 [Xiangrui Meng] add (0, 0), (1, 1) to roc, and (0, 1) to pr fb4b6d2 [Xiangrui Meng] rename Evaluator to Metrics and add more metrics b1b7dab [Xiangrui Meng] fix code styles 9dc3518 [Xiangrui Meng] add tests for BinaryClassificationEvaluator ca31da5 [Xiangrui Meng] remove PredictionAndResponse 3d71525 [Xiangrui Meng] move binary evalution classes to evaluation.binary 8f78958 [Xiangrui Meng] add PredictionAndResponse dda82d5 [Xiangrui Meng] add confusion matrix aa7e278 [Xiangrui Meng] add initial version of binary classification evaluator 221ebce [Xiangrui Meng] add a new test to sliding a920865 [Xiangrui Meng] Merge branch 'sliding' into auc a9b250a [Xiangrui Meng] move sliding to mllib cab9a52 [Xiangrui Meng] use last for the last element db6cb30 [Xiangrui Meng] remove unnecessary toSeq 9916202 [Xiangrui Meng] change RDD.sliding return type to RDD[Seq[T]] 284d991 [Xiangrui Meng] change SlidedRDD to SlidingRDD c1c6c22 [Xiangrui Meng] add AreaUnderCurve 65461b2 [Xiangrui Meng] Merge branch 'sliding' into auc 5ee6001 [Xiangrui Meng] add TODO d2a600d [Xiangrui Meng] add sliding to rdd
1 parent 98225a6 commit f5ace8d

File tree

9 files changed

+671
-0
lines changed

9 files changed

+671
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation
19+
20+
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.mllib.rdd.RDDFunctions._
22+
23+
/**
24+
* Computes the area under the curve (AUC) using the trapezoidal rule.
25+
*/
26+
private[evaluation] object AreaUnderCurve {
27+
28+
/**
29+
* Uses the trapezoidal rule to compute the area under the line connecting the two input points.
30+
* @param points two 2D points stored in Seq
31+
*/
32+
private def trapezoid(points: Seq[(Double, Double)]): Double = {
33+
require(points.length == 2)
34+
val x = points.head
35+
val y = points.last
36+
(y._1 - x._1) * (y._2 + x._2) / 2.0
37+
}
38+
39+
/**
40+
* Returns the area under the given curve.
41+
*
42+
* @param curve a RDD of ordered 2D points stored in pairs representing a curve
43+
*/
44+
def of(curve: RDD[(Double, Double)]): Double = {
45+
curve.sliding(2).aggregate(0.0)(
46+
seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
47+
combOp = _ + _
48+
)
49+
}
50+
51+
/**
52+
* Returns the area under the given curve.
53+
*
54+
* @param curve an iterator over ordered 2D points stored in pairs representing a curve
55+
*/
56+
def of(curve: Iterable[(Double, Double)]): Double = {
57+
curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)(
58+
seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
59+
combop = _ + _
60+
)
61+
}
62+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation.binary
19+
20+
/**
21+
* Trait for a binary classification evaluation metric computer.
22+
*/
23+
private[evaluation] trait BinaryClassificationMetricComputer extends Serializable {
24+
def apply(c: BinaryConfusionMatrix): Double
25+
}
26+
27+
/** Precision. */
28+
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
29+
override def apply(c: BinaryConfusionMatrix): Double =
30+
c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
31+
}
32+
33+
/** False positive rate. */
34+
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
35+
override def apply(c: BinaryConfusionMatrix): Double =
36+
c.numFalsePositives.toDouble / c.numNegatives
37+
}
38+
39+
/** Recall. */
40+
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
41+
override def apply(c: BinaryConfusionMatrix): Double =
42+
c.numTruePositives.toDouble / c.numPositives
43+
}
44+
45+
/**
46+
* F-Measure.
47+
* @param beta the beta constant in F-Measure
48+
* @see http://en.wikipedia.org/wiki/F1_score
49+
*/
50+
private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificationMetricComputer {
51+
private val beta2 = beta * beta
52+
override def apply(c: BinaryConfusionMatrix): Double = {
53+
val precision = Precision(c)
54+
val recall = Recall(c)
55+
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
56+
}
57+
}
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation.binary
19+
20+
import org.apache.spark.rdd.{UnionRDD, RDD}
21+
import org.apache.spark.SparkContext._
22+
import org.apache.spark.mllib.evaluation.AreaUnderCurve
23+
import org.apache.spark.Logging
24+
25+
/**
26+
* Implementation of [[org.apache.spark.mllib.evaluation.binary.BinaryConfusionMatrix]].
27+
*
28+
* @param count label counter for labels with scores greater than or equal to the current score
29+
* @param totalCount label counter for all labels
30+
*/
31+
private case class BinaryConfusionMatrixImpl(
32+
count: LabelCounter,
33+
totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable {
34+
35+
/** number of true positives */
36+
override def numTruePositives: Long = count.numPositives
37+
38+
/** number of false positives */
39+
override def numFalsePositives: Long = count.numNegatives
40+
41+
/** number of false negatives */
42+
override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives
43+
44+
/** number of true negatives */
45+
override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives
46+
47+
/** number of positives */
48+
override def numPositives: Long = totalCount.numPositives
49+
50+
/** number of negatives */
51+
override def numNegatives: Long = totalCount.numNegatives
52+
}
53+
54+
/**
55+
* Evaluator for binary classification.
56+
*
57+
* @param scoreAndLabels an RDD of (score, label) pairs.
58+
*/
59+
class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)])
60+
extends Serializable with Logging {
61+
62+
private lazy val (
63+
cumulativeCounts: RDD[(Double, LabelCounter)],
64+
confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
65+
// Create a bin for each distinct score value, count positives and negatives within each bin,
66+
// and then sort by score values in descending order.
67+
val counts = scoreAndLabels.combineByKey(
68+
createCombiner = (label: Double) => new LabelCounter(0L, 0L) += label,
69+
mergeValue = (c: LabelCounter, label: Double) => c += label,
70+
mergeCombiners = (c1: LabelCounter, c2: LabelCounter) => c1 += c2
71+
).sortByKey(ascending = false)
72+
val agg = counts.values.mapPartitions({ iter =>
73+
val agg = new LabelCounter()
74+
iter.foreach(agg += _)
75+
Iterator(agg)
76+
}, preservesPartitioning = true).collect()
77+
val partitionwiseCumulativeCounts =
78+
agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg.clone() += c)
79+
val totalCount = partitionwiseCumulativeCounts.last
80+
logInfo(s"Total counts: $totalCount")
81+
val cumulativeCounts = counts.mapPartitionsWithIndex(
82+
(index: Int, iter: Iterator[(Double, LabelCounter)]) => {
83+
val cumCount = partitionwiseCumulativeCounts(index)
84+
iter.map { case (score, c) =>
85+
cumCount += c
86+
(score, cumCount.clone())
87+
}
88+
}, preservesPartitioning = true)
89+
cumulativeCounts.persist()
90+
val confusions = cumulativeCounts.map { case (score, cumCount) =>
91+
(score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])
92+
}
93+
(cumulativeCounts, confusions)
94+
}
95+
96+
/** Unpersist intermediate RDDs used in the computation. */
97+
def unpersist() {
98+
cumulativeCounts.unpersist()
99+
}
100+
101+
/** Returns thresholds in descending order. */
102+
def thresholds(): RDD[Double] = cumulativeCounts.map(_._1)
103+
104+
/**
105+
* Returns the receiver operating characteristic (ROC) curve,
106+
* which is an RDD of (false positive rate, true positive rate)
107+
* with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
108+
* @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
109+
*/
110+
def roc(): RDD[(Double, Double)] = {
111+
val rocCurve = createCurve(FalsePositiveRate, Recall)
112+
val sc = confusions.context
113+
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
114+
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
115+
new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last))
116+
}
117+
118+
/**
119+
* Computes the area under the receiver operating characteristic (ROC) curve.
120+
*/
121+
def areaUnderROC(): Double = AreaUnderCurve.of(roc())
122+
123+
/**
124+
* Returns the precision-recall curve, which is an RDD of (recall, precision),
125+
* NOT (precision, recall), with (0.0, 1.0) prepended to it.
126+
* @see http://en.wikipedia.org/wiki/Precision_and_recall
127+
*/
128+
def pr(): RDD[(Double, Double)] = {
129+
val prCurve = createCurve(Recall, Precision)
130+
val sc = confusions.context
131+
val first = sc.makeRDD(Seq((0.0, 1.0)), 1)
132+
first.union(prCurve)
133+
}
134+
135+
/**
136+
* Computes the area under the precision-recall curve.
137+
*/
138+
def areaUnderPR(): Double = AreaUnderCurve.of(pr())
139+
140+
/**
141+
* Returns the (threshold, F-Measure) curve.
142+
* @param beta the beta factor in F-Measure computation.
143+
* @return an RDD of (threshold, F-Measure) pairs.
144+
* @see http://en.wikipedia.org/wiki/F1_score
145+
*/
146+
def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta))
147+
148+
/** Returns the (threshold, F-Measure) curve with beta = 1.0. */
149+
def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0)
150+
151+
/** Returns the (threshold, precision) curve. */
152+
def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision)
153+
154+
/** Returns the (threshold, recall) curve. */
155+
def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall)
156+
157+
/** Creates a curve of (threshold, metric). */
158+
private def createCurve(y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = {
159+
confusions.map { case (s, c) =>
160+
(s, y(c))
161+
}
162+
}
163+
164+
/** Creates a curve of (metricX, metricY). */
165+
private def createCurve(
166+
x: BinaryClassificationMetricComputer,
167+
y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = {
168+
confusions.map { case (_, c) =>
169+
(x(c), y(c))
170+
}
171+
}
172+
}
173+
174+
/**
175+
* A counter for positives and negatives.
176+
*
177+
* @param numPositives number of positive labels
178+
* @param numNegatives number of negative labels
179+
*/
180+
private class LabelCounter(
181+
var numPositives: Long = 0L,
182+
var numNegatives: Long = 0L) extends Serializable {
183+
184+
/** Processes a label. */
185+
def +=(label: Double): LabelCounter = {
186+
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
187+
// -1.0 for negative as well.
188+
if (label > 0.5) numPositives += 1L else numNegatives += 1L
189+
this
190+
}
191+
192+
/** Merges another counter. */
193+
def +=(other: LabelCounter): LabelCounter = {
194+
numPositives += other.numPositives
195+
numNegatives += other.numNegatives
196+
this
197+
}
198+
199+
override def clone: LabelCounter = {
200+
new LabelCounter(numPositives, numNegatives)
201+
}
202+
203+
override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"
204+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation.binary
19+
20+
/**
21+
* Trait for a binary confusion matrix.
22+
*/
23+
private[evaluation] trait BinaryConfusionMatrix {
24+
/** number of true positives */
25+
def numTruePositives: Long
26+
27+
/** number of false positives */
28+
def numFalsePositives: Long
29+
30+
/** number of false negatives */
31+
def numFalseNegatives: Long
32+
33+
/** number of true negatives */
34+
def numTrueNegatives: Long
35+
36+
/** number of positives */
37+
def numPositives: Long = numTruePositives + numFalseNegatives
38+
39+
/** number of negatives */
40+
def numNegatives: Long = numFalsePositives + numTrueNegatives
41+
}

0 commit comments

Comments
 (0)