1
+ /*
2
+ * Licensed to the Apache Software Foundation (ASF) under one or more
3
+ * contributor license agreements. See the NOTICE file distributed with
4
+ * this work for additional information regarding copyright ownership.
5
+ * The ASF licenses this file to You under the Apache License, Version 2.0
6
+ * (the "License"); you may not use this file except in compliance with
7
+ * the License. You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ package org .apache .spark .mllib .evaluation
19
+
20
+ import org .apache .spark .rdd .RDD
21
+ import org .apache .spark .Logging
22
+ import org .apache .spark .SparkContext ._
23
+
24
+ /**
25
+ * Evaluator for multiclass classification.
26
+ *
27
+ * @param scoreAndLabels an RDD of (score, label) pairs.
28
+ */
29
+ class MulticlassEvaluator (scoreAndLabels : RDD [(Double , Double )]) extends Logging {
30
+
31
+ /* class = category; label = instance of class; prediction = instance of class */
32
+
33
+ private lazy val labelCountByClass = scoreAndLabels.values.countByValue()
34
+ private lazy val labelCount = labelCountByClass.foldLeft(0L ){case (sum, (_, count)) => sum + count}
35
+ private lazy val tpByClass = scoreAndLabels.map{ case (prediction, label) => (label, if (label == prediction) 1 else 0 ) }.reduceByKey{_ + _}.collectAsMap
36
+ private lazy val fpByClass = scoreAndLabels.map{ case (prediction, label) => (prediction, if (prediction != label) 1 else 0 ) }.reduceByKey{_ + _}.collectAsMap
37
+
38
+ /**
39
+ * Returns Precision for a given label (category)
40
+ * @param label the label.
41
+ * @return Precision.
42
+ */
43
+ def precision (label : Double ): Double = if (tpByClass(label) + fpByClass.getOrElse(label, 0 ) == 0 ) 0
44
+ else tpByClass(label).toDouble / (tpByClass(label) + fpByClass.getOrElse(label, 0 )).toDouble
45
+
46
+ /**
47
+ * Returns Recall for a given label (category)
48
+ * @param label the label.
49
+ * @return Recall.
50
+ */
51
+ def recall (label : Double ): Double = tpByClass(label).toDouble / labelCountByClass(label).toDouble
52
+
53
+ /**
54
+ * Returns F1-measure for a given label (category)
55
+ * @param label the label.
56
+ * @return F1-measure.*/
57
+ def f1Measure (label : Double ): Double = 2 * precision(label) * recall(label) / (precision(label) + recall(label))
58
+
59
+ /**
60
+ * Returns micro-averaged Recall (equals to microPrecision and microF1measure for multiclass classifier)
61
+ * @return microRecall.
62
+ */
63
+ def microRecall : Double = tpByClass.foldLeft(0L ){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount.toDouble
64
+
65
+ /**
66
+ * Returns micro-averaged Precision (equals to microPrecision and microF1measure for multiclass classifier)
67
+ * @return microPrecision.
68
+ */
69
+ def microPrecision : Double = microRecall
70
+
71
+ /**
72
+ * Returns micro-averaged F1-measure (equals to microPrecision and microRecall for multiclass classifier)
73
+ * @return microF1measure.
74
+ */
75
+ def microF1Measure : Double = microRecall
76
+
77
+ /**
78
+ * Returns weighted averaged Recall
79
+ * @return weightedRecall.
80
+ */
81
+ def weightedRecall : Double = labelCountByClass.foldLeft(0.0 ){case (wRecall, (category, count)) => wRecall + recall(category) * count.toDouble / labelCount.toDouble}
82
+
83
+ /**
84
+ * Returns weighted averaged Precision
85
+ * @return weightedPrecision.
86
+ */
87
+ def weightedPrecision : Double = labelCountByClass.foldLeft(0.0 ){case (wPrecision, (category, count)) => wPrecision + precision(category) * count.toDouble / labelCount.toDouble}
88
+
89
+ /**
90
+ * Returns weighted averaged F1-measure
91
+ * @return weightedF1Measure.
92
+ */
93
+ def weightedF1Measure : Double = 2 * weightedPrecision * weightedRecall / (weightedPrecision + weightedRecall)
94
+
95
+ /**
96
+ * Returns map with Precisions for individual classes
97
+ * @return precisionPerClass.
98
+ */
99
+ def precisionPerClass = labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap
100
+
101
+ /**
102
+ * Returns map with Recalls for individual classes
103
+ * @return recallPerClass.
104
+ */
105
+ def recallPerClass = labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap
106
+
107
+ /**
108
+ * Returns map with F1-measures for individual classes
109
+ * @return f1MeasurePerClass.
110
+ */
111
+ def f1MeasurePerClass = labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap
112
+ }
0 commit comments