@@ -93,47 +93,14 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
93
93
* Check if `partialAgg` to be partial aggregate of `finalAgg`.
94
94
*/
95
95
private def isPartialAgg (partialAgg : HashAggregateExec , finalAgg : HashAggregateExec ): Boolean = {
96
- val partialGroupExprs = partialAgg.groupingExpressions
97
- val finalGroupExprs = finalAgg.groupingExpressions
98
- val partialAggExprs = partialAgg.aggregateExpressions
99
- val finalAggExprs = finalAgg.aggregateExpressions
100
- val partialAggAttrs = partialAggExprs.flatMap(_.aggregateFunction.aggBufferAttributes)
101
- val finalAggAttrs = finalAggExprs.map(_.resultAttribute)
102
- val partialResultExprs = partialGroupExprs ++
103
- partialAggExprs.flatMap(_.aggregateFunction.inputAggBufferAttributes)
104
-
105
- val groupExprsEqual = partialGroupExprs.length == finalGroupExprs.length &&
106
- partialGroupExprs.zip(finalGroupExprs).forall {
107
- case (e1, e2) => e1.semanticEquals(e2)
108
- }
109
- val aggExprsEqual = partialAggExprs.length == finalAggExprs.length &&
110
- partialAggExprs.forall(_.mode == Partial ) && finalAggExprs.forall(_.mode == Final ) &&
111
- partialAggExprs.zip(finalAggExprs).forall {
112
- case (e1, e2) => e1.aggregateFunction.semanticEquals(e2.aggregateFunction)
113
- }
114
- val isPartialAggAttrsValid = partialAggAttrs.length == partialAgg.aggregateAttributes.length &&
115
- partialAggAttrs.zip(partialAgg.aggregateAttributes).forall {
116
- case (a1, a2) => a1.semanticEquals(a2)
96
+ if (partialAgg.aggregateExpressions.forall(_.mode == Partial ) &&
97
+ finalAgg.aggregateExpressions.forall(_.mode == Final )) {
98
+ (finalAgg.logicalLink, partialAgg.logicalLink) match {
99
+ case (Some (agg1), Some (agg2)) => agg1.sameResult(agg2)
100
+ case _ => false
117
101
}
118
- val isFinalAggAttrsValid = finalAggAttrs.length == finalAgg.aggregateAttributes.length &&
119
- finalAggAttrs.zip(finalAgg.aggregateAttributes).forall {
120
- case (a1, a2) => a1.semanticEquals(a2)
121
- }
122
- val isPartialResultExprsValid =
123
- partialResultExprs.length == partialAgg.resultExpressions.length &&
124
- partialResultExprs.zip(partialAgg.resultExpressions).forall {
125
- case (a1, a2) => a1.semanticEquals(a2)
126
- }
127
- val isRequiredDistributionValid =
128
- partialAgg.requiredChildDistributionExpressions.isEmpty &&
129
- finalAgg.requiredChildDistributionExpressions.exists { exprs =>
130
- exprs.length == finalGroupExprs.length &&
131
- exprs.zip(finalGroupExprs).forall {
132
- case (e1, e2) => e1.semanticEquals(e2)
133
- }
134
- }
135
-
136
- groupExprsEqual && aggExprsEqual && isPartialAggAttrsValid && isFinalAggAttrsValid &&
137
- isPartialResultExprsValid && isRequiredDistributionValid
102
+ } else {
103
+ false
104
+ }
138
105
}
139
106
}
0 commit comments