Skip to content

[SPARK-12978][SQL] Skip unnecessary final group-by when input data already clustered with group-by keys #10896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -259,24 +259,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}

val aggregateOperator =
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
if (functionsWithDistinct.nonEmpty) {
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
"aggregate functions which don't support partial aggregation.")
} else {
aggregate.AggUtils.planAggregateWithoutPartial(
groupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
}
} else if (functionsWithDistinct.isEmpty) {
if (functionsWithDistinct.isEmpty) {
aggregate.AggUtils.planAggregateWithoutDistinct(
groupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
} else {
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
"aggregate functions which don't support partial aggregation.")
}
aggregate.AggUtils.planAggregateWithOneDistinct(
groupingExpressions,
functionsWithDistinct,
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.UnaryExecNode

/**
* A base class for aggregate implementation.
*/
abstract class AggregateExec extends UnaryExecNode {

def requiredChildDistributionExpressions: Option[Seq[Expression]]
def groupingExpressions: Seq[NamedExpression]
def aggregateExpressions: Seq[AggregateExpression]
def aggregateAttributes: Seq[Attribute]
def initialInputBufferOffset: Int
def resultExpressions: Seq[NamedExpression]

protected[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
Expand All @@ -42,11 +41,7 @@ case class HashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode with CodegenSupport {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}
extends AggregateExec with CodegenSupport {

require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))

Expand All @@ -60,21 +55,6 @@ case class HashAggregateExec(
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)

override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}

// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
// map and/or the sort-based aggregation once it has processed a given number of input rows.
private val testFallbackStartsAt: Option[(Int, Int)] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.Utils

Expand All @@ -38,30 +37,11 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)
extends AggregateExec {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.AggUtils
import org.apache.spark.sql.execution.aggregate.PartialAggregate
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -151,18 +153,30 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
var children: Seq[SparkPlan] = operator.children
assert(requiredChildDistributions.length == children.length)
assert(requiredChildOrderings.length == children.length)
assert(requiredChildDistributions.length == operator.children.length)
assert(requiredChildOrderings.length == operator.children.length)

// Ensure that the operator's children satisfy their output distribution requirements:
children = children.zip(requiredChildDistributions).map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
def createShuffleExchange(dist: Distribution, child: SparkPlan) =
ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child)

var (parent, children) = operator match {
case PartialAggregate(childDist) if !operator.outputPartitioning.satisfies(childDist) =>
// If an aggregation needs a shuffle and support partial aggregations, a map-side partial
// aggregation and a shuffle are added as children.
val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator)
(mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil)
case _ =>
// Ensure that the operator's children satisfy their output distribution requirements:
val childrenWithDist = operator.children.zip(requiredChildDistributions)
val newChildren = childrenWithDist.map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
createShuffleExchange(distribution, child)
}
(operator, newChildren)
}

// If the operator has multiple children and specifies child output distributions (e.g. join),
Expand Down Expand Up @@ -246,7 +260,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
}
}

operator.withNewChildren(children)
parent.withNewChildren(children)
}

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}

/**
* Verifies that there is no Exchange between the Aggregations for `df`
* Verifies that there is a single Aggregation for `df`
*/
private def verifyNonExchangingAgg(df: DataFrame) = {
private def verifyNonExchangingSingleAgg(df: DataFrame) = {
var atFirstAgg: Boolean = false
df.queryExecution.executedPlan.foreach {
case agg: HashAggregateExec =>
atFirstAgg = !atFirstAgg
case _ =>
if (atFirstAgg) {
fail("Should not have operators between the two aggregations")
fail("Should not have back to back Aggregates")
}
atFirstAgg = true
case _ =>
}
}

Expand Down Expand Up @@ -1292,9 +1292,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
// Group by the column we are distributed by. This should generate a plan with no exchange
// between the aggregates
val df3 = testData.repartition($"key").groupBy("key").count()
verifyNonExchangingAgg(df3)
verifyNonExchangingAgg(testData.repartition($"key", $"value")
verifyNonExchangingSingleAgg(df3)
verifyNonExchangingSingleAgg(testData.repartition($"key", $"value")
.groupBy("key", "value").count())
verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count())

// Grouping by just the first distributeBy expr, need to exchange.
verifyExchangingAgg(testData.repartition($"key", $"value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.{execution, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.Inner
Expand All @@ -37,36 +37,65 @@ class PlannerSuite extends SharedSQLContext {

setupTestData()

private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = {
val planner = spark.sessionState.planner
import planner._
val plannedOption = Aggregation(query).headOption
val planned =
plannedOption.getOrElse(
fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }

// For the new aggregation code path, there will be four aggregate operator for
// distinct aggregations.
assert(
aggregations.size == 2 || aggregations.size == 4,
s"The plan of query $query does not have partial aggregations.")
val ensureRequirements = EnsureRequirements(spark.sessionState.conf)
val planned = Aggregation(query).headOption.map(ensureRequirements(_))
.getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
planned.collect { case n if n.nodeName contains "Aggregate" => n }
}

test("count is partially aggregated") {
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
assert(testPartialAggregationPlan(query).size == 2,
s"The plan of query $query does not have partial aggregations.")
}

test("count distinct is partially aggregated") {
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
// For the new aggregation code path, there will be four aggregate operator for distinct
// aggregations.
assert(testPartialAggregationPlan(query).size == 4,
s"The plan of query $query does not have partial aggregations.")
}

test("mixed aggregates are partially aggregated") {
val query =
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
// For the new aggregation code path, there will be four aggregate operator for distinct
// aggregations.
assert(testPartialAggregationPlan(query).size == 4,
s"The plan of query $query does not have partial aggregations.")
}

test("non-partial aggregation for aggregates") {
withTempView("testNonPartialAggregation") {
val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
val row = Row.fromSeq(Seq.fill(1)(null))
val rowRDD = sparkContext.parallelize(row :: Nil)
spark.createDataFrame(rowRDD, schema).repartition($"value")
.createOrReplaceTempView("testNonPartialAggregation")

val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value")
.queryExecution.executedPlan

// If input data are already partitioned and the same columns are used in grouping keys and
// aggregation values, no partial aggregation exist in query plans.
val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.")

val planned2 = sql(
"""
|SELECT t.value, SUM(DISTINCT t.value)
|FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t
|GROUP BY t.value
""".stripMargin).queryExecution.executedPlan

val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.")
}
}

test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
Expand Down