Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44871][SQL] Fix percentile_disc behaviour #42559

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.TypeCollection.NumericAndAnsiInterval
import org.apache.spark.util.collection.OpenHashMap
Expand Down Expand Up @@ -164,11 +165,9 @@ abstract class PercentileBase
val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) {
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
}.tail
val maxPosition = accumulatedCounts.last._2 - 1

percentages.map { percentile =>
getPercentile(accumulatedCounts, maxPosition * percentile)
}

percentages.map(getPercentile(accumulatedCounts, _))
}

private def generateOutput(percentiles: Seq[Double]): Any = {
Expand All @@ -191,8 +190,11 @@ abstract class PercentileBase
* This function has been based upon similar function from HIVE
* `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
*/
private def getPercentile(
accumulatedCounts: Seq[(AnyRef, Long)], position: Double): Double = {
protected def getPercentile(
accumulatedCounts: Seq[(AnyRef, Long)],
percentile: Double): Double = {
val position = (accumulatedCounts.last._2 - 1) * percentile

// We may need to do linear interpolation to get the exact percentile
val lower = position.floor.toLong
val higher = position.ceil.toLong
Expand All @@ -215,6 +217,7 @@ abstract class PercentileBase
}

if (discrete) {
// We end up here only if spark.sql.legacy.percentileDiscCalculation=true
toDoubleValue(lowerKey)
} else {
// Linear interpolation to get the exact percentile
Expand Down Expand Up @@ -384,7 +387,9 @@ case class PercentileDisc(
percentageExpression: Expression,
reverse: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends PercentileBase with BinaryLike[Expression] {
inputAggBufferOffset: Int = 0,
legacyCalculation: Boolean = SQLConf.get.getConf(SQLConf.LEGACY_PERCENTILE_DISC_CALCULATION))
extends PercentileBase with BinaryLike[Expression] {

val frequencyExpression: Expression = Literal(1L)

Expand Down Expand Up @@ -412,4 +417,25 @@ case class PercentileDisc(
child = newLeft,
percentageExpression = newRight
)

override protected def getPercentile(
accumulatedCounts: Seq[(AnyRef, Long)],
percentile: Double): Double = {
if (legacyCalculation) {
super.getPercentile(accumulatedCounts, percentile)
} else {
// `percentile_disc(p)` returns the value with the smallest `cume_dist()` value given that is
// greater than or equal to `p` so `position` here is `p` adjusted by max position.
val position = accumulatedCounts.last._2 * percentile

val higher = position.ceil.toLong

// Use binary search to find the higher position.
val countsArray = accumulatedCounts.map(_._2).toArray[Long]
val higherIndex = binarySearchCount(countsArray, 0, accumulatedCounts.size, higher)
val higherKey = accumulatedCounts(higherIndex)._1

toDoubleValue(higherKey)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4368,6 +4368,16 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val LEGACY_PERCENTILE_DISC_CALCULATION = buildConf("spark.sql.legacy.percentileDiscCalculation")
.internal()
.doc("If true, the old bogus percentile_disc calculation is used. The old calculation " +
"incorrectly mapped the requested percentile to the sorted range of values in some cases " +
"and so returned incorrect results. Also, the new implementation is faster as it doesn't " +
"contain the interpolation logic that the old percentile_cont based one did.")
.version("3.3.4")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bug was introduced with the very first version of percentile_disc in 3.3.0: https://issues.apache.org/jira/browse/SPARK-37691 so 3.3.4 seems to be the earliest still active release where we should backport this fix to.

.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Loading