Skip to content

Commit 2870c76

Browse files
committed
[SPARK-48016][SQL][3.4] Fix a bug in try_divide function when with decimals
### What changes were proposed in this pull request? Currently, the following query will throw DIVIDE_BY_ZERO error instead of returning null ``` SELECT try_divide(1, decimal(0)); ``` This is caused by the rule `DecimalPrecision`: ``` case b BinaryOperator(left, right) if left.dataType != right.dataType => (left, right) match { ... case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] && l.dataType.isInstanceOf[IntegralType] && literalPickMinimumPrecision => b.makeCopy(Array(Cast(l, DataTypeUtils.fromLiteral(l)), r)) ``` The result of the above makeCopy will contain `ANSI` as the `evalMode`, instead of `TRY`. This PR is to fix this bug by replacing the makeCopy method calls with withNewChildren ### Why are the changes needed? Bug fix in try_* functions. ### Does this PR introduce _any_ user-facing change? Yes, it fixes a long-standing bug in the try_divide function. ### How was this patch tested? New UT ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#46289 from gengliangwang/PICK_PR_46286_BRANCH-3.4. Authored-by: Gengliang Wang <gengliang@apache.org> Signed-off-by: Gengliang Wang <gengliang@apache.org>
1 parent e2f34c7 commit 2870c76

File tree

7 files changed

+1130
-12
lines changed

7 files changed

+1130
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ object DecimalPrecision extends TypeCoercionRule {
8282
val resultType = widerDecimalType(p1, s1, p2, s2)
8383
val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType)
8484
val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType)
85-
b.makeCopy(Array(newE1, newE2))
85+
b.withNewChildren(Seq(newE1, newE2))
8686
}
8787

8888
/**
@@ -201,21 +201,21 @@ object DecimalPrecision extends TypeCoercionRule {
201201
case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
202202
l.dataType.isInstanceOf[IntegralType] &&
203203
literalPickMinimumPrecision =>
204-
b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r))
204+
b.withNewChildren(Seq(Cast(l, DecimalType.fromLiteral(l)), r))
205205
case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] &&
206206
r.dataType.isInstanceOf[IntegralType] &&
207207
literalPickMinimumPrecision =>
208-
b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r))))
208+
b.withNewChildren(Seq(l, Cast(r, DecimalType.fromLiteral(r))))
209209
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
210210
// and fixed-precision decimals in an expression with floats / doubles to doubles
211211
case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) =>
212-
b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r))
212+
b.withNewChildren(Seq(Cast(l, DecimalType.forType(l.dataType)), r))
213213
case (l @ DecimalType.Expression(_, _), r @ IntegralType()) =>
214-
b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType))))
214+
b.withNewChildren(Seq(l, Cast(r, DecimalType.forType(r.dataType))))
215215
case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) =>
216-
b.makeCopy(Array(l, Cast(r, DoubleType)))
216+
b.withNewChildren(Seq(l, Cast(r, DoubleType)))
217217
case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) =>
218-
b.makeCopy(Array(Cast(l, DoubleType), r))
218+
b.withNewChildren(Seq(Cast(l, DoubleType), r))
219219
case _ => b
220220
}
221221
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,22 +1106,22 @@ object TypeCoercion extends TypeCoercionBase {
11061106

11071107
case a @ BinaryArithmetic(left @ StringType(), right)
11081108
if right.dataType != CalendarIntervalType =>
1109-
a.makeCopy(Array(Cast(left, DoubleType), right))
1109+
a.withNewChildren(Seq(Cast(left, DoubleType), right))
11101110
case a @ BinaryArithmetic(left, right @ StringType())
11111111
if left.dataType != CalendarIntervalType =>
1112-
a.makeCopy(Array(left, Cast(right, DoubleType)))
1112+
a.withNewChildren(Seq(left, Cast(right, DoubleType)))
11131113

11141114
// For equality between string and timestamp we cast the string to a timestamp
11151115
// so that things like rounding of subsecond precision does not affect the comparison.
11161116
case p @ Equality(left @ StringType(), right @ TimestampType()) =>
1117-
p.makeCopy(Array(Cast(left, TimestampType), right))
1117+
p.withNewChildren(Seq(Cast(left, TimestampType), right))
11181118
case p @ Equality(left @ TimestampType(), right @ StringType()) =>
1119-
p.makeCopy(Array(left, Cast(right, TimestampType)))
1119+
p.withNewChildren(Seq(left, Cast(right, TimestampType)))
11201120

11211121
case p @ BinaryComparison(left, right)
11221122
if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined =>
11231123
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get
1124-
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
1124+
p.withNewChildren(Seq(castExpr(left, commonType), castExpr(right, commonType)))
11251125
}
11261126
}
11271127

0 commit comments

Comments
 (0)