|
17 | 17 |
|
18 | 18 | package org.apache.spark.mllib.evaluation
|
19 | 19 |
|
20 |
| -import org.apache.spark.Logging |
| 20 | +import scala.collection.Map |
| 21 | + |
21 | 22 | import org.apache.spark.SparkContext._
|
22 | 23 | import org.apache.spark.annotation.Experimental
|
23 | 24 | import org.apache.spark.mllib.linalg.{Matrices, Matrix}
|
24 | 25 | import org.apache.spark.rdd.RDD
|
25 | 26 |
|
26 |
| -import scala.collection.Map |
27 |
| - |
28 | 27 | /**
|
29 | 28 | * ::Experimental::
|
30 | 29 | * Evaluator for multiclass classification.
|
@@ -57,12 +56,12 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
|
57 | 56 | * as in "labels"
|
58 | 57 | */
|
59 | 58 | lazy val confusionMatrix: Matrix = {
|
60 |
| - val transposedMatrix = Array.ofDim[Double](labels.size, labels.size) |
| 59 | + val transposedFlatMatrix = Array.ofDim[Double](labels.size * labels.size) |
61 | 60 | for (i <- 0 to labels.size - 1; j <- 0 to labels.size - 1) {
|
62 |
| - transposedMatrix(i)(j) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble |
| 61 | + transposedFlatMatrix(i * labels.size + j) |
| 62 | + = confusions.getOrElse((labels(i), labels(j)), 0).toDouble |
63 | 63 | }
|
64 |
| - val flatMatrix = transposedMatrix.flatMap(arr => arr) |
65 |
| - Matrices.dense(transposedMatrix.length, transposedMatrix(0).length, flatMatrix) |
| 64 | + Matrices.dense(labels.size, labels.size, transposedFlatMatrix) |
66 | 65 | }
|
67 | 66 |
|
68 | 67 | /**
|
|
0 commit comments