Skip to content

Commit 7b1450b

Browse files
chenghao-intelmarmbrus
authored andcommitted
[SPARK-7235] [SQL] Refactor the grouping sets
The logical plan `Expand` takes the `output` as constructor argument, which break the references chain. We need to refactor the code, as well as the column pruning. Author: Cheng Hao <hao.cheng@intel.com> Closes #5780 from chenghao-intel/expand and squashes the following commits: 76e4aa4 [Cheng Hao] revert the change for case insenstive 7c10a83 [Cheng Hao] refactor the grouping sets
1 parent 4f7fbef commit 7b1450b

File tree

5 files changed

+78
-71
lines changed

5 files changed

+78
-71
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 9 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -192,49 +192,17 @@ class Analyzer(
192192
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
193193
}
194194

195-
/**
196-
* Create an array of Projections for the child projection, and replace the projections'
197-
* expressions which equal GroupBy expressions with Literal(null), if those expressions
198-
* are not set for this grouping set (according to the bit mask).
199-
*/
200-
private[this] def expand(g: GroupingSets): Seq[Seq[Expression]] = {
201-
val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]
202-
203-
g.bitmasks.foreach { bitmask =>
204-
// get the non selected grouping attributes according to the bit mask
205-
val nonSelectedGroupExprs = ArrayBuffer.empty[Expression]
206-
var bit = g.groupByExprs.length - 1
207-
while (bit >= 0) {
208-
if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit)
209-
bit -= 1
210-
}
211-
212-
val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
213-
case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined =>
214-
// if the input attribute in the Invalid Grouping Expression set of for this group
215-
// replace it with constant null
216-
Literal.create(null, expr.dataType)
217-
case x if x == g.gid =>
218-
// replace the groupingId with concrete value (the bit mask)
219-
Literal.create(bitmask, IntegerType)
220-
})
221-
222-
result += substitution
223-
}
224-
225-
result.toSeq
226-
}
227-
228195
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
229-
case a: Cube if a.resolved =>
230-
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
231-
case a: Rollup if a.resolved =>
232-
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
233-
case x: GroupingSets if x.resolved =>
196+
case a: Cube =>
197+
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
198+
case a: Rollup =>
199+
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
200+
case x: GroupingSets =>
201+
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
234202
Aggregate(
235-
x.groupByExprs :+ x.gid,
203+
x.groupByExprs :+ VirtualColumn.groupingIdAttribute,
236204
x.aggregations,
237-
Expand(expand(x), x.child.output :+ x.gid, x.child))
205+
Expand(x.bitmasks, x.groupByExprs, gid, x.child))
238206
}
239207
}
240208

@@ -368,12 +336,7 @@ class Analyzer(
368336

369337
case q: LogicalPlan =>
370338
logTrace(s"Attempting to resolve ${q.simpleString}")
371-
q transformExpressionsUp {
372-
case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 &&
373-
resolver(nameParts(0), VirtualColumn.groupingIdName) &&
374-
q.isInstanceOf[GroupingAnalytics] =>
375-
// Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
376-
q.asInstanceOf[GroupingAnalytics].gid
339+
q transformExpressionsUp {
377340
case u @ UnresolvedAttribute(nameParts) =>
378341
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
379342
val result =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,5 +262,5 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E
262262

263263
object VirtualColumn {
264264
val groupingIdName: String = "grouping__id"
265-
def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)()
265+
val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName)
266266
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ object UnionPushdown extends Rule[LogicalPlan] {
121121
*/
122122
object ColumnPruning extends Rule[LogicalPlan] {
123123
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
124+
case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child))
125+
if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty =>
126+
a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references)))
127+
124128
// Eliminate attributes that are not needed to calculate the specified aggregates.
125129
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
126130
a.copy(child = Project(a.references.toSeq, child))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
2020
import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.plans._
2222
import org.apache.spark.sql.types._
23+
import org.apache.spark.util.collection.OpenHashSet
2324

