@@ -20,6 +20,7 @@ package org.apache.spark.mllib.evaluation
2020import org .apache .spark .Logging
2121import org .apache .spark .SparkContext ._
2222import org .apache .spark .annotation .Experimental
23+ import org .apache .spark .mllib .linalg .{Matrices , Matrix }
2324import org .apache .spark .rdd .RDD
2425
2526import scala .collection .Map
@@ -31,19 +32,19 @@ import scala.collection.Map
3132 * @param predictionAndLabels an RDD of (prediction, label) pairs.
3233 */
3334@ Experimental
34- class MulticlassMetrics (predictionAndLabels : RDD [(Double , Double )]) extends Logging {
35+ class MulticlassMetrics (predictionAndLabels : RDD [(Double , Double )]) {
3536
3637 private lazy val labelCountByClass : Map [Double , Long ] = predictionAndLabels.values.countByValue()
3738 private lazy val labelCount : Long = labelCountByClass.values.sum
3839 private lazy val tpByClass : Map [Double , Int ] = predictionAndLabels
3940 .map { case (prediction, label) =>
40- (label, if (label == prediction) 1 else 0 )
41- }.reduceByKey(_ + _)
41+ (label, if (label == prediction) 1 else 0 )
42+ }.reduceByKey(_ + _)
4243 .collectAsMap()
4344 private lazy val fpByClass : Map [Double , Int ] = predictionAndLabels
4445 .map { case (prediction, label) =>
45- (prediction, if (prediction != label) 1 else 0 )
46- }.reduceByKey(_ + _)
46+ (prediction, if (prediction != label) 1 else 0 )
47+ }.reduceByKey(_ + _)
4748 .collectAsMap()
4849 private lazy val confusions = predictionAndLabels.map {
4950 case (prediction, label) => ((prediction, label), 1 )
@@ -55,12 +56,13 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) extends Logg
5556 * they are ordered by class label ascending,
5657 * as in "labels"
5758 */
58- lazy val confusionMatrix : Array [ Array [ Int ]] = {
59- val matrix = Array .ofDim[Int ](labels.size, labels.size)
59+ lazy val confusionMatrix : Matrix = {
60+ val transposedMatrix = Array .ofDim[Double ](labels.size, labels.size)
6061 for (i <- 0 to labels.size - 1 ; j <- 0 to labels.size - 1 ) {
61- matrix(j)(i ) = confusions.getOrElse((labels(i), labels(j)), 0 )
62+ transposedMatrix(i)(j ) = confusions.getOrElse((labels(i), labels(j)), 0 ).toDouble
6263 }
63- matrix
64+ val flatMatrix = transposedMatrix.flatMap(arr => arr)
65+ Matrices .dense(transposedMatrix.length, transposedMatrix(0 ).length, flatMatrix)
6466 }
6567
6668 /**
0 commit comments