Skip to content

Commit cff1424

Browse files
committed
Check the approach to check partial agg based on logical plan instead
1 parent e8609fd commit cff1424

File tree

1 file changed

+8
-41
lines changed

1 file changed

+8
-41
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAgg.scala

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -93,47 +93,14 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
9393
* Check if `partialAgg` to be partial aggregate of `finalAgg`.
9494
*/
9595
private def isPartialAgg(partialAgg: HashAggregateExec, finalAgg: HashAggregateExec): Boolean = {
96-
val partialGroupExprs = partialAgg.groupingExpressions
97-
val finalGroupExprs = finalAgg.groupingExpressions
98-
val partialAggExprs = partialAgg.aggregateExpressions
99-
val finalAggExprs = finalAgg.aggregateExpressions
100-
val partialAggAttrs = partialAggExprs.flatMap(_.aggregateFunction.aggBufferAttributes)
101-
val finalAggAttrs = finalAggExprs.map(_.resultAttribute)
102-
val partialResultExprs = partialGroupExprs ++
103-
partialAggExprs.flatMap(_.aggregateFunction.inputAggBufferAttributes)
104-
105-
val groupExprsEqual = partialGroupExprs.length == finalGroupExprs.length &&
106-
partialGroupExprs.zip(finalGroupExprs).forall {
107-
case (e1, e2) => e1.semanticEquals(e2)
108-
}
109-
val aggExprsEqual = partialAggExprs.length == finalAggExprs.length &&
110-
partialAggExprs.forall(_.mode == Partial) && finalAggExprs.forall(_.mode == Final) &&
111-
partialAggExprs.zip(finalAggExprs).forall {
112-
case (e1, e2) => e1.aggregateFunction.semanticEquals(e2.aggregateFunction)
113-
}
114-
val isPartialAggAttrsValid = partialAggAttrs.length == partialAgg.aggregateAttributes.length &&
115-
partialAggAttrs.zip(partialAgg.aggregateAttributes).forall {
116-
case (a1, a2) => a1.semanticEquals(a2)
96+
if (partialAgg.aggregateExpressions.forall(_.mode == Partial) &&
97+
finalAgg.aggregateExpressions.forall(_.mode == Final)) {
98+
(finalAgg.logicalLink, partialAgg.logicalLink) match {
99+
case (Some(agg1), Some(agg2)) => agg1.sameResult(agg2)
100+
case _ => false
117101
}
118-
val isFinalAggAttrsValid = finalAggAttrs.length == finalAgg.aggregateAttributes.length &&
119-
finalAggAttrs.zip(finalAgg.aggregateAttributes).forall {
120-
case (a1, a2) => a1.semanticEquals(a2)
121-
}
122-
val isPartialResultExprsValid =
123-
partialResultExprs.length == partialAgg.resultExpressions.length &&
124-
partialResultExprs.zip(partialAgg.resultExpressions).forall {
125-
case (a1, a2) => a1.semanticEquals(a2)
126-
}
127-
val isRequiredDistributionValid =
128-
partialAgg.requiredChildDistributionExpressions.isEmpty &&
129-
finalAgg.requiredChildDistributionExpressions.exists { exprs =>
130-
exprs.length == finalGroupExprs.length &&
131-
exprs.zip(finalGroupExprs).forall {
132-
case (e1, e2) => e1.semanticEquals(e2)
133-
}
134-
}
135-
136-
groupExprsEqual && aggExprsEqual && isPartialAggAttrsValid && isFinalAggAttrsValid &&
137-
isPartialResultExprsValid && isRequiredDistributionValid
102+
} else {
103+
false
104+
}
138105
}
139106
}

0 commit comments

Comments
 (0)