Skip to content

WEB-11368: Skip literals so that HiveUDAF percentile_approx can operate correctly #186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 4, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,21 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
// Extract distinct aggregate expressions.
val distinctAggGroups = aggExpressions
.filter(_.isDistinct)
.groupBy(_.aggregateFunction.children.toSet)
.groupBy { e =>
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
if (unfoldableChildren.nonEmpty) {
// Only expand the unfoldable children
unfoldableChildren
} else {
// If aggregateFunction's children are all foldable
// we must expand at least one of the children (here we take the first child),
// or If we don't, we will get the wrong result, for example:
// count(distinct 1) will be explained to count(1) after the rewrite function.
// Generally, the distinct aggregateFunction should not run
// foldable TypeCheck for the first child.
e.aggregateFunction.children.take(1).toSet
}
}

// Aggregation strategy can handle the query with single distinct
if (distinctAggGroups.size > 1) {
Expand All @@ -134,10 +148,9 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
def patchAggregateFunctionChildren(
af: AggregateFunction)(
attrs: Expression => Expression): AggregateFunction = {
af.withNewChildren(af.children.map {
case afc => attrs(afc)
}).asInstanceOf[AggregateFunction]
attrs: Expression => Option[Expression]): AggregateFunction = {
val newChildren = af.children.map(c => attrs(c).getOrElse(c))
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
}

// Setup unique distinct aggregate children.
Expand All @@ -161,7 +174,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
val operators = expressions.map { e =>
val af = e.aggregateFunction
val naf = patchAggregateFunctionChildren(af) { x =>
evalWithinGroup(id, distinctAggChildAttrLookup(x))
distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _))
}
(e, e.copy(aggregateFunction = naf, isDistinct = false))
}
Expand All @@ -170,16 +183,20 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
}

// Setup expand for the 'regular' aggregate expressions.
val regularAggExprs = aggExpressions.filter(!_.isDistinct)
val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
// only expand unfoldable children
val regularAggExprs = aggExpressions
.filter(e => !e.isDistinct && e.children.exists(!_.foldable))
val regularAggChildren = regularAggExprs
.flatMap(_.aggregateFunction.children.filter(!_.foldable))
.distinct
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)

// Setup aggregates for 'regular' aggregate expressions.
val regularGroupId = Literal(0)
val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get)
val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()

// Select the result of the first aggregate in the last aggregate.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,40 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}

test("Generic UDAF aggregates") {
checkAnswer(sql(
"""
|SELECT percentile_approx(2, 0.99999),
| sum(distinct 1),
| count(distinct 1,2,3,4) FROM src LIMIT 1
""".stripMargin), sql("SELECT 2, 1, 1 FROM src LIMIT 1").collect().toSeq)

checkAnswer(sql(
"""
|SELECT ceiling(percentile_approx(distinct key, 0.99999)),
| count(distinct key),
| sum(distinct key),
| count(distinct 1),
| sum(distinct 1),
| sum(1) FROM src LIMIT 1
""".stripMargin),
sql(
"""
|SELECT max(key),
| count(distinct key),
| sum(distinct key),
| 1, 1, sum(1) FROM src LIMIT 1
""".stripMargin).collect().toSeq)

checkAnswer(sql(
"""
|SELECT ceiling(percentile_approx(distinct key, 0.9 + 0.09999)),
| count(distinct key), sum(distinct key),
| count(distinct 1), sum(distinct 1),
| sum(1) FROM src LIMIT 1
""".stripMargin),
sql("SELECT max(key), count(distinct key), sum(distinct key), 1, 1, sum(1) FROM src LIMIT 1")
.collect().toSeq)

checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)

Expand Down