Skip to content

Commit e3db569

Browse files
committed
Addressing reviewers comments mengxr. Added true positive rate and false positive rate. Test suite code style.
1 parent a7e8bf0 commit e3db569

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,24 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
4343
private lazy val fpByClass: Map[Double, Int] = predictionsAndLabels
4444
.map { case (prediction, label) =>
4545
(prediction, if (prediction != label) 1 else 0)
46-
}.reduceByKey(_ + _)
46+
}.reduceByKey(_ + _)
4747
.collectAsMap()
4848

49+
/**
50+
* Returns true positive rate for a given label (category)
51+
* @param label the label.
52+
*/
53+
def truePositiveRate(label: Double): Double = recall(label)
54+
55+
/**
56+
* Returns false positive rate for a given label (category)
57+
* @param label the label.
58+
*/
59+
def falsePositiveRate(label: Double): Double = {
60+
val fp = fpByClass.getOrElse(label, 0)
61+
fp.toDouble / (labelCount - labelCountByClass(label))
62+
}
63+
4964
/**
5065
* Returns precision for a given label (category)
5166
* @param label the label.
@@ -65,6 +80,7 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
6580
/**
6681
* Returns f-measure for a given label (category)
6782
* @param label the label.
83+
* @param beta the beta parameter.
6884
*/
6985
def fMeasure(label: Double, beta: Double): Double = {
7086
val p = precision(label)
@@ -113,15 +129,23 @@ class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Log
113129
precision(category) * count.toDouble / labelCount
114130
}.sum
115131

132+
/**
133+
* Returns weighted averaged f-measure
134+
* @param beta the beta parameter.
135+
*/
136+
def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) =>
137+
fMeasure(category, beta) * count.toDouble / labelCount
138+
}.sum
139+
116140
/**
117141
* Returns weighted averaged f1-measure
118142
*/
119-
lazy val weightedF1Measure: Double = labelCountByClass.map { case (category, count) =>
120-
fMeasure(category) * count.toDouble / labelCount
143+
lazy val weightedFMeasure: Double = labelCountByClass.map { case (category, count) =>
144+
fMeasure(category, 1.0) * count.toDouble / labelCount
121145
}.sum
122146

123147
/**
124148
* Returns the sequence of labels in ascending order
125149
*/
126-
lazy val labels:Array[Double] = tpByClass.keys.toArray.sorted
150+
lazy val labels: Array[Double] = tpByClass.keys.toArray.sorted
127151
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,21 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
3636
(1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
3737
val metrics = new MulticlassMetrics(scoreAndLabels)
3838
val delta = 0.0000001
39-
val precision0 = 2.0 / (2.0 + 1.0)
40-
val precision1 = 3.0 / (3.0 + 1.0)
41-
val precision2 = 1.0 / (1.0 + 1.0)
42-
val recall0 = 2.0 / (2.0 + 2.0)
43-
val recall1 = 3.0 / (3.0 + 1.0)
44-
val recall2 = 1.0 / (1.0 + 0.0)
39+
val fpRate0 = 1.0 / (9 - 4)
40+
val fpRate1 = 1.0 / (9 - 4)
41+
val fpRate2 = 1.0 / (9 - 1)
42+
val precision0 = 2.0 / (2 + 1)
43+
val precision1 = 3.0 / (3 + 1)
44+
val precision2 = 1.0 / (1 + 1)
45+
val recall0 = 2.0 / (2 + 2)
46+
val recall1 = 3.0 / (3 + 1)
47+
val recall2 = 1.0 / (1 + 0)
4548
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
4649
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
4750
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
51+
assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta)
52+
assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta)
53+
assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta)
4854
assert(math.abs(metrics.precision(0.0) - precision0) < delta)
4955
assert(math.abs(metrics.precision(1.0) - precision1) < delta)
5056
assert(math.abs(metrics.precision(2.0) - precision2) < delta)
@@ -55,16 +61,16 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
5561
assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta)
5662
assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta)
5763
assert(math.abs(metrics.recall -
58-
(2.0 + 3.0 + 1.0) / ((2.0 + 3.0 + 1.0) + (1.0 + 1.0 + 1.0))) < delta)
64+
(2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta)
5965
assert(math.abs(metrics.recall - metrics.precision) < delta)
6066
assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
6167
assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
6268
assert(math.abs(metrics.weightedPrecision -
63-
((4.0 / 9.0) * precision0 + (4.0 / 9.0) * precision1 + (1.0 / 9.0) * precision2)) < delta)
69+
((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta)
6470
assert(math.abs(metrics.weightedRecall -
65-
((4.0 / 9.0) * recall0 + (4.0 / 9.0) * recall1 + (1.0 / 9.0) * recall2)) < delta)
66-
assert(math.abs(metrics.weightedF1Measure -
67-
((4.0 / 9.0) * f1measure0 + (4.0 / 9.0) * f1measure1 + (1.0 / 9.0) * f1measure2)) < delta)
71+
((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta)
72+
assert(math.abs(metrics.weightedFMeasure -
73+
((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta)
6874
assert(metrics.labels.sameElements(labels))
6975
}
7076
}

0 commit comments

Comments
 (0)