Skip to content

Commit 87fb11f

Browse files
committed
Addressing reviewers comments mengxr. Added confusion matrix
1 parent e3db569 commit 87fb11f

File tree

2 files changed

+66
-23
lines changed

2 files changed

+66
-23
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,52 @@
1717

1818
package org.apache.spark.mllib.evaluation
1919

20-
import org.apache.spark.annotation.Experimental
21-
import org.apache.spark.rdd.RDD
2220
import org.apache.spark.Logging
2321
import org.apache.spark.SparkContext._
22+
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.rdd.RDD
2424

2525
import scala.collection.Map
2626

2727
/**
2828
* ::Experimental::
2929
* Evaluator for multiclass classification.
3030
*
31-
* @param predictionsAndLabels an RDD of (prediction, label) pairs.
31+
* @param predictionAndLabels an RDD of (prediction, label) pairs.
3232
*/
3333
@Experimental
34-
class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Logging {
34+
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) extends Logging {
3535

36-
private lazy val labelCountByClass: Map[Double, Long] = predictionsAndLabels.values.countByValue()
36+
private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
3737
private lazy val labelCount: Long = labelCountByClass.values.sum
38-
private lazy val tpByClass: Map[Double, Int] = predictionsAndLabels
38+
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
3939
.map { case (prediction, label) =>
40-
(label, if (label == prediction) 1 else 0)
41-
}.reduceByKey(_ + _)
40+
(label, if (label == prediction) 1 else 0)
41+
}.reduceByKey(_ + _)
4242
.collectAsMap()
43-
private lazy val fpByClass: Map[Double, Int] = predictionsAndLabels
43+
private lazy val fpByClass: Map[Double, Int] = predictionAndLabels
4444
.map { case (prediction, label) =>
45-
(prediction, if (prediction != label) 1 else 0)
46-
}.reduceByKey(_ + _)
45+
(prediction, if (prediction != label) 1 else 0)
46+
}.reduceByKey(_ + _)
4747
.collectAsMap()
48+
private lazy val confusions = predictionAndLabels.map {
49+
case (prediction, label) => ((prediction, label), 1)
50+
}.reduceByKey(_ + _).collectAsMap()
51+
52+
/**
53+
* Returns confusion matrix:
54+
* predicted classes are in columns,
55+
* they are ordered by class label ascending,
56+
* as in "labels"
57+
*/
58+
lazy val confusionMatrix: Array[Array[Int]] = {
59+
val matrix = Array.ofDim[Int](labels.size, labels.size)
60+
println(matrix.length, matrix(0).length)
61+
for (i <- 0 to labels.size - 1; j <- 0 to labels.size - 1) {
62+
matrix(j)(i) = confusions.getOrElse((labels(i), labels(j)), 0)
63+
}
64+
matrix
65+
}
4866

4967
/**
5068
* Returns true positive rate for a given label (category)
@@ -103,8 +121,8 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
103121
/**
104122
* Returns recall
105123
* (equals to precision for multiclass classifier
106-
* because sum of all false positives is equal to sum
107-
* of all false negatives)
124+
* because sum of all false positives is equal to sum
125+
* of all false negatives)
108126
*/
109127
lazy val recall: Double = precision
110128

@@ -114,6 +132,19 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
114132
*/
115133
lazy val fMeasure: Double = precision
116134

135+
/**
136+
* Returns weighted true positive rate
137+
* (equals to precision, recall and f-measure)
138+
*/
139+
lazy val weightedTruePositiveRate: Double = weightedRecall
140+
141+
/**
142+
* Returns weighted false positive rate
143+
*/
144+
lazy val weightedFalsePositiveRate: Double = labelCountByClass.map { case (category, count) =>
145+
falsePositiveRate(category) * count.toDouble / labelCount
146+
}.sum
147+
117148
/**
118149
* Returns weighted averaged recall
119150
* (equals to precision, recall and f-measure)

mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,23 @@
1717

1818
package org.apache.spark.mllib.evaluation
1919

20-
import org.scalatest.FunSuite
21-
2220
import org.apache.spark.mllib.util.LocalSparkContext
21+
import org.scalatest.FunSuite
2322

2423
class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
2524
test("Multiclass evaluation metrics") {
2625
/*
27-
* Confusion matrix for 3-class classification with total 9 instances:
28-
* |2|1|1| true class0 (4 instances)
29-
* |1|3|0| true class1 (4 instances)
30-
* |0|0|1| true class2 (1 instance)
31-
*
32-
*/
26+
* Confusion matrix for 3-class classification with total 9 instances:
27+
* |2|1|1| true class0 (4 instances)
28+
* |1|3|0| true class1 (4 instances)
29+
* |0|0|1| true class2 (1 instance)
30+
*/
31+
val confusionMatrix = Array(Array(2, 1, 1), Array(1, 3, 0), Array(0, 0, 1))
3332
val labels = Array(0.0, 1.0, 2.0)
34-
val scoreAndLabels = sc.parallelize(
33+
val predictionAndLabels = sc.parallelize(
3534
Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
3635
(1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
37-
val metrics = new MulticlassMetrics(scoreAndLabels)
36+
val metrics = new MulticlassMetrics(predictionAndLabels)
3837
val delta = 0.0000001
3938
val fpRate0 = 1.0 / (9 - 4)
4039
val fpRate1 = 1.0 / (9 - 4)
@@ -48,6 +47,11 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
4847
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
4948
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
5049
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
50+
val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0)
51+
val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
52+
val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)
53+
54+
assert(metrics.confusionMatrix.deep == confusionMatrix.deep)
5155
assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta)
5256
assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta)
5357
assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta)
@@ -60,17 +64,25 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
6064
assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta)
6165
assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta)
6266
assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta)
67+
assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta)
68+
assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta)
69+
assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta)
70+
6371
assert(math.abs(metrics.recall -
6472
(2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta)
6573
assert(math.abs(metrics.recall - metrics.precision) < delta)
6674
assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
6775
assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
76+
assert(math.abs(metrics.weightedFalsePositiveRate -
77+
((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta)
6878
assert(math.abs(metrics.weightedPrecision -
6979
((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta)
7080
assert(math.abs(metrics.weightedRecall -
7181
((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta)
7282
assert(math.abs(metrics.weightedFMeasure -
7383
((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta)
84+
assert(math.abs(metrics.weightedFMeasure(2.0) -
85+
((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta)
7486
assert(metrics.labels.sameElements(labels))
7587
}
7688
}

0 commit comments

Comments
 (0)