Skip to content

[SPARK-27577][MLlib] Correct thresholds downsampled in BinaryClassificationMetrics #24470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed

[SPARK-27577][MLlib] Correct thresholds downsampled in BinaryClassificationMetrics #24470

wants to merge 4 commits into from

Conversation

shishaochen
Copy link

@shishaochen shishaochen commented Apr 26, 2019

What changes were proposed in this pull request?

Choose the last record in chunks when calculating metrics with downsampling in BinaryClassificationMetrics.

How was this patch tested?

A new unit test is added to verify thresholds from downsampled records.

Change-Id: Ic7d10b6844b480c77940707db9a722fd6927bd67
Change-Id: Ifd4cd7b1181957c99be33e55995aa4c62d963d5b
Change-Id: Ib815be7c34b913b6a005fb8f7f53f182be8c9e21
@shishaochen
Copy link
Author

@srowen Could you please have a look at this pull request? Thanks a lot!

@srowen
Copy link
Member

srowen commented May 4, 2019

This doesn't look like a bug. I can't understand the argument in the JIRA why the last vs first element of a bin is more representative. Both are approximations.

@shishaochen
Copy link
Author

shishaochen commented May 4, 2019

@srowen Yes, both are approximations. But it has less error if we choose the last element in each chunk as the threshold.
And the essential problem is that, the so-called "downsampling" is not real sampling. The code behind calculates precision, recall, etc. based on statistics (like TP, NP, TF, NF) of all elements.

counts.mapPartitions(_.grouped(grouping.toInt).map { pairs =>
  // The score of the combined point will be just the first one's score
  val firstScore = pairs.head._1
  // The point will contain all counts in this chunk
  val agg = new BinaryLabelCounter()
  pairs.foreach(pair => agg += pair._2)
  (firstScore, agg)
})

You can see, counters (BinaryLabelCounter) of all elements are merged into one instead of return the first element directly.
Thus, from the definition of threshold, the score of the last element (which is the minimal one) is the right threshold to use when inference.
In online systems, we need choose the right threshold to predict whether an instance is positive (score>=threshold) or negative (score<threshold).
For example, in a high-risk-detection model for videos where RECALL is extemely important, we choose a threshold from what BinaryClassificationMetrics prints. When numBins is set to 200, each chunk has about 0.5% instances. The wrong threshold given by score of the first element will miss lots of videos (1 million per day in total) that should be considered dangerous.

@srowen
Copy link
Member

srowen commented May 4, 2019

I still don't see the argument that the first or last is better. They are simply the endpoints of the range of scores within the bin. As the number of bins increases, the range is smaller. If you are worried about this difference, you need more bins. Your argument cuts two ways: having a slightly higher threshold than desired can cause as many problems as slightly smaller.

What would be possibly better here is to compute the score of a bin as a weighted average of its elements. That would be OK though you'd have to change many tests. I think the current implementation is designed to match scikit (?)

@shishaochen
Copy link
Author

@srowen Get your point!
Actually, if we choose score of the last element in each chunk as threshold, the calculated Recall, Precision, FMeasure on each threshold are exactly the same as those when no sampling (numBins=0).
In other words, they are accurate metrics. The only difference is the count of thresholds when printing precision/recall/f1 curve compared to downsampling.
Thus, why not return the correct metrics of full data set but approximate values?

@srowen
Copy link
Member

srowen commented May 5, 2019

Ah OK I agree with you, I see the argument now. It comes from the fact that the scores are sorted descending. The score of each bin is currently its maximum, not minimum. The precision / recall for each bin is calculated as if all of the instances in the bin were classified as positive. This only makes sense if the score is the minimum.

You might mention something to this effect in the comment in the code.
Also I think this may change some test results; let's see.

@SparkQA
Copy link

SparkQA commented May 5, 2019

Test build #4776 has finished for PR 24470 at commit fa2eae6.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Change-Id: I270aa88bc06c7aac9e4fcf4978f4bb6b9dcac93b
@shishaochen
Copy link
Author

shishaochen commented May 6, 2019

@srowen Great thanks for your patience!
I have added explanation in code comments at BinaryClassificationMetrics.scala. Do these words at below match your expectations?

counts.mapPartitions(_.grouped(grouping.toInt).map { pairs =>
  // The score of the combined point will be just the last one's score, which is also
  // the minimal in each chunk since all scores are already sorted in descending.
  val lastScore = pairs.last._1
  // The combined point will contain all counts in this chunk. Thus, calculated
  // metrics (like precision, recall, etc.) on its score (or so-called threshold) are
  // the same as those without sampling.
  val agg = new BinaryLabelCounter()
  pairs.foreach(pair => agg += pair._2)
  (lastScore, agg)
})

Besides, I have scanned all unit tests and class references in the Spark code repository. None of them uses numBins but one unit test in BinaryClassificationMetricsSuite, which only tests the ROC curve without threshold. Thus, it is safe to merge this pull request.

@shishaochen
Copy link
Author

@srowen Execuse me, is there anything I should do before merging this pull request? Thanks a lot!

@srowen
Copy link
Member

srowen commented May 6, 2019

No, I leave these open for a day or two to make sure there aren't more comments.

srowen pushed a commit that referenced this pull request May 7, 2019
…cationMetrics

## What changes were proposed in this pull request?

Choose the last record in chunks when calculating metrics with downsampling in `BinaryClassificationMetrics`.

## How was this patch tested?

A new unit test is added to verify thresholds from downsampled records.

Closes #24470 from shishaochen/spark-mllib-binary-metrics.

Authored-by: Shaochen Shi <shishaochen@bytedance.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
(cherry picked from commit d5308cd)
Signed-off-by: Sean Owen <sean.owen@databricks.com>
@srowen srowen closed this in d5308cd May 7, 2019
@srowen
Copy link
Member

srowen commented May 7, 2019

Merged to master/2.4/2.3

srowen pushed a commit that referenced this pull request May 7, 2019
…cationMetrics

## What changes were proposed in this pull request?

Choose the last record in chunks when calculating metrics with downsampling in `BinaryClassificationMetrics`.

## How was this patch tested?

A new unit test is added to verify thresholds from downsampled records.

Closes #24470 from shishaochen/spark-mllib-binary-metrics.

Authored-by: Shaochen Shi <shishaochen@bytedance.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
(cherry picked from commit d5308cd)
Signed-off-by: Sean Owen <sean.owen@databricks.com>
rluta pushed a commit to rluta/spark that referenced this pull request Sep 17, 2019
…cationMetrics

## What changes were proposed in this pull request?

Choose the last record in chunks when calculating metrics with downsampling in `BinaryClassificationMetrics`.

## How was this patch tested?

A new unit test is added to verify thresholds from downsampled records.

Closes apache#24470 from shishaochen/spark-mllib-binary-metrics.

Authored-by: Shaochen Shi <shishaochen@bytedance.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
(cherry picked from commit d5308cd)
Signed-off-by: Sean Owen <sean.owen@databricks.com>
kai-chi pushed a commit to kai-chi/spark that referenced this pull request Sep 26, 2019
…cationMetrics

## What changes were proposed in this pull request?

Choose the last record in chunks when calculating metrics with downsampling in `BinaryClassificationMetrics`.

## How was this patch tested?

A new unit test is added to verify thresholds from downsampled records.

Closes apache#24470 from shishaochen/spark-mllib-binary-metrics.

Authored-by: Shaochen Shi <shishaochen@bytedance.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
(cherry picked from commit d5308cd)
Signed-off-by: Sean Owen <sean.owen@databricks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants