Skip to content

Commit 2828345

Browse files
committed
do not merge different aggregate implementations and add test
1 parent 6134fa9 commit 2828345

File tree

7 files changed

+98
-32
lines changed

7 files changed

+98
-32
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import scala.collection.mutable.ArrayBuffer
2121

2222
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2324
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
2425
import org.apache.spark.sql.catalyst.rules.Rule
2526
import org.apache.spark.sql.catalyst.trees.TreePattern.{MULTI_SCALAR_SUBQUERY, SCALAR_SUBQUERY}
@@ -134,7 +135,7 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper {
134135
val newOutputMap = createOutputMap(np.projectList, newProjectList)
135136
Project(distinctExpressions(ep.output ++ newProjectList), mergedChild) -> newOutputMap
136137
}
137-
case (np: Aggregate, ep: Aggregate) =>
138+
case (np: Aggregate, ep: Aggregate) if supportedAggregateMerge(np, ep) =>
138139
tryMergePlans(np.child, ep.child).flatMap { case (mergedChild, outputMap) =>
139140
val newGroupingExpression = replaceAttributes(np.groupingExpressions, outputMap)
140141
if (ExpressionSet(newGroupingExpression) == ExpressionSet(ep.groupingExpressions)) {
@@ -147,8 +148,7 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper {
147148
None
148149
}
149150
}
150-
case _ =>
151-
None
151+
case _ => None
152152
}
153153
}
154154

@@ -168,6 +168,29 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper {
168168
ExpressionSet(expressions).toSeq.asInstanceOf[Seq[NamedExpression]]
169169
}
170170

171+
// Merging different aggregate implementations could cause performance regression
172+
private def supportedAggregateMerge(newPlan: Aggregate, existingPlan: Aggregate) = {
173+
val newPlanAggregateExpressions = newPlan.aggregateExpressions.flatMap(_.collect {
174+
case a: AggregateExpression => a
175+
})
176+
val existingPlanAggregateExpressions = existingPlan.aggregateExpressions.flatMap(_.collect {
177+
case a: AggregateExpression => a
178+
})
179+
val newPlanSupportsHashAggregate = Aggregate.supportsHashAggregate(
180+
newPlanAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
181+
val existingPlanSupportsHashAggregate = Aggregate.supportsHashAggregate(
182+
existingPlanAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
183+
newPlanSupportsHashAggregate && existingPlanSupportsHashAggregate ||
184+
!newPlanSupportsHashAggregate && !existingPlanSupportsHashAggregate && {
185+
val newPlanSupportsObjectHashAggregate =
186+
Aggregate.supportsObjectHashAggregate(newPlanAggregateExpressions)
187+
val existingPlanSupportsObjectHashAggregate =
188+
Aggregate.supportsObjectHashAggregate(existingPlanAggregateExpressions)
189+
newPlanSupportsObjectHashAggregate && existingPlanSupportsObjectHashAggregate ||
190+
!newPlanSupportsObjectHashAggregate && !existingPlanSupportsObjectHashAggregate
191+
}
192+
}
193+
171194
private def removeReferences(
172195
plan: LogicalPlan,
173196
mergedSubqueries: ArrayBuffer[LogicalPlan]): LogicalPlan = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, MultiInstanceRe
2222
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
2323
import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN
2424
import org.apache.spark.sql.catalyst.expressions._
25-
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
25+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, TypedImperativeAggregate}
2626
import org.apache.spark.sql.catalyst.plans._
2727
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
2828
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
@@ -828,6 +828,24 @@ case class Aggregate(
828828
copy(child = newChild)
829829
}
830830

831+
object Aggregate {
832+
def supportsAggregationBufferSchema(schema: StructType): Boolean = {
833+
schema.forall(f => UnsafeRow.isMutable(f.dataType))
834+
}
835+
836+
def supportsHashAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
837+
val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
838+
supportsAggregationBufferSchema(aggregationBufferSchema)
839+
}
840+
841+
def supportsObjectHashAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = {
842+
aggregateExpressions.map(_.aggregateFunction).exists {
843+
case _: TypedImperativeAggregate[_] => true
844+
case _ => false
845+
}
846+
}
847+
}
848+
831849
case class Window(
832850
windowExpressions: Seq[NamedExpression],
833851
partitionSpec: Seq[Expression],

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import org.apache.spark.sql.catalyst.dsl.expressions._
2121
import org.apache.spark.sql.catalyst.dsl.plans._
2222
import org.apache.spark.sql.catalyst.expressions.{GetStructField, MultiScalarSubquery, ScalarSubquery}
23+
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet}
2324
import org.apache.spark.sql.catalyst.plans._
2425
import org.apache.spark.sql.catalyst.plans.logical._
2526
import org.apache.spark.sql.catalyst.rules._
2627

