diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala index 8447a5f9b5153..da04c5a1c8a1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.types.TypeCollection.NumericAndAnsiInterval import org.apache.spark.util.collection.OpenHashMap @@ -168,11 +169,8 @@ abstract class PercentileBase val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) }.tail - val maxPosition = accumulatedCounts.last._2 - 1 - percentages.map { percentile => - getPercentile(accumulatedCounts, maxPosition * percentile) - } + percentages.map(getPercentile(accumulatedCounts, _)) } private def generateOutput(percentiles: Seq[Double]): Any = { @@ -195,8 +193,11 @@ abstract class PercentileBase * This function has been based upon similar function from HIVE * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. */ - private def getPercentile( - accumulatedCounts: Seq[(AnyRef, Long)], position: Double): Double = { + protected def getPercentile( + accumulatedCounts: Seq[(AnyRef, Long)], + percentile: Double): Double = { + val position = (accumulatedCounts.last._2 - 1) * percentile + // We may need to do linear interpolation to get the exact percentile val lower = position.floor.toLong val higher = position.ceil.toLong @@ -219,6 +220,7 @@ abstract class PercentileBase } if (discrete) { + // We end up here only if spark.sql.legacy.percentileDiscCalculation=true toDoubleValue(lowerKey) } else { // Linear interpolation to get the exact percentile @@ -388,7 +390,9 @@ case class PercentileDisc( percentageExpression: Expression, reverse: Boolean = false, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends PercentileBase with BinaryLike[Expression] { + inputAggBufferOffset: Int = 0, + legacyCalculation: Boolean = SQLConf.get.getConf(SQLConf.LEGACY_PERCENTILE_DISC_CALCULATION)) + extends PercentileBase with BinaryLike[Expression] { val frequencyExpression: Expression = Literal(1L) @@ -416,4 +420,25 @@ case class PercentileDisc( child = newLeft, percentageExpression = newRight ) + + override protected def getPercentile( + accumulatedCounts: Seq[(AnyRef, Long)], + percentile: Double): Double = { + if (legacyCalculation) { + super.getPercentile(accumulatedCounts, percentile) + } else { + // `percentile_disc(p)` returns the value with the smallest `cume_dist()` value given that is + // greater than or equal to `p` so `position` here is `p` adjusted by max position. + val position = accumulatedCounts.last._2 * percentile + + val higher = position.ceil.toLong + + // Use binary search to find the higher position. + val countsArray = accumulatedCounts.map(_._2).toArray[Long] + val higherIndex = binarySearchCount(countsArray, 0, accumulatedCounts.size, higher) + val higherKey = accumulatedCounts(higherIndex)._1 + + toDoubleValue(higherKey) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 743c53c7f62be..1a55b990edbe3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4215,6 +4215,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val LEGACY_PERCENTILE_DISC_CALCULATION = buildConf("spark.sql.legacy.percentileDiscCalculation") + .internal() + .doc("If true, the old bogus percentile_disc calculation is used. The old calculation " + + "incorrectly mapped the requested percentile to the sorted range of values in some cases " + + "and so returned incorrect results. Also, the new implementation is faster as it doesn't " + + "contain the interpolation logic that the old percentile_cont based one did.") + .version("3.3.4") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * diff --git a/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql b/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql index c55c300b5e805..87c5d4be90ce1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql @@ -299,4 +299,79 @@ SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY dt2) FROM intervals GROUP BY k -ORDER BY k; \ No newline at end of file +ORDER BY k; + +-- SPARK-44871: Fix percentile_disc behaviour +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0) AS v(a); + +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1) AS v(a); + +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1), (2) AS v(a); + +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1), (2), (3), (4) AS v(a); + +SET spark.sql.legacy.percentileDiscCalculation = true; + +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1), (2), (3), (4) AS v(a); + +SET spark.sql.legacy.percentileDiscCalculation = false; diff --git a/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out b/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out index cd99ded56bf96..54d2164621f75 100644 --- a/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out @@ -730,3 +730,119 @@ struct +-- !query output +0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + + +-- !query +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1) AS v(a) +-- !query schema +struct +-- !query output +0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 + + +-- !query +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1), (2) AS v(a) +-- !query schema +struct +-- !query output +0.0 0.0 0.0 0.0 1.0 1.0 1.0 2.0 2.0 2.0 2.0 + + +-- !query +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1), (2), (3), (4) AS v(a) +-- !query schema +struct +-- !query output +0.0 0.0 0.0 1.0 1.0 2.0 2.0 3.0 3.0 4.0 4.0 + + +-- !query +SET spark.sql.legacy.percentileDiscCalculation = true +-- !query schema +struct +-- !query output +spark.sql.legacy.percentileDiscCalculation true + + +-- !query +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1), (2), (3), (4) AS v(a) +-- !query schema +struct +-- !query output +0.0 0.0 0.0 1.0 1.0 2.0 2.0 2.0 3.0 3.0 4.0 + + +-- !query +SET spark.sql.legacy.percentileDiscCalculation = false +-- !query schema +struct +-- !query output +spark.sql.legacy.percentileDiscCalculation false