Skip to content

Commit d535d62

Browse files
committed
Multiclass evaluation
1 parent 67fca18 commit d535d62

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation
19+
20+
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.Logging
22+
import org.apache.spark.SparkContext._
23+
24+
/**
25+
* Evaluator for multiclass classification.
26+
*
27+
* @param scoreAndLabels an RDD of (score, label) pairs.
28+
*/
29+
class MulticlassEvaluator(scoreAndLabels: RDD[(Double, Double)]) extends Logging {
30+
31+
/* class = category; label = instance of class; prediction = instance of class */
32+
33+
private lazy val labelCountByClass = scoreAndLabels.values.countByValue()
34+
private lazy val labelCount = labelCountByClass.foldLeft(0L){case(sum, (_, count)) => sum + count}
35+
private lazy val tpByClass = scoreAndLabels.map{ case (prediction, label) => (label, if(label == prediction) 1 else 0) }.reduceByKey{_ + _}.collectAsMap
36+
private lazy val fpByClass = scoreAndLabels.map{ case (prediction, label) => (prediction, if(prediction != label) 1 else 0) }.reduceByKey{_ + _}.collectAsMap
37+
38+
/**
39+
* Returns Precision for a given label (category)
40+
* @param label the label.
41+
* @return Precision.
42+
*/
43+
def precision(label: Double): Double = if(tpByClass(label) + fpByClass.getOrElse(label, 0) == 0) 0
44+
else tpByClass(label).toDouble / (tpByClass(label) + fpByClass.getOrElse(label, 0)).toDouble
45+
46+
/**
47+
* Returns Recall for a given label (category)
48+
* @param label the label.
49+
* @return Recall.
50+
*/
51+
def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label).toDouble
52+
53+
/**
54+
* Returns F1-measure for a given label (category)
55+
* @param label the label.
56+
* @return F1-measure.*/
57+
def f1Measure(label: Double): Double = 2 * precision(label) * recall(label) / (precision(label) + recall(label))
58+
59+
/**
60+
* Returns micro-averaged Recall (equals to microPrecision and microF1measure for multiclass classifier)
61+
* @return microRecall.
62+
*/
63+
def microRecall: Double = tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount.toDouble
64+
65+
/**
66+
* Returns micro-averaged Precision (equals to microPrecision and microF1measure for multiclass classifier)
67+
* @return microPrecision.
68+
*/
69+
def microPrecision: Double = microRecall
70+
71+
/**
72+
* Returns micro-averaged F1-measure (equals to microPrecision and microRecall for multiclass classifier)
73+
* @return microF1measure.
74+
*/
75+
def microF1Measure: Double = microRecall
76+
77+
/**
78+
* Returns weighted averaged Recall
79+
* @return weightedRecall.
80+
*/
81+
def weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) => wRecall + recall(category) * count.toDouble / labelCount.toDouble}
82+
83+
/**
84+
* Returns weighted averaged Precision
85+
* @return weightedPrecision.
86+
*/
87+
def weightedPrecision: Double = labelCountByClass.foldLeft(0.0){case(wPrecision, (category, count)) => wPrecision + precision(category) * count.toDouble / labelCount.toDouble}
88+
89+
/**
90+
* Returns weighted averaged F1-measure
91+
* @return weightedF1Measure.
92+
*/
93+
def weightedF1Measure: Double = 2 * weightedPrecision * weightedRecall / (weightedPrecision + weightedRecall)
94+
95+
/**
96+
* Returns map with Precisions for individual classes
97+
* @return precisionPerClass.
98+
*/
99+
def precisionPerClass = labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap
100+
101+
/**
102+
* Returns map with Recalls for individual classes
103+
* @return recallPerClass.
104+
*/
105+
def recallPerClass = labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap
106+
107+
/**
108+
* Returns map with F1-measures for individual classes
109+
* @return f1MeasurePerClass.
110+
*/
111+
def f1MeasurePerClass = labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap
112+
}

0 commit comments

Comments
 (0)