Skip to content

Commit a5b6d71

Browse files
committed
remove unnecessary partial aggregate
1 parent 6b8cb1f commit a5b6d71

File tree

3 files changed

+79
-4
lines changed

3 files changed

+79
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ import java.nio.charset.StandardCharsets
2121
import java.sql.Timestamp
2222

2323
import org.apache.spark.rdd.RDD
24-
import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext}
24+
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
2727
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
2828
import org.apache.spark.sql.catalyst.rules.Rule
2929
import org.apache.spark.sql.catalyst.util.DateTimeUtils
30+
import org.apache.spark.sql.execution.aggregate.MergePartialAggregate
3031
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec}
3132
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
3233
import org.apache.spark.sql.internal.SQLConf
@@ -100,6 +101,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
100101
python.ExtractPythonUDFs,
101102
PlanSubqueries(sparkSession),
102103
EnsureRequirements(sparkSession.sessionState.conf),
104+
MergePartialAggregate,
103105
CollapseCodegenStages(sparkSession.sessionState.conf),
104106
ReuseExchange(sparkSession.sessionState.conf),
105107
ReuseSubquery(sparkSession.sessionState.conf))
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.aggregate
19+
20+
import org.apache.spark.sql.catalyst.expressions.aggregate._
21+
import org.apache.spark.sql.catalyst.rules.Rule
22+
import org.apache.spark.sql.execution.SparkPlan
23+
24+
object MergePartialAggregate extends Rule[SparkPlan] {
25+
26+
override def apply(plan: SparkPlan): SparkPlan = plan transform {
27+
// Normal partial aggregate pair
28+
case outer @ HashAggregateExec(_, _, _, _, _, _, inner: HashAggregateExec)
29+
if outer.aggregateExpressions.forall(_.mode == Final) &&
30+
inner.aggregateExpressions.forall(_.mode == Partial) =>
31+
inner.copy(
32+
aggregateExpressions = inner.aggregateExpressions.map(_.copy(mode = Complete)),
33+
aggregateAttributes = inner.aggregateExpressions.map(_.resultAttribute),
34+
resultExpressions = outer.resultExpressions)
35+
36+
// First partial aggregate pair for aggregation with distinct
37+
case outer @ HashAggregateExec(_, _, _, _, _, _, inner: HashAggregateExec)
38+
if outer.aggregateExpressions.forall(_.mode == PartialMerge) &&
39+
inner.aggregateExpressions.forall(_.mode == Partial) =>
40+
inner
41+
42+
// Second partial aggregate pair for aggregation with distinct.
43+
// This is actually a no-op. For aggregation with distinct, the output of first partial
44+
// aggregate is partitioned by grouping expressions and distinct attributes, and the second
45+
// partial aggregate requires input to be partitioned by grouping attributes, which is not
46+
// satisfied. `EnsureRequirements` will always insert exchange between these 2 aggregate exec
47+
// and we will never hit this branch.
48+
case outer @ HashAggregateExec(_, _, _, _, _, _, inner: HashAggregateExec)
49+
if outer.aggregateExpressions.forall(_.mode == Final) &&
50+
inner.aggregateExpressions.forall(_.mode == PartialMerge) =>
51+
outer.copy(child = inner.child)
52+
53+
// Add similar logic for sort aggregate
54+
}
55+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,21 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.rdd.RDD
21-
import org.apache.spark.sql.{execution, Row}
21+
import org.apache.spark.sql.{execution, QueryTest, Row}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
2424
import org.apache.spark.sql.catalyst.plans.Inner
2525
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
2626
import org.apache.spark.sql.catalyst.plans.physical._
2727
import org.apache.spark.sql.execution.columnar.InMemoryRelation
28-
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange}
28+
import org.apache.spark.sql.execution.exchange._
2929
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
3030
import org.apache.spark.sql.functions._
3131
import org.apache.spark.sql.internal.SQLConf
3232
import org.apache.spark.sql.test.SharedSQLContext
3333
import org.apache.spark.sql.types._
3434

35-
class PlannerSuite extends SharedSQLContext {
35+
class PlannerSuite extends QueryTest with SharedSQLContext {
3636
import testImplicits._
3737

3838
setupTestData()
@@ -518,6 +518,24 @@ class PlannerSuite extends SharedSQLContext {
518518
fail(s"Should have only two shuffles:\n$outputPlan")
519519
}
520520
}
521+
522+
test("no partial aggregation if input relation is already partitioned") {
523+
val input = Seq("a" -> 1, "b" -> 2).toDF("i", "j")
524+
525+
val aggWithoutDistinct = input.repartition($"i").groupBy($"i").agg(sum($"j"))
526+
checkAnswer(aggWithoutDistinct, input.groupBy($"i").agg(sum($"j")))
527+
val numShuffles = aggWithoutDistinct.queryExecution.executedPlan.collect {
528+
case e: Exchange => e
529+
}.length
530+
assert(numShuffles == 1)
531+
532+
val aggWithDistinct = input.repartition($"i", $"j").groupBy($"i").agg(countDistinct($"j"))
533+
checkAnswer(aggWithDistinct, input.groupBy($"i").agg(countDistinct($"j")))
534+
val numShuffles2 = aggWithDistinct.queryExecution.executedPlan.collect {
535+
case e: Exchange => e
536+
}.length
537+
assert(numShuffles2 == 2)
538+
}
521539
}
522540

523541
// Used for unit-testing EnsureRequirements

0 commit comments

Comments
 (0)