Skip to content

Commit a05941d

Browse files
committed
replace TP/FP/TN/FN by their full names
1 parent 3f42e98 commit a05941d

File tree

4 files changed

+43
-41
lines changed

4 files changed

+43
-41
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,19 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl
2727
/** Precision. */
2828
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
2929
override def apply(c: BinaryConfusionMatrix): Double =
30-
c.tp.toDouble / (c.tp + c.fp)
30+
c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
3131
}
3232

3333
/** False positive rate. */
3434
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
3535
override def apply(c: BinaryConfusionMatrix): Double =
36-
c.fp.toDouble / c.n
36+
c.numFalsePositives.toDouble / c.numNegatives
3737
}
3838

3939
/** Recall. */
4040
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
4141
override def apply(c: BinaryConfusionMatrix): Double =
42-
c.tp.toDouble / c.p
42+
c.numTruePositives.toDouble / c.numPositives
4343
}
4444

4545
/**

mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,22 @@ private case class BinaryConfusionMatrixImpl(
3333
totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable {
3434

3535
/** number of true positives */
36-
override def tp: Long = count.numPositives
36+
override def numTruePositives: Long = count.numPositives
3737

3838
/** number of false positives */
39-
override def fp: Long = count.numNegatives
39+
override def numFalsePositives: Long = count.numNegatives
4040

4141
/** number of false negatives */
42-
override def fn: Long = totalCount.numPositives - count.numPositives
42+
override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives
4343

4444
/** number of true negatives */
45-
override def tn: Long = totalCount.numNegatives - count.numNegatives
45+
override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives
4646

4747
/** number of positives */
48-
override def p: Long = totalCount.numPositives
48+
override def numPositives: Long = totalCount.numPositives
4949

5050
/** number of negatives */
51-
override def n: Long = totalCount.numNegatives
51+
override def numNegatives: Long = totalCount.numNegatives
5252
}
5353

5454
/**
@@ -57,10 +57,10 @@ private case class BinaryConfusionMatrixImpl(
5757
* @param scoreAndLabels an RDD of (score, label) pairs.
5858
*/
5959
class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)])
60-
extends Serializable with Logging {
60+
extends Serializable with Logging {
6161

6262
private lazy val (
63-
cumCounts: RDD[(Double, LabelCounter)],
63+
cumulativeCounts: RDD[(Double, LabelCounter)],
6464
confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
6565
// Create a bin for each distinct score value, count positives and negatives within each bin,
6666
// and then sort by score values in descending order.
@@ -74,32 +74,32 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)])
7474
iter.foreach(agg += _)
7575
Iterator(agg)
7676
}, preservesPartitioning = true).collect()
77-
val partitionwiseCumCounts =
77+
val partitionwiseCumulativeCounts =
7878
agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg.clone() += c)
79-
val totalCount = partitionwiseCumCounts.last
79+
val totalCount = partitionwiseCumulativeCounts.last
8080
logInfo(s"Total counts: $totalCount")
81-
val cumCounts = counts.mapPartitionsWithIndex(
81+
val cumulativeCounts = counts.mapPartitionsWithIndex(
8282
(index: Int, iter: Iterator[(Double, LabelCounter)]) => {
83-
val cumCount = partitionwiseCumCounts(index)
83+
val cumCount = partitionwiseCumulativeCounts(index)
8484
iter.map { case (score, c) =>
8585
cumCount += c
8686
(score, cumCount.clone())
8787
}
8888
}, preservesPartitioning = true)
89-
cumCounts.persist()
90-
val confusions = cumCounts.map { case (score, cumCount) =>
89+
cumulativeCounts.persist()
90+
val confusions = cumulativeCounts.map { case (score, cumCount) =>
9191
(score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])
9292
}
93-
(cumCounts, confusions)
93+
(cumulativeCounts, confusions)
9494
}
9595

9696
/** Unpersist intermediate RDDs used in the computation. */
9797
def unpersist() {
98-
cumCounts.unpersist()
98+
cumulativeCounts.unpersist()
9999
}
100100

101101
/** Returns thresholds in descending order. */
102-
def thresholds(): RDD[Double] = cumCounts.map(_._1)
102+
def thresholds(): RDD[Double] = cumulativeCounts.map(_._1)
103103

104104
/**
105105
* Returns the receiver operating characteristic (ROC) curve,

mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,20 @@ package org.apache.spark.mllib.evaluation.binary
2222
*/
2323
private[evaluation] trait BinaryConfusionMatrix {
2424
/** number of true positives */
25-
def tp: Long
25+
def numTruePositives: Long
2626

2727
/** number of false positives */
28-
def fp: Long
28+
def numFalsePositives: Long
2929

3030
/** number of false negatives */
31-
def fn: Long
31+
def numFalseNegatives: Long
3232

3333
/** number of true negatives */
34-
def tn: Long
34+
def numTrueNegatives: Long
3535

3636
/** number of positives */
37-
def p: Long = tp + fn
37+
def numPositives: Long = numTruePositives + numFalseNegatives
3838

3939
/** number of negatives */
40-
def n: Long = fp + tn
40+
def numNegatives: Long = numFalsePositives + numTrueNegatives
4141
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,29 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
2727
val scoreAndLabels = sc.parallelize(
2828
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
2929
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
30-
val score = Seq(0.8, 0.6, 0.4, 0.1)
31-
val tp = Seq(1, 3, 3, 4)
32-
val fp = Seq(0, 1, 2, 3)
33-
val p = 4
34-
val n = 3
35-
val precision = tp.zip(fp).map { case (t, f) => t.toDouble / (t + f) }
36-
val recall = tp.map(t => t.toDouble / p)
37-
val fpr = fp.map(f => f.toDouble / n)
30+
val threshold = Seq(0.8, 0.6, 0.4, 0.1)
31+
val numTruePositives = Seq(1, 3, 3, 4)
32+
val numFalsePositives = Seq(0, 1, 2, 3)
33+
val numPositives = 4
34+
val numNegatives = 3
35+
val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
36+
t.toDouble / (t + f)
37+
}
38+
val recall = numTruePositives.map(t => t.toDouble / numPositives)
39+
val fpr = numFalsePositives.map(f => f.toDouble / numNegatives)
3840
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
3941
val pr = recall.zip(precision)
4042
val prCurve = Seq((0.0, 1.0)) ++ pr
41-
val f1 = pr.map { case (re, prec) => 2.0 * (prec * re) / (prec + re) }
42-
val f2 = pr.map { case (re, prec) => 5.0 * (prec * re) / (4.0 * prec + re)}
43-
assert(metrics.thresholds().collect().toSeq === score)
43+
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) }
44+
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
45+
assert(metrics.thresholds().collect().toSeq === threshold)
4446
assert(metrics.roc().collect().toSeq === rocCurve)
4547
assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve))
4648
assert(metrics.pr().collect().toSeq === prCurve)
4749
assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve))
48-
assert(metrics.fMeasureByThreshold().collect().toSeq === score.zip(f1))
49-
assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === score.zip(f2))
50-
assert(metrics.precisionByThreshold().collect().toSeq === score.zip(precision))
51-
assert(metrics.recallByThreshold().collect().toSeq === score.zip(recall))
50+
assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1))
51+
assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2))
52+
assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision))
53+
assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall))
5254
}
5355
}

0 commit comments

Comments
 (0)