Skip to content

Commit

Permalink
[3.0][SQL] Revert SPARK-32018
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Revert SPARK-32018 related changes in branch 3.0: apache#29125 and  apache#29404

### Why are the changes needed?

apache#29404 is made to fix correctness regression introduced by apache#29125. However, the behavior of decimal overflow is strange in non-ansi mode:
1. from 3.0.0 to 3.0.1: decimal overflow will throw exceptions instead of returning null on decimal overflow
2. from 3.0.1 to 3.1.0: decimal overflow will return null instead of throwing exceptions.

So, this PR proposes to revert both apache#29404 and apache#29125. So that Spark will return null on decimal overflow in Spark 3.0.0 and Spark 3.0.1.

### Does this PR introduce _any_ user-facing change?

Yes, Spark will return null on decimal overflow in Spark 3.0.1.

### How was this patch tested?

Unit tests

Closes apache#29450 from gengliangwang/revertDecimalOverflow.

Authored-by: Gengliang Wang <gengliang.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
gengliangwang authored and cloud-fan committed Aug 17, 2020
1 parent c4807ce commit ee12374
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ public void setDecimal(int ordinal, Decimal value, int precision) {
Platform.putLong(baseObject, baseOffset + cursor, 0L);
Platform.putLong(baseObject, baseOffset + cursor + 8, 0L);

if (value == null || !value.changePrecision(precision, value.scale())) {
if (value == null) {
setNullAt(ordinal);
// keep the offset for future update
Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,36 +71,23 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
)

override lazy val updateExpressions: Seq[Expression] = {
val sumWithChild = resultType match {
case d: DecimalType =>
CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, nullOnOverflow = false)
case _ =>
coalesce(sum, zero) + child.cast(sumDataType)
}

if (child.nullable) {
Seq(
/* sum = */
coalesce(sumWithChild, sum)
coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
)
} else {
Seq(
/* sum = */
sumWithChild
coalesce(sum, zero) + child.cast(sumDataType)
)
}
}

override lazy val mergeExpressions: Seq[Expression] = {
val sumWithRight = resultType match {
case d: DecimalType =>
CheckOverflow(coalesce(sum.left, zero) + sum.right, d, nullOnOverflow = false)

case _ => coalesce(sum.left, zero) + sum.right
}
Seq(
/* sum = */
coalesce(sumWithRight, sum.left)
coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.util.Random

import org.scalatest.Matchers.the

import org.apache.spark.SparkException
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
Expand Down Expand Up @@ -1045,42 +1044,6 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(sql(queryTemplate("FIRST")), Row(1))
checkAnswer(sql(queryTemplate("LAST")), Row(3))
}

private def exceptionOnDecimalOverflow(df: DataFrame): Unit = {
val msg = intercept[SparkException] {
df.collect()
}.getCause.getMessage
assert(msg.contains("cannot be represented as Decimal(38, 18)"))
}

test("SPARK-32018: Throw exception on decimal overflow at partial aggregate phase") {
val decimalString = "1" + "0" * 19
val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
val hashAgg = union
.select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key"))
.groupBy("key")
.agg(sum($"d").alias("sumD"))
.select($"sumD")
exceptionOnDecimalOverflow(hashAgg)

val sortAgg = union
.select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("a").as("str"),
lit("1").as("key")).groupBy("key")
.agg(sum($"d").alias("sumD"), min($"str").alias("minStr")).select($"sumD", $"minStr")
exceptionOnDecimalOverflow(sortAgg)
}

test("SPARK-32018: Throw exception on decimal overflow at merge aggregation phase") {
val decimalString = "5" + "0" * 19
val union = spark.range(0, 1, 1, 1).union(spark.range(0, 1, 1, 1))
.union(spark.range(0, 1, 1, 1))
val agg = union
.select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key"))
.groupBy("key")
.agg(sum($"d").alias("sumD"))
.select($"sumD")
exceptionOnDecimalOverflow(agg)
}
}

case class B(c: Option[Double])
Expand Down
10 changes: 0 additions & 10 deletions sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,4 @@ class UnsafeRowSuite extends SparkFunSuite {
// Makes sure hashCode on unsafe array won't crash
unsafeRow.getArray(0).hashCode()
}

test("SPARK-32018: setDecimal with overflowed value") {
val d1 = new Decimal().set(BigDecimal("10000000000000000000")).toPrecision(38, 18)
val row = InternalRow.apply(d1)
val unsafeRow = UnsafeProjection.create(Array[DataType](DecimalType(38, 18))).apply(row)
assert(unsafeRow.getDecimal(0, 38, 18) === d1)
val d2 = (d1 * Decimal(10)).toPrecision(39, 18)
unsafeRow.setDecimal(0, d2, 38)
assert(unsafeRow.getDecimal(0, 38, 18) === null)
}
}

0 comments on commit ee12374

Please sign in to comment.