Skip to content

Commit 1843f73

Browse files
committed
Scala style fix
1 parent 79e8476 commit 1843f73

File tree

2 files changed

+35
-58
lines changed

2 files changed

+35
-58
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,94 +17,83 @@
1717

1818
package org.apache.spark.mllib.evaluation
1919

20-
import org.apache.spark.Logging
2120
import org.apache.spark.rdd.RDD
2221
import org.apache.spark.SparkContext._
2322

2423
/**
2524
* Evaluator for multilabel classification.
26-
* NB: type Double both for prediction and label is retained
27-
* for compatibility with model.predict that returns Double
28-
* and MLUtils.loadLibSVMFile that loads class labels as Double
29-
*
3025
* @param predictionAndLabels an RDD of (predictions, labels) pairs, both are non-null sets.
3126
*/
32-
class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) extends Logging{
27+
class MultilabelMetrics(predictionAndLabels: RDD[(Set[Double], Set[Double])]) {
3328

34-
private lazy val numDocs = predictionAndLabels.count
29+
private lazy val numDocs: Long = predictionAndLabels.count
3530

36-
private lazy val numLabels = predictionAndLabels.flatMap{case(_, labels) => labels}.distinct.count
31+
private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>
32+
labels}.distinct.count
3733

3834
/**
3935
* Returns strict Accuracy
4036
* (for equal sets of labels)
41-
* @return strictAccuracy.
4237
*/
43-
lazy val strictAccuracy = predictionAndLabels.filter{case(predictions, labels) =>
38+
lazy val strictAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
4439
predictions == labels}.count.toDouble / numDocs
4540

4641
/**
4742
* Returns Accuracy
48-
* @return Accuracy.
4943
*/
50-
lazy val accuracy = predictionAndLabels.map{ case(predictions, labels) =>
44+
lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) =>
5145
labels.intersect(predictions).size.toDouble / labels.union(predictions).size}.sum / numDocs
5246

