1717
1818package org .apache .spark .mllib .evaluation
1919
20- import org .scalatest .FunSuite
21-
2220import org .apache .spark .mllib .util .LocalSparkContext
21+ import org .scalatest .FunSuite
2322
2423class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
2524 test(" Multiclass evaluation metrics" ) {
2625 /*
27- * Confusion matrix for 3-class classification with total 9 instances:
28- * |2|1|1| true class0 (4 instances)
29- * |1|3|0| true class1 (4 instances)
30- * |0|0|1| true class2 (1 instance)
31- *
32- */
26+ * Confusion matrix for 3-class classification with total 9 instances:
27+ * |2|1|1| true class0 (4 instances)
28+ * |1|3|0| true class1 (4 instances)
29+ * |0|0|1| true class2 (1 instance)
30+ */
31+ val confusionMatrix = Array ( Array ( 2 , 1 , 1 ), Array ( 1 , 3 , 0 ), Array ( 0 , 0 , 1 ))
3332 val labels = Array (0.0 , 1.0 , 2.0 )
34- val scoreAndLabels = sc.parallelize(
33+ val predictionAndLabels = sc.parallelize(
3534 Seq ((0.0 , 0.0 ), (0.0 , 1.0 ), (0.0 , 0.0 ), (1.0 , 0.0 ), (1.0 , 1.0 ),
3635 (1.0 , 1.0 ), (1.0 , 1.0 ), (2.0 , 2.0 ), (2.0 , 0.0 )), 2 )
37- val metrics = new MulticlassMetrics (scoreAndLabels )
36+ val metrics = new MulticlassMetrics (predictionAndLabels )
3837 val delta = 0.0000001
3938 val fpRate0 = 1.0 / (9 - 4 )
4039 val fpRate1 = 1.0 / (9 - 4 )
@@ -48,6 +47,11 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
4847 val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
4948 val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
5049 val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
50+ val f2measure0 = (1 + 2 * 2 ) * precision0 * recall0 / (2 * 2 * precision0 + recall0)
51+ val f2measure1 = (1 + 2 * 2 ) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
52+ val f2measure2 = (1 + 2 * 2 ) * precision2 * recall2 / (2 * 2 * precision2 + recall2)
53+
54+ assert(metrics.confusionMatrix.deep == confusionMatrix.deep)
5155 assert(math.abs(metrics.falsePositiveRate(0.0 ) - fpRate0) < delta)
5256 assert(math.abs(metrics.falsePositiveRate(1.0 ) - fpRate1) < delta)
5357 assert(math.abs(metrics.falsePositiveRate(2.0 ) - fpRate2) < delta)
@@ -60,17 +64,25 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
6064 assert(math.abs(metrics.fMeasure(0.0 ) - f1measure0) < delta)
6165 assert(math.abs(metrics.fMeasure(1.0 ) - f1measure1) < delta)
6266 assert(math.abs(metrics.fMeasure(2.0 ) - f1measure2) < delta)
67+ assert(math.abs(metrics.fMeasure(0.0 , 2.0 ) - f2measure0) < delta)
68+ assert(math.abs(metrics.fMeasure(1.0 , 2.0 ) - f2measure1) < delta)
69+ assert(math.abs(metrics.fMeasure(2.0 , 2.0 ) - f2measure2) < delta)
70+
6371 assert(math.abs(metrics.recall -
6472 (2.0 + 3.0 + 1.0 ) / ((2 + 3 + 1 ) + (1 + 1 + 1 ))) < delta)
6573 assert(math.abs(metrics.recall - metrics.precision) < delta)
6674 assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
6775 assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
76+ assert(math.abs(metrics.weightedFalsePositiveRate -
77+ ((4.0 / 9 ) * fpRate0 + (4.0 / 9 ) * fpRate1 + (1.0 / 9 ) * fpRate2)) < delta)
6878 assert(math.abs(metrics.weightedPrecision -
6979 ((4.0 / 9 ) * precision0 + (4.0 / 9 ) * precision1 + (1.0 / 9 ) * precision2)) < delta)
7080 assert(math.abs(metrics.weightedRecall -
7181 ((4.0 / 9 ) * recall0 + (4.0 / 9 ) * recall1 + (1.0 / 9 ) * recall2)) < delta)
7282 assert(math.abs(metrics.weightedFMeasure -
7383 ((4.0 / 9 ) * f1measure0 + (4.0 / 9 ) * f1measure1 + (1.0 / 9 ) * f1measure2)) < delta)
84+ assert(math.abs(metrics.weightedFMeasure(2.0 ) -
85+ ((4.0 / 9 ) * f2measure0 + (4.0 / 9 ) * f2measure1 + (1.0 / 9 ) * f2measure2)) < delta)
7486 assert(metrics.labels.sameElements(labels))
7587 }
7688}
0 commit comments