Skip to content

Commit b3d30a8

Browse files
shishaochensrowen
authored andcommitted
[SPARK-27577][MLLIB] Correct thresholds downsampled in BinaryClassificationMetrics
## What changes were proposed in this pull request? Choose the last record in chunks when calculating metrics with downsampling in `BinaryClassificationMetrics`. ## How was this patch tested? A new unit test is added to verify thresholds from downsampled records. Closes #24470 from shishaochen/spark-mllib-binary-metrics. Authored-by: Shaochen Shi <shishaochen@bytedance.com> Signed-off-by: Sean Owen <sean.owen@databricks.com> (cherry picked from commit d5308cd) Signed-off-by: Sean Owen <sean.owen@databricks.com>
1 parent 771da83 commit b3d30a8

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,15 @@ class BinaryClassificationMetrics @Since("1.3.0") (
175175
grouping = Int.MaxValue
176176
}
177177
counts.mapPartitions(_.grouped(grouping.toInt).map { pairs =>
178-
// The score of the combined point will be just the first one's score
179-
val firstScore = pairs.head._1
180-
// The point will contain all counts in this chunk
178+
// The score of the combined point will be just the last one's score, which is also
179+
// the minimal in each chunk since all scores are already sorted in descending.
180+
val lastScore = pairs.last._1
181+
// The combined point will contain all counts in this chunk. Thus, calculated
182+
// metrics (like precision, recall, etc.) on its score (or so-called threshold) are
183+
// the same as those without sampling.
181184
val agg = new BinaryLabelCounter()
182185
pairs.foreach(pair => agg += pair._2)
183-
(firstScore, agg)
186+
(lastScore, agg)
184187
})
185188
}
186189
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,17 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
155155
(1.0, 1.0), (1.0, 1.0)
156156
) ==
157157
downsampledROC)
158+
159+
val downsampledRecall = downsampled.recallByThreshold().collect().sorted.toList
160+
assert(
161+
// May have to add 1 if the sample factor didn't divide evenly
162+
numBins + (if (scoreAndLabels.size % numBins == 0) 0 else 1) ==
163+
downsampledRecall.size)
164+
assert(
165+
List(
166+
(0.1, 1.0), (0.2, 1.0), (0.4, 0.75), (0.6, 0.75), (0.8, 0.25)
167+
) ==
168+
downsampledRecall)
158169
}
159170

160171
}

0 commit comments

Comments
 (0)