Skip to content

Commit a5c8ba4

Browse files
committed
Unit tests. Class rename
1 parent fcee82d commit a5c8ba4

File tree

2 files changed

+199
-0
lines changed

2 files changed

+199
-0
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation
19+
20+
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.Logging
22+
import org.apache.spark.SparkContext._
23+
24+
/**
25+
* Evaluator for multiclass classification.
26+
*
27+
* @param scoreAndLabels an RDD of (score, label) pairs.
28+
*/
29+
class MulticlassMetrics(scoreAndLabels: RDD[(Double, Double)]) extends Logging {
30+
31+
/* class = category; label = instance of class; prediction = instance of class */
32+
33+
private lazy val labelCountByClass = scoreAndLabels.values.countByValue()
34+
private lazy val labelCount = labelCountByClass.foldLeft(0L){case(sum, (_, count)) => sum + count}
35+
private lazy val tpByClass = scoreAndLabels.map{ case (prediction, label) =>
36+
(label, if(label == prediction) 1 else 0) }.reduceByKey{_ + _}.collectAsMap
37+
private lazy val fpByClass = scoreAndLabels.map{ case (prediction, label) =>
38+
(prediction, if(prediction != label) 1 else 0) }.reduceByKey{_ + _}.collectAsMap
39+
40+
/**
41+
* Returns Precision for a given label (category)
42+
* @param label the label.
43+
* @return Precision.
44+
*/
45+
def precision(label: Double): Double = if(tpByClass(label) + fpByClass.getOrElse(label, 0) == 0) 0
46+
else tpByClass(label).toDouble / (tpByClass(label) + fpByClass.getOrElse(label, 0)).toDouble
47+
48+
/**
49+
* Returns Recall for a given label (category)
50+
* @param label the label.
51+
* @return Recall.
52+
*/
53+
def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label).toDouble
54+
55+
/**
56+
* Returns F1-measure for a given label (category)
57+
* @param label the label.
58+
* @return F1-measure.
59+
*/
60+
def f1Measure(label: Double): Double =
61+
2 * precision(label) * recall(label) / (precision(label) + recall(label))
62+
63+
/**
64+
* Returns micro-averaged Recall
65+
* (equals to microPrecision and microF1measure for multiclass classifier)
66+
* @return microRecall.
67+
*/
68+
def microRecall: Double =
69+
tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount.toDouble
70+
71+
/**
72+
* Returns micro-averaged Precision
73+
* (equals to microPrecision and microF1measure for multiclass classifier)
74+
* @return microPrecision.
75+
*/
76+
def microPrecision: Double = microRecall
77+
78+
/**
79+
* Returns micro-averaged F1-measure
80+
* (equals to microPrecision and microRecall for multiclass classifier)
81+
* @return microF1measure.
82+
*/
83+
def microF1Measure: Double = microRecall
84+
85+
/**
86+
* Returns weighted averaged Recall
87+
* @return weightedRecall.
88+
*/
89+
def weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) =>
90+
wRecall + recall(category) * count.toDouble / labelCount.toDouble}
91+
92+
/**
93+
* Returns weighted averaged Precision
94+
* @return weightedPrecision.
95+
*/
96+
def weightedPrecision: Double =
97+
labelCountByClass.foldLeft(0.0){case(wPrecision, (category, count)) =>
98+
wPrecision + precision(category) * count.toDouble / labelCount.toDouble}
99+
100+
/**
101+
* Returns weighted averaged F1-measure
102+
* @return weightedF1Measure.
103+
*/
104+
def weightedF1Measure: Double =
105+
labelCountByClass.foldLeft(0.0){case(wF1measure, (category, count)) =>
106+
wF1measure + f1Measure(category) * count.toDouble / labelCount.toDouble}
107+
108+
/**
109+
* Returns map with Precisions for individual classes
110+
* @return precisionPerClass.
111+
*/
112+
def precisionPerClass =
113+
labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap
114+
115+
/**
116+
* Returns map with Recalls for individual classes
117+
* @return recallPerClass.
118+
*/
119+
def recallPerClass =
120+
labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap
121+
122+
/**
123+
* Returns map with F1-measures for individual classes
124+
* @return f1MeasurePerClass.
125+
*/
126+
def f1MeasurePerClass =
127+
labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap
128+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation
19+
20+
import org.apache.spark.mllib.util.LocalSparkContext
21+
import org.scalatest.FunSuite
22+
23+
class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
24+
test("Multiclass evaluation metrics") {
25+
/*
26+
* Confusion matrix for 3-class classification with total 9 instances:
27+
* |2|1|1| true class0 (4 instances)
28+
* |1|3|0| true class1 (4 instances)
29+
* |0|0|1| true class2 (1 instance)
30+
*
31+
*/
32+
val scoreAndLabels = sc.parallelize(
33+
Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
34+
(1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
35+
val metrics = new MulticlassMetrics(scoreAndLabels)
36+
37+
val delta = 0.00001
38+
val precision0 = 2.0 / (2.0 + 1.0)
39+
val precision1 = 3.0 / (3.0 + 1.0)
40+
val precision2 = 1.0 / (1.0 + 1.0)
41+
val recall0 = 2.0 / (2.0 + 2.0)
42+
val recall1 = 3.0 / (3.0 + 1.0)
43+
val recall2 = 1.0 / (1.0 + 0.0)
44+
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
45+
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
46+
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
47+
48+
assert(math.abs(metrics.precision(0.0) - precision0) < delta)
49+
assert(math.abs(metrics.precision(1.0) - precision1) < delta)
50+
assert(math.abs(metrics.precision(2.0) - precision2) < delta)
51+
assert(math.abs(metrics.recall(0.0) - recall0) < delta)
52+
assert(math.abs(metrics.recall(1.0) - recall1) < delta)
53+
assert(math.abs(metrics.recall(2.0) - recall2) < delta)
54+
assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta)
55+
assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta)
56+
assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta)
57+
58+
assert(math.abs(metrics.microRecall -
59+
(2.0 + 3.0 + 1.0) / ((2.0 + 3.0 + 1.0) + (1.0 + 1.0 + 1.0))) < delta)
60+
assert(math.abs(metrics.microRecall - metrics.microPrecision) < delta)
61+
assert(math.abs(metrics.microRecall - metrics.microF1Measure) < delta)
62+
assert(math.abs(metrics.microRecall - metrics.weightedRecall) < delta)
63+
assert(math.abs(metrics.weightedPrecision -
64+
((4.0 / 9.0) * precision0 + (4.0 / 9.0) * precision1 + (1.0 / 9.0) * precision2)) < delta)
65+
assert(math.abs(metrics.weightedRecall -
66+
((4.0 / 9.0) * recall0 + (4.0 / 9.0) * recall1 + (1.0 / 9.0) * recall2)) < delta)
67+
assert(math.abs(metrics.weightedF1Measure -
68+
((4.0 / 9.0) * f1measure0 + (4.0 / 9.0) * f1measure1 + (1.0 / 9.0) * f1measure2)) < delta)
69+
70+
}
71+
}

0 commit comments

Comments
 (0)