Skip to content

[SPARK-18137][SQL]Fix RewriteDistinctAggregates UnresolvedException when a UDAF has a foldable TypeCheck #15668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,21 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}

// Extract distinct aggregate expressions.
val distinctAggGroups = aggExpressions
.filter(_.isDistinct)
.groupBy(_.aggregateFunction.children.toSet)
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
if (unfoldableChildren.nonEmpty) {
// Only expand the unfoldable children
unfoldableChildren
} else {
// If aggregateFunction's children are all foldable
// we must expand at least one of the children (here we take the first child),
// or If we don't, we will get the wrong result, for example:
// count(distinct 1) will be explained to count(1) after the rewrite function.
// Generally, the distinct aggregateFunction should not run
// foldable TypeCheck for the first child.
e.aggregateFunction.children.take(1).toSet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good catch. It would be great if we could git rid of this by constant folding (not needed in this PR). Another way of getting rid of this, would be by creating a separate processing group for these distincts.

}
}

// Check if the aggregates contains functions that do not support partial aggregation.
val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial)
Expand All @@ -136,8 +148,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
def patchAggregateFunctionChildren(
af: AggregateFunction)(
attrs: Expression => Expression): AggregateFunction = {
af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction]
attrs: Expression => Option[Expression]): AggregateFunction = {
val newChildren = af.children.map(c => attrs(c).getOrElse(c))
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
}

// Setup unique distinct aggregate children.
Expand All @@ -161,7 +174,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val operators = expressions.map { e =>
val af = e.aggregateFunction
val naf = patchAggregateFunctionChildren(af) { x =>
evalWithinGroup(id, distinctAggChildAttrLookup(x))
distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _))
}
(e, e.copy(aggregateFunction = naf, isDistinct = false))
}
Expand All @@ -170,16 +183,20 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}

// Setup expand for the 'regular' aggregate expressions.
val regularAggExprs = aggExpressions.filter(!_.isDistinct)
val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
// only expand unfoldable children
val regularAggExprs = aggExpressions
.filter(e => !e.isDistinct && e.children.exists(!_.foldable))
val regularAggChildren = regularAggExprs
.flatMap(_.aggregateFunction.children.filter(!_.foldable))
.distinct
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)

// Setup aggregates for 'regular' aggregate expressions.
val regularGroupId = Literal(0)
val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get)
val operator = Alias(e.copy(aggregateFunction = af), e.sql)()

// Select the result of the first aggregate in the last aggregate.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,41 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}

test("Generic UDAF aggregates") {

checkAnswer(sql(
"""
|SELECT percentile_approx(2, 0.99999),
| sum(distinct 1),
| count(distinct 1,2,3,4) FROM src LIMIT 1
""".stripMargin), sql("SELECT 2, 1, 1 FROM src LIMIT 1").collect().toSeq)

checkAnswer(sql(
"""
|SELECT ceiling(percentile_approx(distinct key, 0.99999)),
| count(distinct key),
| sum(distinct key),
| count(distinct 1),
| sum(distinct 1),
| sum(1) FROM src LIMIT 1
""".stripMargin),
sql(
"""
|SELECT max(key),
| count(distinct key),
| sum(distinct key),
| 1, 1, sum(1) FROM src LIMIT 1
""".stripMargin).collect().toSeq)

checkAnswer(sql(
"""
|SELECT ceiling(percentile_approx(distinct key, 0.9 + 0.09999)),
| count(distinct key), sum(distinct key),
| count(distinct 1), sum(distinct 1),
| sum(1) FROM src LIMIT 1
""".stripMargin),
sql("SELECT max(key), count(distinct key), sum(distinct key), 1, 1, sum(1) FROM src LIMIT 1")
.collect().toSeq)

checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)

Expand Down