Skip to content

Commit 04b01bb

Browse files
avulanovmengxr
authored andcommitted
[MLLIB] [SPARK-2222] Add multiclass evaluation metrics
Adding two classes: 1) MulticlassMetrics implements various multiclass evaluation metrics 2) MulticlassMetricsSuite implements unit tests for MulticlassMetrics Author: Alexander Ulanov <nashb@yandex.ru> Author: unknown <ulanov@ULANOV1.emea.hpqcorp.net> Author: Xiangrui Meng <meng@databricks.com> Closes apache#1155 from avulanov/master and squashes the following commits: 2eae80f [Alexander Ulanov] Merge pull request #1 from mengxr/avulanov-master 5ebeb08 [Xiangrui Meng] minor updates 79c3555 [Alexander Ulanov] Addressing reviewers comments mengxr 0fa9511 [Alexander Ulanov] Addressing reviewers comments mengxr f0dadc9 [Alexander Ulanov] Addressing reviewers comments mengxr 4811378 [Alexander Ulanov] Removing println 87fb11f [Alexander Ulanov] Addressing reviewers comments mengxr. Added confusion matrix e3db569 [Alexander Ulanov] Addressing reviewers comments mengxr. Added true positive rate and false positive rate. Test suite code style. a7e8bf0 [Alexander Ulanov] Addressing reviewers comments mengxr c3a77ad [Alexander Ulanov] Addressing reviewers comments mengxr e2c91c3 [Alexander Ulanov] Fixes to mutliclass metics d5ce981 [unknown] Comments about Double a5c8ba4 [unknown] Unit tests. Class rename fcee82d [unknown] Unit tests. Class rename d535d62 [unknown] Multiclass evaluation
1 parent 6555618 commit 04b01bb

File tree

2 files changed

+280
-0
lines changed

2 files changed

