17
17
18
18
package org .apache .spark .mllib .evaluation
19
19
20
- import org .apache .spark .Logging
21
20
import org .apache .spark .rdd .RDD
22
21
import org .apache .spark .SparkContext ._
23
22
24
23
/**
25
24
* Evaluator for multilabel classification.
26
- * NB: type Double both for prediction and label is retained
27
- * for compatibility with model.predict that returns Double
28
- * and MLUtils.loadLibSVMFile that loads class labels as Double
29
- *
30
25
* @param predictionAndLabels an RDD of (predictions, labels) pairs, both are non-null sets.
31
26
*/
32
- class MultilabelMetrics (predictionAndLabels: RDD [(Set [Double ], Set [Double ])]) extends Logging {
27
+ class MultilabelMetrics (predictionAndLabels : RDD [(Set [Double ], Set [Double ])]) {
33
28
34
- private lazy val numDocs = predictionAndLabels.count
29
+ private lazy val numDocs : Long = predictionAndLabels.count
35
30
36
- private lazy val numLabels = predictionAndLabels.flatMap{case (_, labels) => labels}.distinct.count
31
+ private lazy val numLabels : Long = predictionAndLabels.flatMap { case (_, labels) =>
32
+ labels}.distinct.count
37
33
38
34
/**
39
35
* Returns strict Accuracy
40
36
* (for equal sets of labels)
41
- * @return strictAccuracy.
42
37
*/
43
- lazy val strictAccuracy = predictionAndLabels.filter{ case (predictions, labels) =>
38
+ lazy val strictAccuracy : Double = predictionAndLabels.filter { case (predictions, labels) =>
44
39
predictions == labels}.count.toDouble / numDocs
45
40
46
41
/**
47
42
* Returns Accuracy
48
- * @return Accuracy.
49
43
*/
50
- lazy val accuracy = predictionAndLabels.map{ case (predictions, labels) =>
44
+ lazy val accuracy : Double = predictionAndLabels.map { case (predictions, labels) =>
51
45
labels.intersect(predictions).size.toDouble / labels.union(predictions).size}.sum / numDocs
52
46
53
47
/**
54
48
* Returns Hamming-loss
55
- * @return hammingLoss.
56
49
*/
57
- lazy val hammingLoss = (predictionAndLabels.map{ case (predictions, labels) =>
50
+ lazy val hammingLoss : Double = (predictionAndLabels.map { case (predictions, labels) =>
58
51
labels.diff(predictions).size + predictions.diff(labels).size}.
59
52
sum).toDouble / (numDocs * numLabels)
60
53
61
54
/**
62
55
* Returns Document-based Precision averaged by the number of documents
63
- * @return macroPrecisionDoc.
64
56
*/
65
- lazy val macroPrecisionDoc = (predictionAndLabels.map{ case (predictions, labels) =>
66
- if (predictions.size > 0 )
67
- predictions.intersect(labels).size.toDouble / predictions.size else 0 }.sum) / numDocs
57
+ lazy val macroPrecisionDoc : Double = (predictionAndLabels.map { case (predictions, labels) =>
58
+ if (predictions.size > 0 ) {
59
+ predictions.intersect(labels).size.toDouble / predictions.size
60
+ } else 0
61
+ }.sum) / numDocs
68
62
69
63
/**
70
64
* Returns Document-based Recall averaged by the number of documents
71
- * @return macroRecallDoc.
72
65
*/
73
- lazy val macroRecallDoc = (predictionAndLabels.map{ case (predictions, labels) =>
66
+ lazy val macroRecallDoc : Double = (predictionAndLabels.map { case (predictions, labels) =>
74
67
labels.intersect(predictions).size.toDouble / labels.size}.sum) / numDocs
75
68
76
69
/**
77
70
* Returns Document-based F1-measure averaged by the number of documents
78
- * @return macroRecallDoc.
79
71
*/
80
- lazy val macroF1MeasureDoc = (predictionAndLabels.map{ case (predictions, labels) =>
72
+ lazy val macroF1MeasureDoc : Double = (predictionAndLabels.map { case (predictions, labels) =>
81
73
2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)}.sum) / numDocs
82
74
83
75
/**
84
76
* Returns micro-averaged document-based Precision
85
77
* (equals to label-based microPrecision)
86
- * @return microPrecisionDoc.
87
78
*/
88
- lazy val microPrecisionDoc = microPrecisionClass
79
+ lazy val microPrecisionDoc : Double = microPrecisionClass
89
80
90
81
/**
91
82
* Returns micro-averaged document-based Recall
92
83
* (equals to label-based microRecall)
93
- * @return microRecallDoc.
94
84
*/
95
- lazy val microRecallDoc = microRecallClass
85
+ lazy val microRecallDoc : Double = microRecallClass
96
86
97
87
/**
98
88
* Returns micro-averaged document-based F1-measure
99
89
* (equals to label-based microF1measure)
100
- * @return microF1MeasureDoc.
101
90
*/
102
- lazy val microF1MeasureDoc = microF1MeasureClass
91
+ lazy val microF1MeasureDoc : Double = microF1MeasureClass
103
92
104
- private lazy val tpPerClass = predictionAndLabels.flatMap{ case (predictions, labels) =>
93
+ private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
105
94
predictions.intersect(labels).map(category => (category, 1 ))}.reduceByKey(_ + _).collectAsMap()
106
95
107
- private lazy val fpPerClass = predictionAndLabels.flatMap{ case (predictions, labels) =>
96
+ private lazy val fpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
108
97
predictions.diff(labels).map(category => (category, 1 ))}.reduceByKey(_ + _).collectAsMap()
109
98
110
99
private lazy val fnPerClass = predictionAndLabels.flatMap{ case (predictions, labels) =>
@@ -113,38 +102,39 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
113
102
/**
114
103
* Returns Precision for a given label (category)
115
104
* @param label the label.
116
- * @return Precision.
117
105
*/
118
- def precisionClass (label : Double ) = if ((tpPerClass(label) + fpPerClass.getOrElse(label, 0 )) == 0 )
119
- 0 else tpPerClass(label).toDouble / (tpPerClass(label) + fpPerClass.getOrElse(label, 0 ))
106
+ def precisionClass (label : Double ) = {
107
+ val tp = tpPerClass(label)
108
+ val fp = fpPerClass.getOrElse(label, 0 )
109
+ if (tp + fp == 0 ) 0 else tp.toDouble / (tp + fp)
110
+ }
120
111
121
112
/**
122
113
* Returns Recall for a given label (category)
123
114
* @param label the label.
124
- * @return Recall.
125
115
*/
126
- def recallClass (label : Double ) = if ((tpPerClass(label) + fnPerClass.getOrElse(label, 0 )) == 0 )
127
- 0 else
128
- tpPerClass(label).toDouble / (tpPerClass(label) + fnPerClass.getOrElse(label, 0 ))
116
+ def recallClass (label : Double ) = {
117
+ val tp = tpPerClass(label)
118
+ val fn = fnPerClass.getOrElse(label, 0 )
119
+ if (tp + fn == 0 ) 0 else tp.toDouble / (tp + fn)
120
+ }
129
121
130
122
/**
131
123
* Returns F1-measure for a given label (category)
132
124
* @param label the label.
133
- * @return F1-measure.
134
125
*/
135
126
def f1MeasureClass (label : Double ) = {
136
127
val precision = precisionClass(label)
137
128
val recall = recallClass(label)
138
129
if ((precision + recall) == 0 ) 0 else 2 * precision * recall / (precision + recall)
139
130
}
140
131
141
- private lazy val sumTp = tpPerClass.foldLeft(0L ){ case (sum, (_, tp)) => sum + tp}
142
- private lazy val sumFpClass = fpPerClass.foldLeft(0L ){ case (sum, (_, fp)) => sum + fp}
143
- private lazy val sumFnClass = fnPerClass.foldLeft(0L ){ case (sum, (_, fn)) => sum + fn}
132
+ private lazy val sumTp = tpPerClass.foldLeft(0L ){ case (sum, (_, tp)) => sum + tp}
133
+ private lazy val sumFpClass = fpPerClass.foldLeft(0L ){ case (sum, (_, fp)) => sum + fp}
134
+ private lazy val sumFnClass = fnPerClass.foldLeft(0L ){ case (sum, (_, fn)) => sum + fn}
144
135
145
136
/**
146
137
* Returns micro-averaged label-based Precision
147
- * @return microPrecisionClass.
148
138
*/
149
139
lazy val microPrecisionClass = {
150
140
val sumFp = fpPerClass.foldLeft(0L ){ case (sumFp, (_, fp)) => sumFp + fp}
@@ -153,7 +143,6 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
153
143
154
144
/**
155
145
* Returns micro-averaged label-based Recall
156
- * @return microRecallClass.
157
146
*/
158
147
lazy val microRecallClass = {
159
148
val sumFn = fnPerClass.foldLeft(0.0 ){ case (sumFn, (_, fn)) => sumFn + fn}
@@ -162,8 +151,6 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
162
151
163
152
/**
164
153
* Returns micro-averaged label-based F1-measure
165
- * @return microRecallClass.
166
154
*/
167
155
lazy val microF1MeasureClass = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
168
-
169
156
}
0 commit comments