Skip to content

[SPARK-4256] Make Binary Evaluation Metrics functions defined in cases where there ar... #3118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,43 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl
def apply(c: BinaryConfusionMatrix): Double
}

/** Precision. */
/** Precision. Defined as 1.0 when there are no positive examples. */
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double =
c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
override def apply(c: BinaryConfusionMatrix): Double = {
val totalPositives = c.numTruePositives + c.numFalsePositives
if (totalPositives == 0) {
1.0
} else {
c.numTruePositives.toDouble / totalPositives
}
}
}

/** False positive rate. */
/** False positive rate. Defined as 0.0 when there are no negative examples. */
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double =
c.numFalsePositives.toDouble / c.numNegatives
override def apply(c: BinaryConfusionMatrix): Double = {
if (c.numNegatives == 0) {
0.0
} else {
c.numFalsePositives.toDouble / c.numNegatives
}
}
}

/** Recall. */
/** Recall. Defined as 0.0 when there are no positive examples. */
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double =
c.numTruePositives.toDouble / c.numPositives
override def apply(c: BinaryConfusionMatrix): Double = {
if (c.numPositives == 0) {
0.0
} else {
c.numTruePositives.toDouble / c.numPositives
}
}
}

/**
* F-Measure.
* F-Measure. Defined as 0 if both precision and recall are 0. EG in the case that all examples
* are false positives.
* @param beta the beta constant in F-Measure
* @see http://en.wikipedia.org/wiki/F1_score
*/
Expand All @@ -52,6 +69,10 @@ private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificati
override def apply(c: BinaryConfusionMatrix): Double = {
val precision = Precision(c)
val recall = Recall(c)
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
if (precision + recall == 0) {
0.0
} else {
(1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,104 @@ import org.apache.spark.mllib.util.TestingUtils._

class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {

def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5

def cond2(x: ((Double, Double), (Double, Double))): Boolean =
private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean =
(x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)

private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = {
assert(left.zip(right).forall(areWithinEpsilon))
}

private def assertTupleSequencesMatch(left: Seq[(Double, Double)],
right: Seq[(Double, Double)]): Unit = {
assert(left.zip(right).forall(pairsWithinEpsilon))
}

private def validateMetrics(metrics: BinaryClassificationMetrics,
expectedThresholds: Seq[Double],
expectedROCCurve: Seq[(Double, Double)],
expectedPRCurve: Seq[(Double, Double)],
expectedFMeasures1: Seq[Double],
expectedFmeasures2: Seq[Double],
expectedPrecisions: Seq[Double],
expectedRecalls: Seq[Double]) = {

assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds)
assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(expectedROCCurve) absTol 1E-5)
assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve)
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(expectedPRCurve) absTol 1E-5)
assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(),
expectedThresholds.zip(expectedFMeasures1))
assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(),
expectedThresholds.zip(expectedFmeasures2))
assertTupleSequencesMatch(metrics.precisionByThreshold().collect(),
expectedThresholds.zip(expectedPrecisions))
assertTupleSequencesMatch(metrics.recallByThreshold().collect(),
expectedThresholds.zip(expectedRecalls))
}

test("binary evaluation metrics") {
val scoreAndLabels = sc.parallelize(
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val threshold = Seq(0.8, 0.6, 0.4, 0.1)
val thresholds = Seq(0.8, 0.6, 0.4, 0.1)
val numTruePositives = Seq(1, 3, 3, 4)
val numFalsePositives = Seq(0, 1, 2, 3)
val numPositives = 4
val numNegatives = 3
val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) =>
t.toDouble / (t + f)
}
val recall = numTruePositives.map(t => t.toDouble / numPositives)
val recalls = numTruePositives.map(t => t.toDouble / numPositives)
val fpr = numFalsePositives.map(f => f.toDouble / numNegatives)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
val pr = recall.zip(precision)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
val pr = recalls.zip(precisions)
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}

assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
assert(metrics.pr().collect().zip(prCurve).forall(cond2))
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
}

test("binary evaluation metrics for RDD where all examples have positive label") {
val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2)
val metrics = new BinaryClassificationMetrics(scoreAndLabels)

val thresholds = Seq(0.5)
val precisions = Seq(1.0)
val recalls = Seq(1.0)
val fpr = Seq(0.0)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
val pr = recalls.zip(precisions)
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}

validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
}

test("binary evaluation metrics for RDD where all examples have negative label") {
val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0), (0.5, 0.0)), 2)
val metrics = new BinaryClassificationMetrics(scoreAndLabels)

val thresholds = Seq(0.5)
val precisions = Seq(0.0)
val recalls = Seq(0.0)
val fpr = Seq(1.0)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
val pr = recalls.zip(precisions)
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map {
case (0, 0) => 0.0
case (r, p) => 2.0 * (p * r) / (p + r)
}
val f2 = pr.map {
case (0, 0) => 0.0
case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
}

validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
}
}