|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.mllib.evaluation |
| 19 | + |
| 20 | +import org.apache.spark.rdd.RDD |
| 21 | +import org.apache.spark.Logging |
| 22 | +import org.apache.spark.SparkContext._ |
| 23 | + |
| 24 | +/** |
| 25 | + * Evaluator for multiclass classification. |
| 26 | + * |
| 27 | + * @param scoreAndLabels an RDD of (score, label) pairs. |
| 28 | + */ |
| 29 | +class MulticlassMetrics(scoreAndLabels: RDD[(Double, Double)]) extends Logging { |
| 30 | + |
| 31 | + /* class = category; label = instance of class; prediction = instance of class */ |
| 32 | + |
| 33 | + private lazy val labelCountByClass = scoreAndLabels.values.countByValue() |
| 34 | + private lazy val labelCount = labelCountByClass.foldLeft(0L){case(sum, (_, count)) => sum + count} |
| 35 | + private lazy val tpByClass = scoreAndLabels.map{ case (prediction, label) => |
| 36 | + (label, if(label == prediction) 1 else 0) }.reduceByKey{_ + _}.collectAsMap |
| 37 | + private lazy val fpByClass = scoreAndLabels.map{ case (prediction, label) => |
| 38 | + (prediction, if(prediction != label) 1 else 0) }.reduceByKey{_ + _}.collectAsMap |
| 39 | + |
| 40 | + /** |
| 41 | + * Returns Precision for a given label (category) |
| 42 | + * @param label the label. |
| 43 | + * @return Precision. |
| 44 | + */ |
| 45 | + def precision(label: Double): Double = if(tpByClass(label) + fpByClass.getOrElse(label, 0) == 0) 0 |
| 46 | + else tpByClass(label).toDouble / (tpByClass(label) + fpByClass.getOrElse(label, 0)).toDouble |
| 47 | + |
| 48 | + /** |
| 49 | + * Returns Recall for a given label (category) |
| 50 | + * @param label the label. |
| 51 | + * @return Recall. |
| 52 | + */ |
| 53 | + def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label).toDouble |
| 54 | + |
| 55 | + /** |
| 56 | + * Returns F1-measure for a given label (category) |
| 57 | + * @param label the label. |
| 58 | + * @return F1-measure. |
| 59 | + */ |
| 60 | + def f1Measure(label: Double): Double = |
| 61 | + 2 * precision(label) * recall(label) / (precision(label) + recall(label)) |
| 62 | + |
| 63 | + /** |
| 64 | + * Returns micro-averaged Recall |
| 65 | + * (equals to microPrecision and microF1measure for multiclass classifier) |
| 66 | + * @return microRecall. |
| 67 | + */ |
| 68 | + def microRecall: Double = |
| 69 | + tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount.toDouble |
| 70 | + |
| 71 | + /** |
| 72 | + * Returns micro-averaged Precision |
| 73 | + * (equals to microPrecision and microF1measure for multiclass classifier) |
| 74 | + * @return microPrecision. |
| 75 | + */ |
| 76 | + def microPrecision: Double = microRecall |
| 77 | + |
| 78 | + /** |
| 79 | + * Returns micro-averaged F1-measure |
| 80 | + * (equals to microPrecision and microRecall for multiclass classifier) |
| 81 | + * @return microF1measure. |
| 82 | + */ |
| 83 | + def microF1Measure: Double = microRecall |
| 84 | + |
| 85 | + /** |
| 86 | + * Returns weighted averaged Recall |
| 87 | + * @return weightedRecall. |
| 88 | + */ |
| 89 | + def weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) => |
| 90 | + wRecall + recall(category) * count.toDouble / labelCount.toDouble} |
| 91 | + |
| 92 | + /** |
| 93 | + * Returns weighted averaged Precision |
| 94 | + * @return weightedPrecision. |
| 95 | + */ |
| 96 | + def weightedPrecision: Double = |
| 97 | + labelCountByClass.foldLeft(0.0){case(wPrecision, (category, count)) => |
| 98 | + wPrecision + precision(category) * count.toDouble / labelCount.toDouble} |
| 99 | + |
| 100 | + /** |
| 101 | + * Returns weighted averaged F1-measure |
| 102 | + * @return weightedF1Measure. |
| 103 | + */ |
| 104 | + def weightedF1Measure: Double = |
| 105 | + labelCountByClass.foldLeft(0.0){case(wF1measure, (category, count)) => |
| 106 | + wF1measure + f1Measure(category) * count.toDouble / labelCount.toDouble} |
| 107 | + |
| 108 | + /** |
| 109 | + * Returns map with Precisions for individual classes |
| 110 | + * @return precisionPerClass. |
| 111 | + */ |
| 112 | + def precisionPerClass = |
| 113 | + labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap |
| 114 | + |
| 115 | + /** |
| 116 | + * Returns map with Recalls for individual classes |
| 117 | + * @return recallPerClass. |
| 118 | + */ |
| 119 | + def recallPerClass = |
| 120 | + labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap |
| 121 | + |
| 122 | + /** |
| 123 | + * Returns map with F1-measures for individual classes |
| 124 | + * @return f1MeasurePerClass. |
| 125 | + */ |
| 126 | + def f1MeasurePerClass = |
| 127 | + labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap |
| 128 | +} |
0 commit comments