5347
/**
5448
* Returns Hamming-loss
55-
* @return hammingLoss.
5649
*/
57-
lazy val hammingLoss = (predictionAndLabels.map{ case(predictions, labels) =>
50+
lazy val hammingLoss: Double = (predictionAndLabels.map { case (predictions, labels) =>
5851
labels.diff(predictions).size + predictions.diff(labels).size}.
5952
sum).toDouble / (numDocs * numLabels)
6053

6154
/**
6255
* Returns Document-based Precision averaged by the number of documents
63-
* @return macroPrecisionDoc.
6456
*/
65-
lazy val macroPrecisionDoc = (predictionAndLabels.map{ case(predictions, labels) =>
66-
if(predictions.size >0)
67-
predictions.intersect(labels).size.toDouble / predictions.size else 0}.sum) / numDocs
57+
lazy val macroPrecisionDoc: Double = (predictionAndLabels.map { case (predictions, labels) =>
58+
if (predictions.size > 0) {
59+
predictions.intersect(labels).size.toDouble / predictions.size
60+
} else 0
61+
}.sum) / numDocs
6862

6963
/**
7064
* Returns Document-based Recall averaged by the number of documents
71-
* @return macroRecallDoc.
7265
*/
73-
lazy val macroRecallDoc = (predictionAndLabels.map{ case(predictions, labels) =>
66+
lazy val macroRecallDoc: Double = (predictionAndLabels.map { case (predictions, labels) =>
7467
labels.intersect(predictions).size.toDouble / labels.size}.sum) / numDocs
7568

7669
/**
7770
* Returns Document-based F1-measure averaged by the number of documents
78-
* @return macroRecallDoc.
7971
*/
80-
lazy val macroF1MeasureDoc = (predictionAndLabels.map{ case(predictions, labels) =>
72+
lazy val macroF1MeasureDoc: Double = (predictionAndLabels.map { case (predictions, labels) =>
8173
2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)}.sum) / numDocs
8274

8375
/**
8476
* Returns micro-averaged document-based Precision
8577
* (equals to label-based microPrecision)
86-
* @return microPrecisionDoc.
8778
*/
88-
lazy val microPrecisionDoc = microPrecisionClass
79+
lazy val microPrecisionDoc: Double = microPrecisionClass
8980

9081
/**
9182
* Returns micro-averaged document-based Recall
9283
* (equals to label-based microRecall)
93-
* @return microRecallDoc.
9484
*/
95-
lazy val microRecallDoc = microRecallClass
85+
lazy val microRecallDoc: Double = microRecallClass
9686

9787
/**
9888
* Returns micro-averaged document-based F1-measure
9989
* (equals to label-based microF1measure)
100-
* @return microF1MeasureDoc.
10190
*/
102-
lazy val microF1MeasureDoc = microF1MeasureClass
91+
lazy val microF1MeasureDoc: Double = microF1MeasureClass
10392

104-
private lazy val tpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
93+
private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
10594
predictions.intersect(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
10695

107-
private lazy val fpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
96+
private lazy val fpPerClass = predictionAndLabels.flatMap { case(predictions, labels) =>
10897
predictions.diff(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
10998

11099
private lazy val fnPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
@@ -113,38 +102,39 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
113102
/**
114103
* Returns Precision for a given label (category)
115104
* @param label the label.
116-
* @return Precision.
117105
*/
118-
def precisionClass(label: Double) = if((tpPerClass(label) + fpPerClass.getOrElse(label, 0)) == 0)
119-
0 else tpPerClass(label).toDouble / (tpPerClass(label) + fpPerClass.getOrElse(label, 0))
106+
def precisionClass(label: Double) = {
107+
val tp = tpPerClass(label)
108+
val fp = fpPerClass.getOrElse(label, 0)
109+
if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
110+
}
120111

121112
/**
122113
* Returns Recall for a given label (category)
123114
* @param label the label.
124-
* @return Recall.
125115
*/
126-
def recallClass(label: Double) = if((tpPerClass(label) + fnPerClass.getOrElse(label, 0)) == 0)
127-
0 else
128-
tpPerClass(label).toDouble / (tpPerClass(label) + fnPerClass.getOrElse(label, 0))
116+
def recallClass(label: Double) = {
117+
val tp = tpPerClass(label)
118+
val fn = fnPerClass.getOrElse(label, 0)
119+
if (tp + fn == 0) 0 else tp.toDouble / (tp + fn)
120+
}
129121

130122
/**
131123
* Returns F1-measure for a given label (category)
132124
* @param label the label.
133-
* @return F1-measure.
134125
*/
135126
def f1MeasureClass(label: Double) = {
136127
val precision = precisionClass(label)
137128
val recall = recallClass(label)
138129
if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall)
139130
}
140131

141-
private lazy val sumTp = tpPerClass.foldLeft(0L){ case(sum, (_, tp)) => sum + tp}
142-
private lazy val sumFpClass = fpPerClass.foldLeft(0L){ case(sum, (_, fp)) => sum + fp}
143-
private lazy val sumFnClass = fnPerClass.foldLeft(0L){ case(sum, (_, fn)) => sum + fn}
132+
private lazy val sumTp = tpPerClass.foldLeft(0L){ case (sum, (_, tp)) => sum + tp}
133+
private lazy val sumFpClass = fpPerClass.foldLeft(0L){ case (sum, (_, fp)) => sum + fp}
134+
private lazy val sumFnClass = fnPerClass.foldLeft(0L){ case (sum, (_, fn)) => sum + fn}
144135

145136
/**
146137
* Returns micro-averaged label-based Precision
147-
* @return microPrecisionClass.
148138
*/
149139
lazy val microPrecisionClass = {
150140
val sumFp = fpPerClass.foldLeft(0L){ case(sumFp, (_, fp)) => sumFp + fp}
@@ -153,7 +143,6 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
153143

154144
/**
155145
* Returns micro-averaged label-based Recall
156-
* @return microRecallClass.
157146
*/
158147
lazy val microRecallClass = {
159148
val sumFn = fnPerClass.foldLeft(0.0){ case(sumFn, (_, fn)) => sumFn + fn}
@@ -162,8 +151,6 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
162151

163152
/**
164153
* Returns micro-averaged label-based F1-measure
165-
* @return microRecallClass.
166154
*/
167155
lazy val microF1MeasureClass = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
168-
169156
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
package org.apache.spark.mllib.evaluation
1919

20-
import org.apache.spark.mllib.util.LocalSparkContext
21-
import org.apache.spark.rdd.RDD
2220
import org.scalatest.FunSuite
2321

22+
import org.apache.spark.mllib.util.LocalSparkContext
23+
import org.apache.spark.rdd.RDD
2424

2525
class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
2626
test("Multilabel evaluation metrics") {
@@ -45,7 +45,7 @@ class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
4545
* class 2 - doc 0, 3, 4, 6 (total 4)
4646
*
4747
*/
48-
val scoreAndLabels:RDD[(Set[Double], Set[Double])] = sc.parallelize(
48+
val scoreAndLabels: RDD[(Set[Double], Set[Double])] = sc.parallelize(
4949
Seq((Set(0.0, 1.0), Set(0.0, 2.0)),
5050
(Set(0.0, 2.0), Set(0.0, 1.0)),
5151
(Set(), Set(0.0)),
@@ -70,20 +70,16 @@ class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
7070
val microRecallClass = sumTp.toDouble / (4 + 1 + 2 + 1 + 2 + 2)
7171
val microF1MeasureClass = 2.0 * sumTp.toDouble /
7272
(2 * sumTp.toDouble + (1 + 1 + 2) + (0 + 1 + 2))
73-
7473
val macroPrecisionDoc = 1.0 / 7 *
7574
(1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0)
7675
val macroRecallDoc = 1.0 / 7 *
7776
(1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2)
7877
val macroF1MeasureDoc = (1.0 / 7) *
7978
2 * ( 1.0 / (2 + 2) + 1.0 / (2 + 2) + 0 + 1.0 / (1 + 1) +
8079
2.0 / (2 + 2) + 2.0 / (3 + 2) + 1.0 / (1 + 2) )
81-
8280
val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1)
83-
8481
val strictAccuracy = 2.0 / 7
8582
val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2)
86-
8783
assert(math.abs(metrics.precisionClass(0.0) - precision0) < delta)
8884
assert(math.abs(metrics.precisionClass(1.0) - precision1) < delta)
8985
assert(math.abs(metrics.precisionClass(2.0) - precision2) < delta)
@@ -93,20 +89,14 @@ class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
9389
assert(math.abs(metrics.f1MeasureClass(0.0) - f1measure0) < delta)
9490
assert(math.abs(metrics.f1MeasureClass(1.0) - f1measure1) < delta)
9591
assert(math.abs(metrics.f1MeasureClass(2.0) - f1measure2) < delta)
96-
9792
assert(math.abs(metrics.microPrecisionClass - microPrecisionClass) < delta)
9893
assert(math.abs(metrics.microRecallClass - microRecallClass) < delta)
9994
assert(math.abs(metrics.microF1MeasureClass - microF1MeasureClass) < delta)
100-
10195
assert(math.abs(metrics.macroPrecisionDoc - macroPrecisionDoc) < delta)
10296
assert(math.abs(metrics.macroRecallDoc - macroRecallDoc) < delta)
10397
assert(math.abs(metrics.macroF1MeasureDoc - macroF1MeasureDoc) < delta)
104-
10598
assert(math.abs(metrics.hammingLoss - hammingLoss) < delta)
10699
assert(math.abs(metrics.strictAccuracy - strictAccuracy) < delta)
107100
assert(math.abs(metrics.accuracy - accuracy) < delta)
108-
109-
110101
}
111-
112102
}

0 commit comments

Comments
 (0)