Skip to content

Commit

Permalink
[SPARK-44871][SQL] Fix percentile_disc behaviour
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR fixes `percentile_disc()` function as currently it returns inforrect results in some cases. E.g.:
```
SELECT
  percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
  percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
  percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
  percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
  percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
  percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
  percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
  percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
  percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
  percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
  percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2), (3), (4) AS v(a)
```
currently returns:
```
+---+---+---+---+---+---+---+---+---+---+---+
| p0| p1| p2| p3| p4| p5| p6| p7| p8| p9|p10|
+---+---+---+---+---+---+---+---+---+---+---+
|0.0|0.0|0.0|1.0|1.0|2.0|2.0|2.0|3.0|3.0|4.0|
+---+---+---+---+---+---+---+---+---+---+---+
```
but after this PR it returns the correct:
```
+---+---+---+---+---+---+---+---+---+---+---+
| p0| p1| p2| p3| p4| p5| p6| p7| p8| p9|p10|
+---+---+---+---+---+---+---+---+---+---+---+
|0.0|0.0|0.0|1.0|1.0|2.0|2.0|3.0|3.0|4.0|4.0|
+---+---+---+---+---+---+---+---+---+---+---+
```

### Why are the changes needed?
Bugfix.

### Does this PR introduce _any_ user-facing change?
Yes, fixes a correctness bug, but the old behaviour can be restored with `spark.sql.legacy.percentileDiscCalculation=true`.

### How was this patch tested?
Added new UTs.

Closes apache#42559 from peter-toth/SPARK-44871-fix-percentile-disc-behaviour.

Authored-by: Peter Toth <peter.toth@gmail.com>
Signed-off-by: Peter Toth <peter.toth@gmail.com>
  • Loading branch information
peter-toth committed Aug 22, 2023
1 parent 65b8ca2 commit bd8fbf4
Show file tree
Hide file tree
Showing 7 changed files with 356 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.TypeCollection.NumericAndAnsiInterval
import org.apache.spark.util.collection.OpenHashMap
Expand Down Expand Up @@ -164,11 +165,8 @@ abstract class PercentileBase
val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) {
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
}.tail
val maxPosition = accumulatedCounts.last._2 - 1

percentages.map { percentile =>
getPercentile(accumulatedCounts, maxPosition * percentile)
}
percentages.map(getPercentile(accumulatedCounts, _))
}

