Skip to content

Commit 8f8353b

Browse files
authored
fix : cast_operands_to_decimal_type_to_fix_arithmetic_overflow (#1996)
1 parent ddab352 commit 8f8353b

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

spark/src/main/scala/org/apache/comet/serde/arithmetic.scala

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import scala.math.min
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Divide, EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, Remainder, Subtract}
24+
import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Cast, Divide, EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, Remainder, Subtract}
2525
import org.apache.spark.sql.types.{ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType}
2626

2727
import org.apache.comet.CometSparkSessionExtensions.withInfo
@@ -201,7 +201,6 @@ object CometIntegralDivide extends CometExpressionSerde with MathBase {
201201
inputs: Seq[Attribute],
202202
binding: Boolean): Option[ExprOuterClass.Expr] = {
203203
val div = expr.asInstanceOf[IntegralDivide]
204-
val rightExpr = nullIfWhenPrimitive(div.right)
205204

206205
if (!supportedDataType(div.left.dataType)) {
207206
withInfo(div, s"Unsupported datatype ${div.left.dataType}")
@@ -212,17 +211,28 @@ object CometIntegralDivide extends CometExpressionSerde with MathBase {
212211
return None
213212
}
214213

215-
val dataType = (div.left.dataType, div.right.dataType) match {
214+
// Precision is set to 19 (max precision for a numerical data type except DecimalType)
215+
216+
val left =
217+
if (div.left.dataType.isInstanceOf[DecimalType]) div.left
218+
else Cast(div.left, DecimalType(19, 0))
219+
val right =
220+
if (div.right.dataType.isInstanceOf[DecimalType]) div.right
221+
else Cast(div.right, DecimalType(19, 0))
222+
223+
val rightExpr = nullIfWhenPrimitive(right)
224+
225+
val dataType = (left.dataType, right.dataType) match {
216226
case (l: DecimalType, r: DecimalType) =>
217227
// copy from IntegralDivide.resultDecimalType
218228
val intDig = l.precision - l.scale + r.scale
219229
DecimalType(min(if (intDig == 0) 1 else intDig, DecimalType.MAX_PRECISION), 0)
220-
case _ => div.left.dataType
230+
case _ => left.dataType
221231
}
222232

223233
val divideExpr = createMathExpression(
224234
expr,
225-
div.left,
235+
left,
226236
rightExpr,
227237
inputs,
228238
binding,

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,18 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
113113
}
114114
}
115115

116+
test("Integral Division Overflow Handling Matches Spark Behavior") {
117+
withTable("t1") {
118+
withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") {
119+
val value = Long.MinValue
120+
sql("create table t1(c1 long, c2 short) using parquet")
121+
sql(s"insert into t1 values($value, -1)")
122+
val res = sql("select c1 div c2 from t1 order by c1")
123+
checkSparkAnswerAndOperator(res)
124+
}
125+
}
126+
}
127+
116128
test("basic data type support") {
117129
Seq(true, false).foreach { dictionaryEnabled =>
118130
withTempDir { dir =>
@@ -2686,7 +2698,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
26862698

26872699
test("test integral divide") {
26882700
// this test requires native_comet scan due to unsigned u8/u16 issue
2689-
withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_COMET) {
2701+
withSQLConf(
2702+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_COMET,
2703+
CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") {
26902704
Seq(true, false).foreach { dictionaryEnabled =>
26912705
withTempDir { dir =>
26922706
val path1 = new Path(dir.toURI.toString, "test1.parquet")

0 commit comments

Comments
 (0)