Skip to content

Commit 7029e89

Browse files
rootroot
authored andcommitted
[SPARK-18137][SQL]Fix RewriteDistinctAggregates UnresolvedException when the UDAF has a foldable TypeCheck
1 parent c8c0906 commit 7029e89

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,19 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
142142

143143
// Setup unique distinct aggregate children.
144144
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
145-
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
146-
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
145+
val distinctAggChildFoldable = distinctAggChildren.filter(_.foldable)
146+
// 1.only unfoldable child should be expand
147+
// 2.if foldable child mapped to AttributeRefference using expressionAttributePair,
148+
// the udaf function(such as ApproximatePercentile)
149+
// which has a foldable TypeCheck will failed,because AttributeRefference is unfoldable
150+
val distinctAggChildUnFoldableAttrMap = distinctAggChildren
151+
.filter(!_.foldable).map(expressionAttributePair)
152+
153+
val distinctAggChildrenUnFoldableAttrs = distinctAggChildUnFoldableAttrMap.map(_._2)
147154

148155
// Setup expand & aggregate operators for distinct aggregate expressions.
149-
val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
156+
val distinctAggChildAttrLookup = (distinctAggChildUnFoldableAttrMap
157+
++ distinctAggChildFoldable.map(c => c -> c)).toMap
150158
val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
151159
case ((group, expressions), i) =>
152160
val id = Literal(i + 1)
@@ -172,11 +180,15 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
172180
// Setup expand for the 'regular' aggregate expressions.
173181
val regularAggExprs = aggExpressions.filter(!_.isDistinct)
174182
val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
175-
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)
176-
183+
val regularAggChildFoldable = regularAggChildren.filter(_.foldable)
184+
val regularAggChildUnFoldable = regularAggChildren.filter(!_.foldable)
185+
val regularAggChildUnFoldableAttrMap = regularAggChildUnFoldable
186+
.map(expressionAttributePair)
187+
val regularAggChildUnFoldableAttrs = regularAggChildUnFoldableAttrMap.map(_._2)
177188
// Setup aggregates for 'regular' aggregate expressions.
178189
val regularGroupId = Literal(0)
179-
val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
190+
val regularAggChildAttrLookup = (regularAggChildUnFoldableAttrMap
191+
++ regularAggChildFoldable.map(c => c -> c)).toMap
180192
val regularAggOperatorMap = regularAggExprs.map { e =>
181193
// Perform the actual aggregation in the initial aggregate.
182194
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
@@ -207,13 +219,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
207219
Seq(a.groupingExpressions ++
208220
distinctAggChildren.map(nullify) ++
209221
Seq(regularGroupId) ++
210-
regularAggChildren)
222+
regularAggChildUnFoldable)
211223
} else {
212224
Seq.empty[Seq[Expression]]
213225
}
214226

215227
// Construct the distinct aggregate input projections.
216-
val regularAggNulls = regularAggChildren.map(nullify)
228+
val regularAggNulls = regularAggChildUnFoldable.map(nullify)
217229
val distinctAggProjections = distinctAggOperatorMap.map {
218230
case (projection, _) =>
219231
a.groupingExpressions ++
@@ -224,12 +236,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
224236
// Construct the expand operator.
225237
val expand = Expand(
226238
regularAggProjection ++ distinctAggProjections,
227-
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2),
239+
groupByAttrs ++ distinctAggChildrenUnFoldableAttrs ++ Seq(gid)
240+
++ regularAggChildUnFoldableAttrs,
228241
a.child)
229242

230243
// Construct the first aggregate operator. This de-duplicates the all the children of
231244
// distinct operators, and applies the regular aggregate operators.
232-
val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
245+
val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildrenUnFoldableAttrs :+ gid
233246
val firstAggregate = Aggregate(
234247
firstAggregateGroupBy,
235248
firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,16 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
150150
}
151151

152152
test("Generic UDAF aggregates") {
153+
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999))" +
154+
", count(distinct key),sum(distinct key) FROM src LIMIT 1"),
155+
sql("SELECT max(key), count(distinct key),sum(distinct key) FROM src LIMIT 1")
156+
.collect().toSeq)
157+
158+
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.09999 + 0.9))" +
159+
", count(distinct key),sum(distinct key),1 FROM src LIMIT 1"),
160+
sql("SELECT max(key), count(distinct key),sum(distinct key), 1 FROM src LIMIT 1")
161+
.collect().toSeq)
162+
153163
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"),
154164
sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)
155165

0 commit comments

Comments
 (0)