Skip to content

Commit 0a440e9

Browse files
ulysses-youcloud-fan
authored andcommitted
[SPARK-41793][SQL] Incorrect result for window frames defined by a range clause on large decimals
### What changes were proposed in this pull request? Use `DecimalAddNoOverflowCheck` instead of `Add` to craete bound ordering for window range frame ### Why are the changes needed? Before 3.4, the `Add` did not check overflow. Instead, we always wrapped `Add` with a `CheckOverflow`. After #36698, we make `Add` check overflow by itself. However, the bound ordering of window range frame uses `Add` to calculate the boundary that is used to determine which input row lies within the frame boundaries of an output row. Then the behavior is changed with an extra overflow check. Technically,We could allow an overflowing value if it is just an intermediate result. So this pr use `DecimalAddNoOverflowCheck` to replace the `Add` to restore the previous behavior. ### Does this PR introduce _any_ user-facing change? yes, restore the previous(before 3.4) behavior ### How was this patch tested? add test Closes #40138 from ulysses-you/SPARK-41793. Authored-by: ulysses-you <ulyssesyou18@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit fec4f7f) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 2d9a963 commit 0a440e9

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ case class CheckOverflowInSum(
213213
}
214214

215215
/**
216-
* An add expression for decimal values which is only used internally by Sum/Avg.
216+
* An add expression for decimal values which is only used internally by Sum/Avg/Window.
217217
*
218218
* Nota that, this expression does not check overflow which is different with `Add`. When
219219
* aggregating values, Spark writes the aggregation buffer values to `UnsafeRow` via

sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ trait WindowExecBase extends UnaryExecNode {
128128
TimestampAddYMInterval(expr, boundOffset, Some(timeZone))
129129
case (TimestampType | TimestampNTZType, _: DayTimeIntervalType) =>
130130
TimeAdd(expr, boundOffset, Some(timeZone))
131+
case (d: DecimalType, _: DecimalType) => DecimalAddNoOverflowCheck(expr, boundOffset, d)
131132
case (a, b) if a == b => Add(expr, boundOffset)
132133
}
133134
val bound = MutableProjection.create(boundExpr :: Nil, child.output)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql
1919

20-
import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, NonFoldableLiteral, RangeFrame, SortOrder, SpecifiedWindowFrame, UnspecifiedFrame}
20+
import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, NonFoldableLiteral, RangeFrame, SortOrder, SpecifiedWindowFrame, UnaryMinus, UnspecifiedFrame}
2121
import org.apache.spark.sql.catalyst.plans.logical.{Window => WindowNode}
2222
import org.apache.spark.sql.expressions.{Window, WindowSpec}
2323
import org.apache.spark.sql.functions._
@@ -474,4 +474,22 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession {
474474
checkAnswer(df,
475475
Row(3, 1.5) :: Row(3, 1.5) :: Row(6, 2.0) :: Row(6, 2.0) :: Row(6, 2.0) :: Nil)
476476
}
477+
478+
test("SPARK-41793: Incorrect result for window frames defined by a range clause on large " +
479+
"decimals") {
480+
val window = new WindowSpec(Seq($"a".expr), Seq(SortOrder($"b".expr, Ascending)),
481+
SpecifiedWindowFrame(RangeFrame,
482+
UnaryMinus(Literal(BigDecimal(10.2345))), Literal(BigDecimal(6.7890))))
483+
484+
val df = Seq(
485+
1 -> "11342371013783243717493546650944543.47",
486+
1 -> "999999999999999999999999999999999999.99"
487+
).toDF("a", "b")
488+
.select($"a", $"b".cast("decimal(38, 2)"))
489+
.select(count("*").over(window))
490+
491+
checkAnswer(
492+
df,
493+
Row(1) :: Row(1) :: Nil)
494+
}
477495
}

0 commit comments

Comments
 (0)