+280
-0
lines changed
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation
19+
20+
import scala.collection.Map
21+
22+
import org.apache.spark.SparkContext._
23+
import org.apache.spark.annotation.Experimental
24+
import org.apache.spark.mllib.linalg.{Matrices, Matrix}
25+
import org.apache.spark.rdd.RDD
26+
27+
/**
28+
* ::Experimental::
29+
* Evaluator for multiclass classification.
30+
*
31+
* @param predictionAndLabels an RDD of (prediction, label) pairs.
32+
*/
33+
@Experimental
34+
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
35+
36+
private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
37+
private lazy val labelCount: Long = labelCountByClass.values.sum
38+
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
39+
.map { case (prediction, label) =>
40+
(label, if (label == prediction) 1 else 0)
41+
}.reduceByKey(_ + _)
42+
.collectAsMap()
43+
private lazy val fpByClass: Map[Double, Int] = predictionAndLabels
44+
.map { case (prediction, label) =>
45+
(prediction, if (prediction != label) 1 else 0)
46+
}.reduceByKey(_ + _)
47+
.collectAsMap()
48+
private lazy val confusions = predictionAndLabels
49+
.map { case (prediction, label) =>
50+
((label, prediction), 1)
51+
}.reduceByKey(_ + _)
52+
.collectAsMap()
53+
54+
/**
55+
* Returns confusion matrix:
56+
* predicted classes are in columns,
57+
* they are ordered by class label ascending,
58+
* as in "labels"
59+
*/
60+
def confusionMatrix: Matrix = {
61+
val n = labels.size
62+
val values = Array.ofDim[Double](n * n)
63+
var i = 0
64+
while (i < n) {
65+
var j = 0
66+
while (j < n) {
67+
values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble
68+
j += 1
69+
}
70+
i += 1
71+
}
72+
Matrices.dense(n, n, values)
73+
}
74+
75+
/**
76+
* Returns true positive rate for a given label (category)
77+
* @param label the label.
78+
*/
79+
def truePositiveRate(label: Double): Double = recall(label)
80+
81+
/**
82+
* Returns false positive rate for a given label (category)
83+
* @param label the label.
84+
*/
85+
def falsePositiveRate(label: Double): Double = {
86+
val fp = fpByClass.getOrElse(label, 0)
87+
fp.toDouble / (labelCount - labelCountByClass(label))
88+
}
89+
90+
/**
91+
* Returns precision for a given label (category)
92+
* @param label the label.
93+
*/
94+
def precision(label: Double): Double = {
95+
val tp = tpByClass(label)
96+
val fp = fpByClass.getOrElse(label, 0)
97+
if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
98+
}
99+
100+
/**
101+
* Returns recall for a given label (category)
102+
* @param label the label.
103+
*/
104+
def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label)
105+
106+
/**
107+
* Returns f-measure for a given label (category)
108+
* @param label the label.
109+
* @param beta the beta parameter.
110+
*/
111+
def fMeasure(label: Double, beta: Double): Double = {
112+
val p = precision(label)
113+
val r = recall(label)
114+
val betaSqrd = beta * beta
115+
if (p + r == 0) 0 else (1 + betaSqrd) * p * r / (betaSqrd * p + r)
116+
}
117+
118+
/**
119+
* Returns f1-measure for a given label (category)
120+
* @param label the label.
121+
*/
122+
def fMeasure(label: Double): Double = fMeasure(label, 1.0)
123+
124+
/**
125+
* Returns precision
126+
*/
127+
lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount
128+
129+
/**
130+
* Returns recall
131+
* (equals to precision for multiclass classifier
132+
* because sum of all false positives is equal to sum
133+
* of all false negatives)
134+
*/
135+
lazy val recall: Double = precision
136+
137+
/**
138+
* Returns f-measure
139+
* (equals to precision and recall because precision equals recall)
140+
*/
141+
lazy val fMeasure: Double = precision
142+
143+
/**
144+
* Returns weighted true positive rate
145+
* (equals to precision, recall and f-measure)
146+
*/
147+
lazy val weightedTruePositiveRate: Double = weightedRecall
148+
149+
/**
150+
* Returns weighted false positive rate
151+
*/
152+
lazy val weightedFalsePositiveRate: Double = labelCountByClass.map { case (category, count) =>
153+
falsePositiveRate(category) * count.toDouble / labelCount
154+
}.sum
155+
156+
/**
157+
* Returns weighted averaged recall
158+
* (equals to precision, recall and f-measure)
159+
*/
160+
lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) =>
161+
recall(category) * count.toDouble / labelCount
162+
}.sum
163+
164+
/**
165+
* Returns weighted averaged precision
166+
*/
167+
lazy val weightedPrecision: Double = labelCountByClass.map { case (category, count) =>
168+
precision(category) * count.toDouble / labelCount
169+
}.sum
170+
171+
/**
172+
* Returns weighted averaged f-measure
173+
* @param beta the beta parameter.
174+
*/
175+
def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) =>
176+
fMeasure(category, beta) * count.toDouble / labelCount
177+
}.sum
178+
179+
/**
180+
* Returns weighted averaged f1-measure
181+
*/
182+
lazy val weightedFMeasure: Double = labelCountByClass.map { case (category, count) =>
183+
fMeasure(category, 1.0) * count.toDouble / labelCount
184+
}.sum
185+
186+
/**
187+
* Returns the sequence of labels in ascending order
188+
*/
189+
lazy val labels: Array[Double] = tpByClass.keys.toArray.sorted
190+
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.mllib.linalg.Matrices
23+
import org.apache.spark.mllib.util.LocalSparkContext
24+
25+
class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
26+
test("Multiclass evaluation metrics") {
27+
/*
28+
* Confusion matrix for 3-class classification with total 9 instances:
29+
* |2|1|1| true class0 (4 instances)
30+
* |1|3|0| true class1 (4 instances)
31+
* |0|0|1| true class2 (1 instance)
32+
*/
33+
val confusionMatrix = Matrices.dense(3, 3, Array(2, 1, 0, 1, 3, 0, 1, 0, 1))
34+
val labels = Array(0.0, 1.0, 2.0)
35+
val predictionAndLabels = sc.parallelize(
36+
Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
37+
(1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
38+
val metrics = new MulticlassMetrics(predictionAndLabels)
39+
val delta = 0.0000001
40+
val fpRate0 = 1.0 / (9 - 4)
41+
val fpRate1 = 1.0 / (9 - 4)
42+
val fpRate2 = 1.0 / (9 - 1)
43+
val precision0 = 2.0 / (2 + 1)
44+
val precision1 = 3.0 / (3 + 1)
45+
val precision2 = 1.0 / (1 + 1)
46+
val recall0 = 2.0 / (2 + 2)
47+
val recall1 = 3.0 / (3 + 1)
48+
val recall2 = 1.0 / (1 + 0)
49+
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
50+
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
51+
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
52+
val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0)
53+
val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
54+
val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)
55+
56+
assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray))
57+
assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta)
58+
assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta)
59+
assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta)
60+
assert(math.abs(metrics.precision(0.0) - precision0) < delta)
61+
assert(math.abs(metrics.precision(1.0) - precision1) < delta)
62+
assert(math.abs(metrics.precision(2.0) - precision2) < delta)
63+
assert(math.abs(metrics.recall(0.0) - recall0) < delta)
64+
assert(math.abs(metrics.recall(1.0) - recall1) < delta)
65+
assert(math.abs(metrics.recall(2.0) - recall2) < delta)
66+
assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta)
67+
assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta)
68+
assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta)
69+
assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta)
70+
assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta)
71+
assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta)
72+
73+
assert(math.abs(metrics.recall -
74+
(2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta)
75+
assert(math.abs(metrics.recall - metrics.precision) < delta)
76+
assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
77+
assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
78+
assert(math.abs(metrics.weightedFalsePositiveRate -
79+
((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta)
80+
assert(math.abs(metrics.weightedPrecision -
81+
((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta)
82+
assert(math.abs(metrics.weightedRecall -
83+
((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta)
84+
assert(math.abs(metrics.weightedFMeasure -
85+
((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta)
86+
assert(math.abs(metrics.weightedFMeasure(2.0) -
87+
((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta)
88+
assert(metrics.labels.sameElements(labels))
89+
}
90+
}

0 commit comments

Comments
 (0)