2425
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
2526
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
@@ -228,24 +229,76 @@ case class Window(
228229
/**
229230
* Apply the all of the GroupExpressions to every input row, hence we will get
230231
* multiple output rows for a input row.
231-
* @param projections The group of expressions, all of the group expressions should
232-
* output the same schema specified by the parameter `output`
233-
* @param output The output Schema
232+
* @param bitmasks The bitmask set represents the grouping sets
233+
* @param groupByExprs The grouping by expressions
234234
* @param child Child operator
235235
*/
236236
case class Expand(
237-
projections: Seq[Seq[Expression]],
238-
output: Seq[Attribute],
237+
bitmasks: Seq[Int],
238+
groupByExprs: Seq[Expression],
239+
gid: Attribute,
239240
child: LogicalPlan) extends UnaryNode {
240241
override def statistics: Statistics = {
241242
val sizeInBytes = child.statistics.sizeInBytes * projections.length
242243
Statistics(sizeInBytes = sizeInBytes)
243244
}
245+
246+
val projections: Seq[Seq[Expression]] = expand()
247+
248+
/**
249+
* Extract attribute set according to the grouping id
250+
* @param bitmask bitmask to represent the selected of the attribute sequence
251+
* @param exprs the attributes in sequence
252+
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
253+
*/
254+
private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
255+
: OpenHashSet[Expression] = {
256+
val set = new OpenHashSet[Expression](2)
257+
258+
var bit = exprs.length - 1
259+
while (bit >= 0) {
260+
if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
261+
bit -= 1
262+
}
263+
264+
set
265+
}
266+
267+
/**
268+
* Create an array of Projections for the child projection, and replace the projections'
269+
* expressions which equal GroupBy expressions with Literal(null), if those expressions
270+
* are not set for this grouping set (according to the bit mask).
271+
*/
272+
private[this] def expand(): Seq[Seq[Expression]] = {
273+
val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]
274+
275+
bitmasks.foreach { bitmask =>
276+
// get the non selected grouping attributes according to the bit mask
277+
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)
278+
279+
val substitution = (child.output :+ gid).map(expr => expr transformDown {
280+
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
281+
// if the input attribute in the Invalid Grouping Expression set of for this group
282+
// replace it with constant null
283+
Literal.create(null, expr.dataType)
284+
case x if x == gid =>
285+
// replace the groupingId with concrete value (the bit mask)
286+
Literal.create(bitmask, IntegerType)
287+
})
288+
289+
result += substitution
290+
}
291+
292+
result.toSeq
293+
}
294+
295+
override def output: Seq[Attribute] = {
296+
child.output :+ gid
297+
}
244298
}
245299

246300
trait GroupingAnalytics extends UnaryNode {
247301
self: Product =>
248-
def gid: AttributeReference
249302
def groupByExprs: Seq[Expression]
250303
def aggregations: Seq[NamedExpression]
251304

@@ -266,17 +319,12 @@ trait GroupingAnalytics extends UnaryNode {
266319
* @param child Child operator
267320
* @param aggregations The Aggregation expressions, those non selected group by expressions
268321
* will be considered as constant null if it appears in the expressions
269-
* @param gid The attribute represents the virtual column GROUPING__ID, and it's also
270-
* the bitmask indicates the selected GroupBy Expressions for each
271-
* aggregating output row.
272-
* The associated output will be one of the value in `bitmasks`
273322
*/
274323
case class GroupingSets(
275324
bitmasks: Seq[Int],
276325
groupByExprs: Seq[Expression],
277326
child: LogicalPlan,
278-
aggregations: Seq[NamedExpression],
279-
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
327+
aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
280328

281329
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
282330
this.copy(aggregations = aggs)
@@ -290,15 +338,11 @@ case class GroupingSets(
290338
* @param child Child operator
291339
* @param aggregations The Aggregation expressions, those non selected group by expressions
292340
* will be considered as constant null if it appears in the expressions
293-
* @param gid The attribute represents the virtual column GROUPING__ID, and it's also
294-
* the bitmask indicates the selected GroupBy Expressions for each
295-
* aggregating output row.
296341
*/
297342
case class Cube(
298343
groupByExprs: Seq[Expression],
299344
child: LogicalPlan,
300-
aggregations: Seq[NamedExpression],
301-
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
345+
aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
302346

303347
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
304348
this.copy(aggregations = aggs)
@@ -313,15 +357,11 @@ case class Cube(
313357
* @param child Child operator
314358
* @param aggregations The Aggregation expressions, those non selected group by expressions
315359
* will be considered as constant null if it appears in the expressions
316-
* @param gid The attribute represents the virtual column GROUPING__ID, and it's also
317-
* the bitmask indicates the selected GroupBy Expressions for each
318-
* aggregating output row.
319360
*/
320361
case class Rollup(
321362
groupByExprs: Seq[Expression],
322363
child: LogicalPlan,
323-
aggregations: Seq[NamedExpression],
324-
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
364+
aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
325365

326366
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
327367
this.copy(aggregations = aggs)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
308308
execution.Project(projectList, planLater(child)) :: Nil
309309
case logical.Filter(condition, child) =>
310310
execution.Filter(condition, planLater(child)) :: Nil
311-
case logical.Expand(projections, output, child) =>
312-
execution.Expand(projections, output, planLater(child)) :: Nil
311+
case e @ logical.Expand(_, _, _, child) =>
312+
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
313313
case logical.Aggregate(group, agg, child) =>
314314
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
315315
case logical.Window(projectList, windowExpressions, spec, child) =>

0 commit comments

Comments
 (0)