2728
class MergeScalarSubqueriesSuite extends PlanTest {
2829

2930
private object Optimize extends RuleExecutor[LogicalPlan] {
30-
val batches =
31-
Batch("MergeScalarSubqueries", Once, MergeScalarSubqueries) :: Nil
31+
val batches = Batch("MergeScalarSubqueries", Once, MergeScalarSubqueries) :: Nil
3232
}
3333

34-
val testRelation = LocalRelation('a.int, 'b.int)
34+
val testRelation = LocalRelation('a.int, 'b.int, 'c.string)
3535

3636
test("Simple non-correlated scalar subquery merge") {
3737
val subquery1 = testRelation
@@ -70,4 +70,47 @@ class MergeScalarSubqueriesSuite extends PlanTest {
7070
// as only `Optimizer` can insert such a node to the plan
7171
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer, false)
7272
}
73+
74+
test("Do not merge different aggregate implementations") {
75+
// supports HashAggregate
76+
val subquery1 = testRelation
77+
.groupBy('b)(max('a))
78+
val subquery2 = testRelation
79+
.groupBy('b)(min('a))
80+
81+
// supports ObjectHashAggregate
82+
val subquery3 = testRelation
83+
.groupBy('b)(CollectList('a).toAggregateExpression(isDistinct = false))
84+
val subquery4 = testRelation
85+
.groupBy('b)(CollectSet('a).toAggregateExpression(isDistinct = false))
86+
87+
// supports SortAggregate
88+
val subquery5 = testRelation
89+
.groupBy('b)(max('c))
90+
val subquery6 = testRelation
91+
.groupBy('b)(min('c))
92+
93+
val originalQuery = testRelation
94+
.select(ScalarSubquery(subquery1), ScalarSubquery(subquery2), ScalarSubquery(subquery3),
95+
ScalarSubquery(subquery4), ScalarSubquery(subquery5), ScalarSubquery(subquery6))
96+
97+
val hashAggregates = testRelation
98+
.groupBy('b)(max('a), min('a)).analyze
99+
val objectHashAggregates = testRelation
100+
.groupBy('b)(CollectList('a).toAggregateExpression(isDistinct = false),
101+
CollectSet('a).toAggregateExpression(isDistinct = false)).analyze
102+
val sortAggregates = testRelation
103+
.groupBy('b)(max('c), min('c)).analyze
104+
val correctAnswer = testRelation
105+
.select(GetStructField(MultiScalarSubquery(hashAggregates), 0).as("scalarsubquery()"),
106+
GetStructField(MultiScalarSubquery(hashAggregates), 1).as("scalarsubquery()"),
107+
GetStructField(MultiScalarSubquery(objectHashAggregates), 0).as("scalarsubquery()"),
108+
GetStructField(MultiScalarSubquery(objectHashAggregates), 1).as("scalarsubquery()"),
109+
GetStructField(MultiScalarSubquery(sortAggregates), 0).as("scalarsubquery()"),
110+
GetStructField(MultiScalarSubquery(sortAggregates), 1).as("scalarsubquery()"))
111+
112+
// checkAnalysis is disabled because `Analizer` is not prepared for `MultiScalarSubquery` nodes
113+
// as only `Optimizer` can insert such a node to the plan
114+
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer, false)
115+
}
73116
}

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.spark.sql.catalyst.InternalRow;
2626
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
2727
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
28+
import org.apache.spark.sql.catalyst.plans.logical.Aggregate$;
2829
import org.apache.spark.sql.types.StructField;
2930
import org.apache.spark.sql.types.StructType;
3031
import org.apache.spark.unsafe.KVIterator;
@@ -68,12 +69,7 @@ public final class UnsafeFixedWidthAggregationMap {
6869
* schema, false otherwise.
6970
*/
7071
public static boolean supportsAggregationBufferSchema(StructType schema) {
71-
for (StructField field: schema.fields()) {
72-
if (!UnsafeRow.isMutable(field.dataType())) {
73-
return false;
74-
}
75-
}
76-
return true;
72+
return Aggregate$.MODULE$.supportsAggregationBufferSchema(schema);
7773
}
7874

7975
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate
1919

2020
import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.expressions.aggregate._
22+
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
2223
import org.apache.spark.sql.execution.SparkPlan
2324
import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec}
2425

@@ -50,7 +51,7 @@ object AggUtils {
5051
initialInputBufferOffset: Int = 0,
5152
resultExpressions: Seq[NamedExpression] = Nil,
5253
child: SparkPlan): SparkPlan = {
53-
val useHash = HashAggregateExec.supportsAggregate(
54+
val useHash = Aggregate.supportsHashAggregate(
5455
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
5556
if (useHash) {
5657
HashAggregateExec(
@@ -63,7 +64,7 @@ object AggUtils {
6364
child = child)
6465
} else {
6566
val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation
66-
val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions)
67+
val useObjectHash = Aggregate.supportsObjectHashAggregate(aggregateExpressions)
6768

6869
if (objectHashEnabled && useObjectHash) {
6970
ObjectHashAggregateExec(

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
3030
import org.apache.spark.sql.catalyst.expressions.aggregate._
3131
import org.apache.spark.sql.catalyst.expressions.codegen._
3232
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
33+
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
3334
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
3435
import org.apache.spark.sql.catalyst.util.truncatedString
3536
import org.apache.spark.sql.execution._
@@ -55,7 +56,7 @@ case class HashAggregateExec(
5556
with BlockingOperatorWithCodegen
5657
with GeneratePredicateHelper {
5758

58-
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
59+
require(Aggregate.supportsHashAggregate(aggregateBufferAttributes))
5960

6061
override lazy val allAttributes: AttributeSeq =
6162
child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
@@ -1139,10 +1140,3 @@ case class HashAggregateExec(
11391140
override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExec =
11401141
copy(child = newChild)
11411142
}
1142-
1143-
object HashAggregateExec {
1144-
def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
1145-
val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
1146-
UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema)
1147-
}
1148-
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,3 @@ case class ObjectHashAggregateExec(
142142
override protected def withNewChildInternal(newChild: SparkPlan): ObjectHashAggregateExec =
143143
copy(child = newChild)
144144
}
145-
146-
object ObjectHashAggregateExec {
147-
def supportsAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = {
148-
aggregateExpressions.map(_.aggregateFunction).exists {
149-
case _: TypedImperativeAggregate[_] => true
150-
case _ => false
151-
}
152-
}
153-
}

0 commit comments

Comments
 (0)