Skip to content

Commit aa7e278

Browse files
committed
add initial version of binary classification evaluator
1 parent 221ebce commit aa7e278

File tree

4 files changed

+128
-6
lines changed

4 files changed

+128
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._
2323
/**
2424
* Computes the area under the curve (AUC) using the trapezoidal rule.
2525
*/
26-
object AreaUnderCurve {
26+
private[mllib] object AreaUnderCurve {
2727

2828
/**
2929
* Uses the trapezoidal rule to compute the area under the line connecting the two input points.
@@ -53,8 +53,8 @@ object AreaUnderCurve {
5353
*
5454
* @param curve an iterator over ordered 2D points stored in pairs representing a curve
5555
*/
56-
def of(curve: Iterator[(Double, Double)]): Double = {
57-
curve.sliding(2).withPartial(false).aggregate(0.0)(
56+
def of(curve: Iterable[(Double, Double)]): Double = {
57+
curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)(
5858
seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
5959
combop = _ + _
6060
)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package org.apache.spark.mllib.evaluation
2+
3+
import org.apache.spark.rdd.RDD
4+
import org.apache.spark.SparkContext._
5+
6+
class BinaryClassificationEvaluator(scoreAndLabel: RDD[(Double, Double)]) {
7+
8+
}
9+
10+
object BinaryClassificationEvaluator {
11+
12+
def get(rdd: RDD[(Double, Double)]) {
13+
// Create a bin for each distinct score value, count positives and negatives within each bin,
14+
// and then sort by score values in descending order.
15+
val counts = rdd.combineByKey(
16+
createCombiner = (label: Double) => new Counter(0L, 0L) += label,
17+
mergeValue = (c: Counter, label: Double) => c += label,
18+
mergeCombiners = (c1: Counter, c2: Counter) => c1 += c2
19+
).sortByKey(ascending = false)
20+
println(counts.collect().toList)
21+
val agg = counts.values.mapPartitions((iter: Iterator[Counter]) => {
22+
val agg = new Counter()
23+
iter.foreach(agg += _)
24+
Iterator(agg)
25+
}, preservesPartitioning = true).collect()
26+
println(agg.toList)
27+
val cum = agg.scanLeft(new Counter())((agg: Counter, c: Counter) => agg + c)
28+
val total = cum.last
29+
println(total)
30+
println(cum.toList)
31+
val cumCountsRdd = counts.mapPartitionsWithIndex((index: Int, iter: Iterator[(Double, Counter)]) => {
32+
val cumCount = cum(index)
33+
iter.map { case (score, c) =>
34+
cumCount += c
35+
(score, cumCount.clone())
36+
}
37+
}, preservesPartitioning = true)
38+
println("cum: " + cumCountsRdd.collect().toList)
39+
val rocAUC = AreaUnderCurve.of(cumCountsRdd.values.map((c: Counter) => {
40+
(1.0 * c.numNegatives / total.numNegatives,
41+
1.0 * c.numPositives / total.numPositives)
42+
}))
43+
println(rocAUC)
44+
val prAUC = AreaUnderCurve.of(cumCountsRdd.values.map((c: Counter) => {
45+
(1.0 * c.numPositives / total.numPositives,
46+
1.0 * c.numPositives / (c.numPositives + c.numNegatives))
47+
}))
48+
println(prAUC)
49+
}
50+
51+
def get(data: Iterable[(Double, Double)]) {
52+
val counts = data.groupBy(_._1).mapValues { s =>
53+
val c = new Counter()
54+
s.foreach(c += _._2)
55+
c
56+
}.toSeq.sortBy(- _._1)
57+
println("counts: " + counts.toList)
58+
val total = new Counter()
59+
val cum = counts.map { s =>
60+
total += s._2
61+
(s._1, total.clone())
62+
}
63+
println("cum: " + cum.toList)
64+
val roc = cum.map { case (s, c) =>
65+
(1.0 * c.numNegatives / total.numNegatives, 1.0 * c.numPositives / total.numPositives)
66+
}
67+
val rocAUC = AreaUnderCurve.of(roc)
68+
println(rocAUC)
69+
val pr = cum.map { case (s, c) =>
70+
(1.0 * c.numPositives / total.numPositives,
71+
1.0 * c.numPositives / (c.numPositives + c.numNegatives))
72+
}
73+
val prAUC = AreaUnderCurve.of(pr)
74+
println(prAUC)
75+
}
76+
}
77+
78+
class Counter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable {
79+
80+
def +=(label: Double): Counter = {
81+
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
82+
// -1.0 for negative as well.
83+
if (label > 0.5) numPositives += 1L else numNegatives += 1L
84+
this
85+
}
86+
87+
def +=(other: Counter): Counter = {
88+
numPositives += other.numPositives
89+
numNegatives += other.numNegatives
90+
this
91+
}
92+
93+
def +(label: Double): Counter = {
94+
this.clone() += label
95+
}
96+
97+
def +(other: Counter): Counter = {
98+
this.clone() += other
99+
}
100+
101+
def sum: Long = numPositives + numNegatives
102+
103+
override def clone(): Counter = {
104+
new Counter(numPositives, numNegatives)
105+
}
106+
107+
override def toString(): String = s"[$numPositives,$numNegatives]"
108+
}
109+

mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,21 @@ class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {
2626
test("auc computation") {
2727
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
2828
val auc = 4.0
29-
assert(AreaUnderCurve.of(curve.toIterator) === auc)
29+
assert(AreaUnderCurve.of(curve) === auc)
3030
val rddCurve = sc.parallelize(curve, 2)
3131
assert(AreaUnderCurve.of(rddCurve) == auc)
3232
}
3333

3434
test("auc of an empty curve") {
3535
val curve = Seq.empty[(Double, Double)]
36-
assert(AreaUnderCurve.of(curve.toIterator) === 0.0)
36+
assert(AreaUnderCurve.of(curve) === 0.0)
3737
val rddCurve = sc.parallelize(curve, 2)
3838
assert(AreaUnderCurve.of(rddCurve) === 0.0)
3939
}
4040

4141
test("auc of a curve with a single point") {
4242
val curve = Seq((1.0, 1.0))
43-
assert(AreaUnderCurve.of(curve.toIterator) === 0.0)
43+
assert(AreaUnderCurve.of(curve) === 0.0)
4444
val rddCurve = sc.parallelize(curve, 2)
4545
assert(AreaUnderCurve.of(rddCurve) === 0.0)
4646
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package org.apache.spark.mllib.evaluation
2+
3+
import org.scalatest.FunSuite
4+
import org.apache.spark.mllib.util.LocalSparkContext
5+
6+
class BinaryClassificationEvaluationSuite extends FunSuite with LocalSparkContext {
7+
test("test") {
8+
val data = Seq((0.0, 0.0), (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0), (0.9, 1.0))
9+
BinaryClassificationEvaluator.get(data)
10+
val rdd = sc.parallelize(data, 3)
11+
BinaryClassificationEvaluator.get(rdd)
12+
}
13+
}

0 commit comments

Comments
 (0)