Skip to content

Commit f0dadc9

Browse files
committed
Addressing reviewers comments mengxr
1 parent 4811378 commit f0dadc9

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.evaluation
2020
import org.apache.spark.Logging
2121
import org.apache.spark.SparkContext._
2222
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.mllib.linalg.{Matrices, Matrix}
2324
import org.apache.spark.rdd.RDD
2425

2526
import scala.collection.Map
@@ -31,19 +32,19 @@ import scala.collection.Map
3132
* @param predictionAndLabels an RDD of (prediction, label) pairs.
3233
*/
3334
@Experimental
34-
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) extends Logging {
35+
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
3536

3637
private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
3738
private lazy val labelCount: Long = labelCountByClass.values.sum
3839
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
3940
.map { case (prediction, label) =>
40-
(label, if (label == prediction) 1 else 0)
41-
}.reduceByKey(_ + _)
41+
(label, if (label == prediction) 1 else 0)
42+
}.reduceByKey(_ + _)
4243
.collectAsMap()
4344
private lazy val fpByClass: Map[Double, Int] = predictionAndLabels
4445
.map { case (prediction, label) =>
45-
(prediction, if (prediction != label) 1 else 0)
46-
}.reduceByKey(_ + _)
46+
(prediction, if (prediction != label) 1 else 0)
47+
}.reduceByKey(_ + _)
4748
.collectAsMap()
4849
private lazy val confusions = predictionAndLabels.map {
4950
case (prediction, label) => ((prediction, label), 1)
@@ -55,12 +56,13 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) extends Logg
5556
* they are ordered by class label ascending,
5657
* as in "labels"
5758
*/
58-
lazy val confusionMatrix: Array[Array[Int]] = {
59-
val matrix = Array.ofDim[Int](labels.size, labels.size)
59+
lazy val confusionMatrix: Matrix = {
60+
val transposedMatrix = Array.ofDim[Double](labels.size, labels.size)
6061
for (i <- 0 to labels.size - 1; j <- 0 to labels.size - 1) {
61-
matrix(j)(i) = confusions.getOrElse((labels(i), labels(j)), 0)
62+
transposedMatrix(i)(j) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble
6263
}
63-
matrix
64+
val flatMatrix = transposedMatrix.flatMap(arr => arr)
65+
Matrices.dense(transposedMatrix.length, transposedMatrix(0).length, flatMatrix)
6466
}
6567

6668
/**

mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.mllib.evaluation
1919

20+
import org.apache.spark.mllib.linalg.Matrices
2021
import org.apache.spark.mllib.util.LocalSparkContext
2122
import org.scalatest.FunSuite
2223

@@ -28,7 +29,7 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
2829
* |1|3|0| true class1 (4 instances)
2930
* |0|0|1| true class2 (1 instance)
3031
*/
31-
val confusionMatrix = Array(Array(2, 1, 1), Array(1, 3, 0), Array(0, 0, 1))
32+
val confusionMatrix = Matrices.dense(3, 3, Array(2, 1, 0, 1, 3, 0, 1, 0, 1))
3233
val labels = Array(0.0, 1.0, 2.0)
3334
val predictionAndLabels = sc.parallelize(
3435
Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
@@ -51,7 +52,7 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
5152
val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
5253
val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)
5354

54-
assert(metrics.confusionMatrix.deep == confusionMatrix.deep)
55+
assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray))
5556
assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta)
5657
assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta)
5758
assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta)

0 commit comments

Comments
 (0)