Skip to content

Commit

Permalink
[SPARK-44871][SQL][3.3] Fix percentile_disc behaviour
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR fixes `percentile_disc()` function as currently it returns inforrect results in some cases. E.g.:
```
SELECT
  percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
  percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
  percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
  percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
  percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
  percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
  percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
  percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
  percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
  percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
  percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2), (3), (4) AS v(a)
```
currently returns:
```
+---+---+---+---+---+---+---+---+---+---+---+
| p0| p1| p2| p3| p4| p5| p6| p7| p8| p9|p10|
+---+---+---+---+---+---+---+---+---+---+---+
|0.0|0.0|0.0|1.0|1.0|2.0|2.0|2.0|3.0|3.0|4.0|
+---+---+---+---+---+---+---+---+---+---+---+
```
but after this PR it returns the correct:
```
+---+---+---+---+---+---+---+---+---+---+---+
| p0| p1| p2| p3| p4| p5| p6| p7| p8| p9|p10|
+---+---+---+---+---+---+---+---+---+---+---+
|0.0|0.0|0.0|1.0|1.0|2.0|2.0|3.0|3.0|4.0|4.0|
+---+---+---+---+---+---+---+---+---+---+---+
```

### Why are the changes needed?
Bugfix.

### Does this PR introduce _any_ user-facing change?
Yes, fixes a correctness bug, but the old behaviour can be restored with `spark.sql.legacy.percentileDiscCalculation=true`.

### How was this patch tested?
Added new UTs.

Closes #42611 from peter-toth/SPARK-44871-fix-percentile-disc-behaviour-3.3.

Authored-by: Peter Toth <peter.toth@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
peter-toth authored and MaxGekk committed Aug 23, 2023
1 parent 352810b commit aa6f6f7
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap

Expand Down Expand Up @@ -154,11 +155,8 @@ abstract class PercentileBase extends TypedImperativeAggregate[OpenHashMap[AnyRe
val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) {
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
}.tail
val maxPosition = accumulatedCounts.last._2 - 1

percentages.map { percentile =>
getPercentile(accumulatedCounts, maxPosition * percentile)
}
percentages.map(getPercentile(accumulatedCounts, _))
}

private def generateOutput(results: Seq[Double]): Any = {
Expand All @@ -176,8 +174,11 @@ abstract class PercentileBase extends TypedImperativeAggregate[OpenHashMap[AnyRe
* This function has been based upon similar function from HIVE
* `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
*/
private def getPercentile(
accumulatedCounts: Seq[(AnyRef, Long)], position: Double): Double = {
protected def getPercentile(
accumulatedCounts: Seq[(AnyRef, Long)],
percentile: Double): Double = {
val position = (accumulatedCounts.last._2 - 1) * percentile

// We may need to do linear interpolation to get the exact percentile
val lower = position.floor.toLong
val higher = position.ceil.toLong
Expand All @@ -200,6 +201,7 @@ abstract class PercentileBase extends TypedImperativeAggregate[OpenHashMap[AnyRe
}

if (discrete) {
// We end up here only if spark.sql.legacy.percentileDiscCalculation=true
toDoubleValue(lowerKey)
} else {
// Linear interpolation to get the exact percentile
Expand Down Expand Up @@ -389,7 +391,9 @@ case class PercentileDisc(
percentageExpression: Expression,
reverse: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends PercentileBase with BinaryLike[Expression] {
inputAggBufferOffset: Int = 0,
legacyCalculation: Boolean = SQLConf.get.getConf(SQLConf.LEGACY_PERCENTILE_DISC_CALCULATION))
extends PercentileBase with BinaryLike[Expression] {

val frequencyExpression: Expression = Literal(1L)

Expand Down Expand Up @@ -417,4 +421,25 @@ case class PercentileDisc(
child = newLeft,
percentageExpression = newRight
)

override protected def getPercentile(
accumulatedCounts: Seq[(AnyRef, Long)],
percentile: Double): Double = {
if (legacyCalculation) {
super.getPercentile(accumulatedCounts, percentile)
} else {
// `percentile_disc(p)` returns the value with the smallest `cume_dist()` value given that is
// greater than or equal to `p` so `position` here is `p` adjusted by max position.
val position = accumulatedCounts.last._2 * percentile

val higher = position.ceil.toLong

// Use binary search to find the higher position.
val countsArray = accumulatedCounts.map(_._2).toArray[Long]
val higherIndex = binarySearchCount(countsArray, 0, accumulatedCounts.size, higher)
val higherKey = accumulatedCounts(higherIndex)._1

toDoubleValue(higherKey)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3774,6 +3774,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_PERCENTILE_DISC_CALCULATION = buildConf("spark.sql.legacy.percentileDiscCalculation")
.internal()
.doc("If true, the old bogus percentile_disc calculation is used. The old calculation " +
"incorrectly mapped the requested percentile to the sorted range of values in some cases " +
"and so returned incorrect results. Also, the new implementation is faster as it doesn't " +
"contain the interpolation logic that the old percentile_cont based one did.")
.version("3.3.4")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
74 changes: 74 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/percentiles.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
-- SPARK-44871: Fix percentile_disc behaviour
SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0) AS v(a);

SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1) AS v(a);

SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2) AS v(a);

SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2), (3), (4) AS v(a);

SET spark.sql.legacy.percentileDiscCalculation = true;

SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2), (3), (4) AS v(a);

SET spark.sql.legacy.percentileDiscCalculation = false;
118 changes: 118 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/percentiles.sql.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 7


-- !query
SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0) AS v(a)
-- !query schema
struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double>
-- !query output
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0


-- !query
SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1) AS v(a)
-- !query schema
struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double>
-- !query output
0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0


-- !query
SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2) AS v(a)
-- !query schema
struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double>
-- !query output
0.0 0.0 0.0 0.0 1.0 1.0 1.0 2.0 2.0 2.0 2.0


-- !query
SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2), (3), (4) AS v(a)
-- !query schema
struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double>
-- !query output
0.0 0.0 0.0 1.0 1.0 2.0 2.0 3.0 3.0 4.0 4.0


-- !query
SET spark.sql.legacy.percentileDiscCalculation = true
-- !query schema
struct<key:string,value:string>
-- !query output
spark.sql.legacy.percentileDiscCalculation true


-- !query
SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2), (3), (4) AS v(a)
-- !query schema
struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double>
-- !query output
0.0 0.0 0.0 1.0 1.0 2.0 2.0 2.0 3.0 3.0 4.0


-- !query
SET spark.sql.legacy.percentileDiscCalculation = false
-- !query schema
struct<key:string,value:string>
-- !query output
spark.sql.legacy.percentileDiscCalculation false

0 comments on commit aa6f6f7

Please sign in to comment.