Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-17868][SQL] Do not use bitmasks during parsing and analysis of CUBE/ROLLUP/GROUPING SETS #15484

Closed
wants to merge 9 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,9 @@ class Analyzer(
* Group Count: N + 1 (N is the number of group expressions)
*
* We need to get all of its subsets for the rule described above, the subset is
* represented as the bit masks.
* represented as sequence of expressions.
*/
def bitmasks(r: Rollup): Seq[Int] = {
Seq.tabulate(r.groupByExprs.length + 1)(idx => (1 << idx) - 1)
}
def rollupExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.inits.toSeq

/*
* GROUP BY a, b, c WITH CUBE
Expand All @@ -230,10 +228,14 @@ class Analyzer(
* Group Count: 2 ^ N (N is the number of group expressions)
*
* We need to get all of its subsets for a given GROUPBY expression, the subsets are
* represented as the bit masks.
* represented as sequence of expressions.
*/
def bitmasks(c: Cube): Seq[Int] = {
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.toList match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also write unit tests specifically for cubeExprs and rollupExprs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think you can just use subsets? e.g.

scala> Seq(1, 2, 3).toSet.subsets.foreach(println)
Set()
Set(1)
Set(2)
Set(3)
Set(1, 2)
Set(1, 3)
Set(2, 3)
Set(1, 2, 3)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid we can't just map the exprs to a set because we want to keep the original order.

case x :: xs =>
val initial = cubeExprs(xs)
initial.map(x +: _) ++ initial
case Nil =>
Seq(Seq.empty)
}

private def hasGroupingAttribute(expr: Expression): Boolean = {
Expand All @@ -256,103 +258,125 @@ class Analyzer(
expr transform {
case e: GroupingID =>
if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) {
gid
Alias(gid, toPrettySQL(e))()
} else {
throw new AnalysisException(
s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
s"grouping columns (${groupByExprs.mkString(",")})")
}
case Grouping(col: Expression) =>
case e @ Grouping(col: Expression) =>
val idx = groupByExprs.indexOf(col)
if (idx >= 0) {
Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
Literal(1)), ByteType)
Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
Literal(1)), ByteType), toPrettySQL(e))()
} else {
throw new AnalysisException(s"Column of grouping ($col) can't be found " +
s"in grouping columns ${groupByExprs.mkString(",")}")
}
}
}

// This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case a if !a.childrenResolved => a // be sure all of the children are resolved.
case p if p.expressions.exists(hasGroupingAttribute) =>
failAnalysis(
s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")

case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
/*
* Create new alias for all group by expressions for `Expand` operator.
*/
private def constructGroupByAlias(groupByExprs: Seq[Expression]): Seq[Alias] = {
groupByExprs.map {
case e: NamedExpression => Alias(e, e.name)()
case other => Alias(other, other.toString)()
}
}

// Ensure all the expressions have been resolved.
case x: GroupingSets if x.expressions.forall(_.resolved) =>
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()

// Expand works by setting grouping expressions to null as determined by the bitmasks. To
// prevent these null values from being used in an aggregate instead of the original value
// we need to create new aliases for all group by expressions that will only be used for
// the intended purpose.
val groupByAliases: Seq[Alias] = x.groupByExprs.map {
case e: NamedExpression => Alias(e, e.name)()
case other => Alias(other, other.toString)()
/*
* Construct [[Expand]] operator with grouping sets.
*/
private def constructExpand(
selectedGroupByExprs: Seq[Seq[Expression]],
child: LogicalPlan,
groupByAliases: Seq[Alias],
gid: Attribute): LogicalPlan = {
// Change the nullability of group by aliases if necessary. For example, if we have
// GROUPING SETS ((a,b), a), we do not need to change the nullability of a, but we
// should change the nullabilty of b to be TRUE.
// TODO: For Cube/Rollup just set nullability to be `true`.
val expandedAttributes = groupByAliases.map { alias =>
if (selectedGroupByExprs.exists(!_.contains(alias.child))) {
alias.toAttribute.withNullability(true)
} else {
alias.toAttribute
}
}

// The rightmost bit in the bitmasks corresponds to the last expression in groupByAliases
// with 0 indicating this expression is in the grouping set. The following line of code
// calculates the bitmask representing the expressions that absent in at least one grouping
// set (indicated by 1).
val nullBitmask = x.bitmasks.reduce(_ | _)

val attrLength = groupByAliases.length
val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
a.toAttribute.withNullability(((nullBitmask >> (attrLength - idx - 1)) & 1) == 1)
val groupingSetsAttributes = selectedGroupByExprs.map { groupingSetExprs =>
groupingSetExprs.map { expr =>
val alias = groupByAliases.find(_.child.semanticEquals(expr)).getOrElse(
failAnalysis(s"$expr doesn't show up in the GROUP BY list $groupByAliases"))
// Map alias to expanded attribute.
expandedAttributes.find(_.semanticEquals(alias.toAttribute)).getOrElse(
alias.toAttribute)
}
}

val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child)
val groupingAttrs = expand.output.drop(x.child.output.length)
Expand(groupingSetsAttributes, groupByAliases, expandedAttributes, gid, child)
}

val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
// collect all the found AggregateExpression, so we can check an expression is part of
// any AggregateExpression or not.
val aggsBuffer = ArrayBuffer[Expression]()
// Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
def isPartOfAggregation(e: Expression): Boolean = {
aggsBuffer.exists(a => a.find(_ eq e).isDefined)
/*
* Construct new aggregate expressions by replacing grouping functions.
*/
private def constructAggregateExprs(
groupByExprs: Seq[Expression],
aggregations: Seq[NamedExpression],
groupByAliases: Seq[Alias],
groupingAttrs: Seq[Expression],
gid: Attribute): Seq[NamedExpression] = aggregations.map {
// collect all the found AggregateExpression, so we can check an expression is part of
// any AggregateExpression or not.
val aggsBuffer = ArrayBuffer[Expression]()
// Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
def isPartOfAggregation(e: Expression): Boolean = {
aggsBuffer.exists(a => a.find(_ eq e).isDefined)
}
replaceGroupingFunc(_, groupByExprs, gid).transformDown {
// AggregateExpression should be computed on the unmodified value of its argument
// expressions, so we should not replace any references to grouping expression
// inside it.
case e: AggregateExpression =>
aggsBuffer += e
e
case e if isPartOfAggregation(e) => e
case e =>
// Replace expression by expand output attribute.
val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
if (index == -1) {
e
} else {
groupingAttrs(index)
}
replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown {
// AggregateExpression should be computed on the unmodified value of its argument
// expressions, so we should not replace any references to grouping expression
// inside it.
case e: AggregateExpression =>
aggsBuffer += e
e
case e if isPartOfAggregation(e) => e
case e =>
val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
if (index == -1) {
e
} else {
groupingAttrs(index)
}
}.asInstanceOf[NamedExpression]
}
}.asInstanceOf[NamedExpression]
}

Aggregate(groupingAttrs, aggregations, expand)
/*
* Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
*/
private def constructAggregate(
selectedGroupByExprs: Seq[Seq[Expression]],
groupByExprs: Seq[Expression],
aggregationExprs: Seq[NamedExpression],
child: LogicalPlan): LogicalPlan = {
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()

case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
val groupingExprs = findGroupingExprs(child)
// The unresolved grouping id will be resolved by ResolveMissingReferences
val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
f.copy(condition = newCond)
// Expand works by setting grouping expressions to null as determined by the
// `selectedGroupByExprs`. To prevent these null values from being used in an aggregate
// instead of the original value we need to create new aliases for all group by expressions
// that will only be used for the intended purpose.
val groupByAliases = constructGroupByAlias(groupByExprs)

case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) =>
val groupingExprs = findGroupingExprs(child)
val gid = VirtualColumn.groupingIdAttribute
// The unresolved grouping id will be resolved by ResolveMissingReferences
val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
s.copy(order = newOrder)
val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid)
val groupingAttrs = expand.output.drop(child.output.length)

