Skip to content

Commit 2afa238

Browse files
committed
Simplify CountFunction not to traverse to evaluate all child expressions.
1 parent b2bdd0e commit 2afa238

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,8 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
298298
var count: Int = _
299299

300300
override def update(input: Row): Unit = {
301-
val evaluatedExpr = expr.map(_.eval(input))
302-
if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) {
301+
val evaluatedExpr = expr.eval(input)
302+
if (evaluatedExpr != null) {
303303
count += 1
304304
}
305305
}

sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ class DslQuerySuite extends QueryTest {
120120
Seq((1,0), (2, 1))
121121
)
122122

123+
checkAnswer(
124+
testData3.groupBy('a)('a, Count('a + 'b)),
125+
Seq((1,0), (2, 1))
126+
)
127+
123128
checkAnswer(
124129
testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)),
125130
(2, 1, 2, 2, 1) :: Nil

0 commit comments

Comments
 (0)