private def generateOutput(percentiles: Seq[Double]): Any = {
Expand All @@ -191,8 +189,11 @@ abstract class PercentileBase
* This function has been based upon similar function from HIVE
* `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
*/
private def getPercentile(
accumulatedCounts: Seq[(AnyRef, Long)], position: Double): Double = {
protected def getPercentile(
accumulatedCounts: Seq[(AnyRef, Long)],
percentile: Double): Double = {
val position = (accumulatedCounts.last._2 - 1) * percentile

// We may need to do linear interpolation to get the exact percentile
val lower = position.floor.toLong
val higher = position.ceil.toLong
Expand All @@ -215,6 +216,7 @@ abstract class PercentileBase
}

if (discrete) {
// We end up here only if spark.sql.legacy.percentileDiscCalculation=true
toDoubleValue(lowerKey)
} else {
// Linear interpolation to get the exact percentile
Expand Down Expand Up @@ -384,7 +386,9 @@ case class PercentileDisc(
percentageExpression: Expression,
reverse: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends PercentileBase with BinaryLike[Expression] {
inputAggBufferOffset: Int = 0,
legacyCalculation: Boolean = SQLConf.get.getConf(SQLConf.LEGACY_PERCENTILE_DISC_CALCULATION))
extends PercentileBase with BinaryLike[Expression] {

val frequencyExpression: Expression = Literal(1L)

Expand Down Expand Up @@ -412,4 +416,25 @@ case class PercentileDisc(
child = newLeft,
percentageExpression = newRight
)

override protected def getPercentile(
accumulatedCounts: Seq[(AnyRef, Long)],
percentile: Double): Double = {
if (legacyCalculation) {
super.getPercentile(accumulatedCounts, percentile)
} else {
// `percentile_disc(p)` returns the value with the smallest `cume_dist()` value given that is
// greater than or equal to `p` so `position` here is `p` adjusted by max position.
val position = accumulatedCounts.last._2 * percentile

val higher = position.ceil.toLong

// Use binary search to find the higher position.
val countsArray = accumulatedCounts.map(_._2).toArray[Long]
val higherIndex = binarySearchCount(countsArray, 0, accumulatedCounts.size, higher)
val higherKey = accumulatedCounts(higherIndex)._1

toDoubleValue(higherKey)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4368,6 +4368,16 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val LEGACY_PERCENTILE_DISC_CALCULATION = buildConf("spark.sql.legacy.percentileDiscCalculation")
.internal()
.doc("If true, the old bogus percentile_disc calculation is used. The old calculation " +
"incorrectly mapped the requested percentile to the sorted range of values in some cases " +
"and so returned incorrect results. Also, the new implementation is faster as it doesn't " +
"contain the interpolation logic that the old percentile_cont based one did.")
.version("3.3.4")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ Aggregate [percentile_cont(thousand#x, cast(0.5 as double), false) AS percentile
-- !query
select percentile_disc(0.5) within group (order by thousand) from tenk1
-- !query analysis
Aggregate [percentile_disc(thousand#x, cast(0.5 as double), false, 0, 0) AS percentile_disc(0.5) WITHIN GROUP (ORDER BY thousand)#x]
Aggregate [percentile_disc(thousand#x, cast(0.5 as double), false, 0, 0, false) AS percentile_disc(0.5) WITHIN GROUP (ORDER BY thousand)#x]
+- SubqueryAlias spark_catalog.default.tenk1
+- Relation spark_catalog.default.tenk1[unique1#x,unique2#x,two#x,four#x,ten#x,twenty#x,hundred#x,thousand#x,twothousand#x,fivethous#x,tenthous#x,odd#x,even#x,stringu1#x,stringu2#x,string4#x] parquet
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ Aggregate [percentile_cont(thousand#x, cast(0.5 as double), false) AS percentile
-- !query
select percentile_disc(0.5) within group (order by thousand) from tenk1
-- !query analysis
Aggregate [percentile_disc(thousand#x, cast(0.5 as double), false, 0, 0) AS percentile_disc(0.5) WITHIN GROUP (ORDER BY thousand)#x]
Aggregate [percentile_disc(thousand#x, cast(0.5 as double), false, 0, 0, false) AS percentile_disc(0.5) WITHIN GROUP (ORDER BY thousand)#x]
+- SubqueryAlias spark_catalog.default.tenk1
+- Relation spark_catalog.default.tenk1[unique1#x,unique2#x,two#x,four#x,ten#x,twenty#x,hundred#x,thousand#x,twothousand#x,fivethous#x,tenthous#x,odd#x,even#x,stringu1#x,stringu2#x,string4#x] parquet
77 changes: 76 additions & 1 deletion sql/core/src/test/resources/sql-tests/inputs/percentiles.sql
Original file line number Diff line number Diff line change
Expand Up @@ -299,4 +299,79 @@ SELECT
percentile_cont(0.5) WITHIN GROUP (ORDER BY dt2)
FROM intervals
GROUP BY k
ORDER BY k;
ORDER BY k;

-- SPARK-44871: Fix percentile_disc behaviour
SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0) AS v(a);

SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1) AS v(a);

SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2) AS v(a);

SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2), (3), (4) AS v(a);

SET spark.sql.legacy.percentileDiscCalculation = true;

SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2), (3), (4) AS v(a);

SET spark.sql.legacy.percentileDiscCalculation = false;
116 changes: 116 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/percentiles.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -730,3 +730,119 @@ struct<k:int,median(dt2):interval day to second,percentile(dt2, 0.5, 1):interval
2 0 00:22:30.000000000 0 00:22:30.000000000 0 00:22:30.000000000
3 0 01:00:00.000000000 0 01:00:00.000000000 0 01:00:00.000000000
4 NULL NULL NULL


-- !query
SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0) AS v(a)
-- !query schema
struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double>
-- !query output
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0


-- !query
SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1) AS v(a)
-- !query schema
struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double>
-- !query output
0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0


-- !query
SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2) AS v(a)
-- !query schema
struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double>
-- !query output
0.0 0.0 0.0 0.0 1.0 1.0 1.0 2.0 2.0 2.0 2.0


-- !query
SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2), (3), (4) AS v(a)
-- !query schema
struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double>
-- !query output
0.0 0.0 0.0 1.0 1.0 2.0 2.0 3.0 3.0 4.0 4.0


-- !query
SET spark.sql.legacy.percentileDiscCalculation = true
-- !query schema
struct<key:string,value:string>
-- !query output
spark.sql.legacy.percentileDiscCalculation true


-- !query
SELECT
percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0,
percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1,
percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2,
percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3,
percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4,
percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5,
percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6,
percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7,
percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8,
percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9,
percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10
FROM VALUES (0), (1), (2), (3), (4) AS v(a)
-- !query schema
struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double>
-- !query output
0.0 0.0 0.0 1.0 1.0 2.0 2.0 2.0 3.0 3.0 4.0


-- !query
SET spark.sql.legacy.percentileDiscCalculation = false
-- !query schema
struct<key:string,value:string>
-- !query output
spark.sql.legacy.percentileDiscCalculation false

0 comments on commit bd8fbf4

Please sign in to comment.