diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 4dc5ce1de047b..034894bd86085 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -288,7 +288,7 @@ public void setDecimal(int ordinal, Decimal value, int precision) { Platform.putLong(baseObject, baseOffset + cursor, 0L); Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); - if (value == null || !value.changePrecision(precision, value.scale())) { + if (value == null) { setNullAt(ordinal); // keep the offset for future update Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index d442549f20e80..d2daaac72fc85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -71,36 +71,23 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) override lazy val updateExpressions: Seq[Expression] = { - val sumWithChild = resultType match { - case d: DecimalType => - CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, nullOnOverflow = false) - case _ => - coalesce(sum, zero) + child.cast(sumDataType) - } - if (child.nullable) { Seq( /* sum = */ - coalesce(sumWithChild, sum) + coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) ) } else { Seq( /* sum = */ - sumWithChild + coalesce(sum, zero) + child.cast(sumDataType) ) } } override lazy val mergeExpressions: Seq[Expression] = { - val sumWithRight = resultType match { - case d: DecimalType => - CheckOverflow(coalesce(sum.left, zero) + sum.right, d, nullOnOverflow = false) - - case _ => coalesce(sum.left, zero) + sum.right - } Seq( /* sum = */ - coalesce(sumWithRight, sum.left) + coalesce(coalesce(sum.left, zero) + sum.right, sum.left) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8c0358e205b07..54327b38c100b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,7 +21,6 @@ import scala.util.Random import org.scalatest.Matchers.the -import org.apache.spark.SparkException import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -1045,42 +1044,6 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(sql(queryTemplate("FIRST")), Row(1)) checkAnswer(sql(queryTemplate("LAST")), Row(3)) } - - private def exceptionOnDecimalOverflow(df: DataFrame): Unit = { - val msg = intercept[SparkException] { - df.collect() - }.getCause.getMessage - assert(msg.contains("cannot be represented as Decimal(38, 18)")) - } - - test("SPARK-32018: Throw exception on decimal overflow at partial aggregate phase") { - val decimalString = "1" + "0" * 19 - val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) - val hashAgg = union - .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key")) - .groupBy("key") - .agg(sum($"d").alias("sumD")) - .select($"sumD") - exceptionOnDecimalOverflow(hashAgg) - - val sortAgg = union - .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("a").as("str"), - lit("1").as("key")).groupBy("key") - .agg(sum($"d").alias("sumD"), min($"str").alias("minStr")).select($"sumD", $"minStr") - exceptionOnDecimalOverflow(sortAgg) - } - - test("SPARK-32018: Throw exception on decimal overflow at merge aggregation phase") { - val decimalString = "5" + "0" * 19 - val union = spark.range(0, 1, 1, 1).union(spark.range(0, 1, 1, 1)) - .union(spark.range(0, 1, 1, 1)) - val agg = union - .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key")) - .groupBy("key") - .agg(sum($"d").alias("sumD")) - .select($"sumD") - exceptionOnDecimalOverflow(agg) - } } case class B(c: Option[Double]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 9daa69ce9f155..a5f904c621e6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -178,14 +178,4 @@ class UnsafeRowSuite extends SparkFunSuite { // Makes sure hashCode on unsafe array won't crash unsafeRow.getArray(0).hashCode() } - - test("SPARK-32018: setDecimal with overflowed value") { - val d1 = new Decimal().set(BigDecimal("10000000000000000000")).toPrecision(38, 18) - val row = InternalRow.apply(d1) - val unsafeRow = UnsafeProjection.create(Array[DataType](DecimalType(38, 18))).apply(row) - assert(unsafeRow.getDecimal(0, 38, 18) === d1) - val d2 = (d1 * Decimal(10)).toPrecision(39, 18) - unsafeRow.setDecimal(0, d2, 38) - assert(unsafeRow.getDecimal(0, 38, 18) === null) - } }