|
| 1 | +package org.apache.spark.mllib.evaluation |
| 2 | + |
| 3 | +import org.apache.spark.rdd.RDD |
| 4 | +import org.apache.spark.SparkContext._ |
| 5 | + |
| 6 | +class BinaryClassificationEvaluator(scoreAndLabel: RDD[(Double, Double)]) { |
| 7 | + |
| 8 | +} |
| 9 | + |
| 10 | +object BinaryClassificationEvaluator { |
| 11 | + |
| 12 | + def get(rdd: RDD[(Double, Double)]) { |
| 13 | + // Create a bin for each distinct score value, count positives and negatives within each bin, |
| 14 | + // and then sort by score values in descending order. |
| 15 | + val counts = rdd.combineByKey( |
| 16 | + createCombiner = (label: Double) => new Counter(0L, 0L) += label, |
| 17 | + mergeValue = (c: Counter, label: Double) => c += label, |
| 18 | + mergeCombiners = (c1: Counter, c2: Counter) => c1 += c2 |
| 19 | + ).sortByKey(ascending = false) |
| 20 | + println(counts.collect().toList) |
| 21 | + val agg = counts.values.mapPartitions((iter: Iterator[Counter]) => { |
| 22 | + val agg = new Counter() |
| 23 | + iter.foreach(agg += _) |
| 24 | + Iterator(agg) |
| 25 | + }, preservesPartitioning = true).collect() |
| 26 | + println(agg.toList) |
| 27 | + val cum = agg.scanLeft(new Counter())((agg: Counter, c: Counter) => agg + c) |
| 28 | + val total = cum.last |
| 29 | + println(total) |
| 30 | + println(cum.toList) |
| 31 | + val cumCountsRdd = counts.mapPartitionsWithIndex((index: Int, iter: Iterator[(Double, Counter)]) => { |
| 32 | + val cumCount = cum(index) |
| 33 | + iter.map { case (score, c) => |
| 34 | + cumCount += c |
| 35 | + (score, cumCount.clone()) |
| 36 | + } |
| 37 | + }, preservesPartitioning = true) |
| 38 | + println("cum: " + cumCountsRdd.collect().toList) |
| 39 | + val rocAUC = AreaUnderCurve.of(cumCountsRdd.values.map((c: Counter) => { |
| 40 | + (1.0 * c.numNegatives / total.numNegatives, |
| 41 | + 1.0 * c.numPositives / total.numPositives) |
| 42 | + })) |
| 43 | + println(rocAUC) |
| 44 | + val prAUC = AreaUnderCurve.of(cumCountsRdd.values.map((c: Counter) => { |
| 45 | + (1.0 * c.numPositives / total.numPositives, |
| 46 | + 1.0 * c.numPositives / (c.numPositives + c.numNegatives)) |
| 47 | + })) |
| 48 | + println(prAUC) |
| 49 | + } |
| 50 | + |
| 51 | + def get(data: Iterable[(Double, Double)]) { |
| 52 | + val counts = data.groupBy(_._1).mapValues { s => |
| 53 | + val c = new Counter() |
| 54 | + s.foreach(c += _._2) |
| 55 | + c |
| 56 | + }.toSeq.sortBy(- _._1) |
| 57 | + println("counts: " + counts.toList) |
| 58 | + val total = new Counter() |
| 59 | + val cum = counts.map { s => |
| 60 | + total += s._2 |
| 61 | + (s._1, total.clone()) |
| 62 | + } |
| 63 | + println("cum: " + cum.toList) |
| 64 | + val roc = cum.map { case (s, c) => |
| 65 | + (1.0 * c.numNegatives / total.numNegatives, 1.0 * c.numPositives / total.numPositives) |
| 66 | + } |
| 67 | + val rocAUC = AreaUnderCurve.of(roc) |
| 68 | + println(rocAUC) |
| 69 | + val pr = cum.map { case (s, c) => |
| 70 | + (1.0 * c.numPositives / total.numPositives, |
| 71 | + 1.0 * c.numPositives / (c.numPositives + c.numNegatives)) |
| 72 | + } |
| 73 | + val prAUC = AreaUnderCurve.of(pr) |
| 74 | + println(prAUC) |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +class Counter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable { |
| 79 | + |
| 80 | + def +=(label: Double): Counter = { |
| 81 | + // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle |
| 82 | + // -1.0 for negative as well. |
| 83 | + if (label > 0.5) numPositives += 1L else numNegatives += 1L |
| 84 | + this |
| 85 | + } |
| 86 | + |
| 87 | + def +=(other: Counter): Counter = { |
| 88 | + numPositives += other.numPositives |
| 89 | + numNegatives += other.numNegatives |
| 90 | + this |
| 91 | + } |
| 92 | + |
| 93 | + def +(label: Double): Counter = { |
| 94 | + this.clone() += label |
| 95 | + } |
| 96 | + |
| 97 | + def +(other: Counter): Counter = { |
| 98 | + this.clone() += other |
| 99 | + } |
| 100 | + |
| 101 | + def sum: Long = numPositives + numNegatives |
| 102 | + |
| 103 | + override def clone(): Counter = { |
| 104 | + new Counter(numPositives, numNegatives) |
| 105 | + } |
| 106 | + |
| 107 | + override def toString(): String = s"[$numPositives,$numNegatives]" |
| 108 | +} |
| 109 | + |
0 commit comments