Skip to content

Commit 9c3df7d

Browse files
committed
use larger intermediate buffer for sum
1 parent 8591417 commit 9c3df7d

File tree

5 files changed

+26
-8
lines changed

5 files changed

+26
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala

+8-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
4545
case _ => DoubleType
4646
}
4747

48-
private lazy val sumDataType = resultType
48+
private lazy val sumDataType = child.dataType match {
49+
case LongType => DecimalType.BigIntDecimal
50+
case _ => resultType
51+
}
52+
53+
private lazy val castToResultType: (Expression) => Expression =
54+
if (sumDataType == resultType) (e: Expression) => e else (e: Expression) => Cast(e, resultType)
4955

5056
private lazy val sum = AttributeReference("sum", sumDataType)()
5157

@@ -78,5 +84,5 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
7884
)
7985
}
8086

81-
override lazy val evaluateExpression: Expression = sum
87+
override lazy val evaluateExpression: Expression = castToResultType(sum)
8288
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ final class Decimal extends Ordered[Decimal] with Serializable {
215215
if (decimalVal.eq(null)) {
216216
longVal / POW_10(_scale)
217217
} else {
218-
decimalVal.longValue()
218+
// This will throw an exception if overflow occurs
219+
decimalVal.toLongExact
219220
}
220221
}
221222

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

+6
Original file line numberDiff line numberDiff line change
@@ -922,4 +922,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
922922
val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType)
923923
checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]")
924924
}
925+
926+
test("SPARK-24598: Cast to long should fail on overflow") {
927+
checkExceptionInExpression[ArithmeticException](
928+
cast(Literal.create(Decimal(Long.MaxValue) + Decimal(1)), LongType), "Overflow")
929+
checkEvaluation(cast(Literal.create(Decimal(Long.MaxValue)), LongType), Long.MaxValue)
930+
}
925931
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala

+1-5
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
114114
val random = new Random(seed)
115115

116116
def randomBound(): Long = {
117-
val n = if (random.nextBoolean()) {
118-
random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS))
119-
} else {
120-
random.nextLong() / 2
121-
}
117+
val n = random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS))
122118
if (random.nextBoolean()) n else -n
123119
}
124120

sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala

+9
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.apache.spark.SparkException
2021
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2122
import org.apache.spark.sql.expressions.Aggregator
2223
import org.apache.spark.sql.expressions.scalalang.typed
@@ -333,4 +334,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
333334
df.groupBy($"i").agg(VeryComplexResultAgg.toColumn),
334335
Row(1, Row(Row(1, "a"), Row(1, "a"))) :: Row(2, Row(Row(2, "bc"), Row(2, "bc"))) :: Nil)
335336
}
337+
338+
test("SPARK-24598: sum throws exception instead of silently overflow") {
339+
val df1 = Seq(Long.MinValue, -10, Long.MaxValue).toDF("i")
340+
checkAnswer(df1.agg(sum($"i")), Row(-11))
341+
val df2 = Seq(Long.MinValue, -10, 8).toDF("i")
342+
val e = intercept[SparkException](df2.agg(sum($"i")).collect())
343+
assert(e.getCause.isInstanceOf[ArithmeticException])
344+
}
336345
}

0 commit comments

Comments
 (0)