Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

@ExpressionDescription(
Expand Down Expand Up @@ -89,5 +90,9 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
)
}

override lazy val evaluateExpression: Expression = sum
override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType => CheckOverflow(sum, d, SQLConf.get.decimalOperationsNullOnOverflow)
case _ => sum
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
import org.apache.spark.sql.test.SQLTestData.{NullStrings, TestData2}
import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
Expand Down Expand Up @@ -156,6 +156,27 @@ class DataFrameSuite extends QueryTest with SharedSparkSession {
structDf.select(xxhash64($"a", $"record.*")))
}

test("SPARK-28224: Aggregate sum big decimal overflow") {
val largeDecimals = spark.sparkContext.parallelize(
DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) ::
DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF()

Seq(true, false).foreach { nullOnOverflow =>
withSQLConf((SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key, nullOnOverflow.toString)) {
val structDf = largeDecimals.select("a").agg(sum("a"))
if (nullOnOverflow) {
checkAnswer(structDf, Row(null))
} else {
val e = intercept[SparkException] {
structDf.collect
}
assert(e.getCause.getClass.equals(classOf[ArithmeticException]))
assert(e.getCause.getMessage.contains("cannot be represented as Decimal"))
}
}
}
}

test("Star Expansion - explode should fail with a meaningful message if it takes a star") {
val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv")
val e = intercept[AnalysisException] {
Expand Down