Skip to content

Commit 1153f75

Browse files
committed
do best to avoid overflowing in function avg().
1 parent b77c19b commit 1153f75

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,10 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
214214
override def toString = s"AVG($child)"
215215

216216
override def asPartial: SplitEvaluation = {
217-
val partialSum = Alias(Sum(child), "PartialSum")()
218-
val partialCount = Alias(Count(child), "PartialCount")()
219-
val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
220-
val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
217+
val partialSum = Alias(Sum(Cast(child, dataType)), "PartialSum")()
218+
val partialCount = Alias(Cast(Count(child), dataType), "PartialCount")()
219+
val castedSum = Sum(partialSum.toAttribute)
220+
val castedCount = Sum(partialCount.toAttribute)
221221

222222
SplitEvaluation(
223223
Divide(castedSum, castedCount),

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ class SQLQuerySuite extends QueryTest {
123123
2.0)
124124
}
125125

126+
test("average overflow test") {
127+
checkAnswer(
128+
sql("SELECT AVG(a),b FROM testData1 group by b"),
129+
Seq((2147483645.0,1),(2.0,2)))
130+
}
131+
126132
test("count") {
127133
checkAnswer(
128134
sql("SELECT COUNT(*) FROM testData2"),

sql/core/src/test/scala/org/apache/spark/sql/TestData.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ object TestData {
3030
(1 to 100).map(i => TestData(i, i.toString)))
3131
testData.registerAsTable("testData")
3232

33+
case class TestData1(a: Int, b: Int)
34+
val testData1: SchemaRDD =
35+
TestSQLContext.sparkContext.parallelize(
36+
TestData1(2147483644, 1) ::
37+
TestData1(1, 2) ::
38+
TestData1(2147483645, 1) ::
39+
TestData1(2, 2) ::
40+
TestData1(2147483646, 1) ::
41+
TestData1(3, 2) :: Nil)
42+
testData1.registerAsTable("testData1")
43+
3344
case class TestData2(a: Int, b: Int)
3445
val testData2: SchemaRDD =
3546
TestSQLContext.sparkContext.parallelize(

0 commit comments

Comments
 (0)