val aggregations = constructAggregateExprs(
groupByExprs, aggregationExprs, groupByAliases, groupingAttrs, gid)

Aggregate(groupingAttrs, aggregations, expand)
}

private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
Expand All @@ -369,6 +393,41 @@ class Analyzer(
failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
}
}

// This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case a if !a.childrenResolved => a // be sure all of the children are resolved.
case p if p.expressions.exists(hasGroupingAttribute) =>
failAnalysis(
s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")

// Ensure group by expressions and aggregate expressions have been resolved.
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child)
if (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
constructAggregate(cubeExprs(groupByExprs), groupByExprs, aggregateExpressions, child)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child)
if (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
constructAggregate(rollupExprs(groupByExprs), groupByExprs, aggregateExpressions, child)
// Ensure all the expressions have been resolved.
case x: GroupingSets if x.expressions.forall(_.resolved) =>
constructAggregate(x.selectedGroupByExprs, x.groupByExprs, x.aggregations, x.child)

// We should make sure all expressions in condition have been resolved.
case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved =>
val groupingExprs = findGroupingExprs(child)
// The unresolved grouping id will be resolved by ResolveMissingReferences
val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
f.copy(condition = newCond)

// We should make sure all [[SortOrder]]s have been resolved.
case s @ Sort(order, _, child)
if order.exists(hasGroupingFunction) && order.forall(_.resolved) =>
val groupingExprs = findGroupingExprs(child)
val gid = VirtualColumn.groupingIdAttribute
// The unresolved grouping id will be resolved by ResolveMissingReferences
val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
s.copy(order = newOrder)
}
}

object ResolvePivot extends Rule[LogicalPlan] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,33 +492,18 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
ctx: AggregationContext,
selectExpressions: Seq[NamedExpression],
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
import ctx._
val groupByExpressions = expressionList(groupingExpressions)
val groupByExpressions = expressionList(ctx.groupingExpressions)

if (GROUPING != null) {
if (ctx.GROUPING != null) {
// GROUP BY .... GROUPING SETS (...)
val expressionMap = groupByExpressions.zipWithIndex.toMap
val numExpressions = expressionMap.size
val mask = (1 << numExpressions) - 1
val masks = ctx.groupingSet.asScala.map {
_.expression.asScala.foldLeft(mask) {
case (bitmap, eCtx) =>
// Find the index of the expression.
val e = typedVisit[Expression](eCtx)
val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse(
throw new ParseException(
s"$e doesn't show up in the GROUP BY list", ctx))
// 0 means that the column at the given index is a grouping column, 1 means it is not,
// so we unset the bit in bitmap.
bitmap & ~(1 << (numExpressions - 1 - index))
}
}
GroupingSets(masks, groupByExpressions, query, selectExpressions)
val selectedGroupByExprs =
ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)))
GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions)
} else {
// GROUP BY .... (WITH CUBE | WITH ROLLUP)?
val mappedGroupByExpressions = if (CUBE != null) {
val mappedGroupByExpressions = if (ctx.CUBE != null) {
Seq(Cube(groupByExpressions))
} else if (ROLLUP != null) {
} else if (ctx.ROLLUP != null) {
Seq(Rollup(groupByExpressions))
} else {
groupByExpressions
Copy link
Contributor Author

@jiangxb1987 jiangxb1987 Oct 14, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't check whether expression is in the GROUP BY list here, moved this to Analysis stage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is fine.

Expand Down
Loading