@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
20
20
import org .apache .spark .sql .catalyst .expressions ._
21
21
import org .apache .spark .sql .catalyst .plans ._
22
22
import org .apache .spark .sql .types ._
23
+ import org .apache .spark .util .collection .OpenHashSet
23
24
24
25
case class Project (projectList : Seq [NamedExpression ], child : LogicalPlan ) extends UnaryNode {
25
26
override def output : Seq [Attribute ] = projectList.map(_.toAttribute)
@@ -228,24 +229,76 @@ case class Window(
228
229
/**
229
230
* Apply the all of the GroupExpressions to every input row, hence we will get
230
231
* multiple output rows for a input row.
231
- * @param projections The group of expressions, all of the group expressions should
232
- * output the same schema specified by the parameter `output`
233
- * @param output The output Schema
232
+ * @param bitmasks The bitmask set represents the grouping sets
233
+ * @param groupByExprs The grouping by expressions
234
234
* @param child Child operator
235
235
*/
236
236
case class Expand (
237
- projections : Seq [Seq [Expression ]],
238
- output : Seq [Attribute ],
237
+ bitmasks : Seq [Int ],
238
+ groupByExprs : Seq [Expression ],
239
+ gid : Attribute ,
239
240
child : LogicalPlan ) extends UnaryNode {
240
241
override def statistics : Statistics = {
241
242
val sizeInBytes = child.statistics.sizeInBytes * projections.length
242
243
Statistics (sizeInBytes = sizeInBytes)
243
244
}
245
+
246
+ val projections : Seq [Seq [Expression ]] = expand()
247
+
248
+ /**
249
+ * Extract attribute set according to the grouping id
250
+ * @param bitmask bitmask to represent the selected of the attribute sequence
251
+ * @param exprs the attributes in sequence
252
+ * @return the attributes of non selected specified via bitmask (with the bit set to 1)
253
+ */
254
+ private def buildNonSelectExprSet (bitmask : Int , exprs : Seq [Expression ])
255
+ : OpenHashSet [Expression ] = {
256
+ val set = new OpenHashSet [Expression ](2 )
257
+
258
+ var bit = exprs.length - 1
259
+ while (bit >= 0 ) {
260
+ if (((bitmask >> bit) & 1 ) == 0 ) set.add(exprs(bit))
261
+ bit -= 1
262
+ }
263
+
264
+ set
265
+ }
266
+
267
+ /**
268
+ * Create an array of Projections for the child projection, and replace the projections'
269
+ * expressions which equal GroupBy expressions with Literal(null), if those expressions
270
+ * are not set for this grouping set (according to the bit mask).
271
+ */
272
+ private [this ] def expand (): Seq [Seq [Expression ]] = {
273
+ val result = new scala.collection.mutable.ArrayBuffer [Seq [Expression ]]
274
+
275
+ bitmasks.foreach { bitmask =>
276
+ // get the non selected grouping attributes according to the bit mask
277
+ val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)
278
+
279
+ val substitution = (child.output :+ gid).map(expr => expr transformDown {
280
+ case x : Expression if nonSelectedGroupExprSet.contains(x) =>
281
+ // if the input attribute in the Invalid Grouping Expression set of for this group
282
+ // replace it with constant null
283
+ Literal .create(null , expr.dataType)
284
+ case x if x == gid =>
285
+ // replace the groupingId with concrete value (the bit mask)
286
+ Literal .create(bitmask, IntegerType )
287
+ })
288
+
289
+ result += substitution
290
+ }
291
+
292
+ result.toSeq
293
+ }
294
+
295
+ override def output : Seq [Attribute ] = {
296
+ child.output :+ gid
297
+ }
244
298
}
245
299
246
300
trait GroupingAnalytics extends UnaryNode {
247
301
self : Product =>
248
- def gid : AttributeReference
249
302
def groupByExprs : Seq [Expression ]
250
303
def aggregations : Seq [NamedExpression ]
251
304
@@ -266,17 +319,12 @@ trait GroupingAnalytics extends UnaryNode {
266
319
* @param child Child operator
267
320
* @param aggregations The Aggregation expressions, those non selected group by expressions
268
321
* will be considered as constant null if it appears in the expressions
269
- * @param gid The attribute represents the virtual column GROUPING__ID, and it's also
270
- * the bitmask indicates the selected GroupBy Expressions for each
271
- * aggregating output row.
272
- * The associated output will be one of the value in `bitmasks`
273
322
*/
274
323
case class GroupingSets (
275
324
bitmasks : Seq [Int ],
276
325
groupByExprs : Seq [Expression ],
277
326
child : LogicalPlan ,
278
- aggregations : Seq [NamedExpression ],
279
- gid : AttributeReference = VirtualColumn .newGroupingId) extends GroupingAnalytics {
327
+ aggregations : Seq [NamedExpression ]) extends GroupingAnalytics {
280
328
281
329
def withNewAggs (aggs : Seq [NamedExpression ]): GroupingAnalytics =
282
330
this .copy(aggregations = aggs)
@@ -290,15 +338,11 @@ case class GroupingSets(
290
338
* @param child Child operator
291
339
* @param aggregations The Aggregation expressions, those non selected group by expressions
292
340
* will be considered as constant null if it appears in the expressions
293
- * @param gid The attribute represents the virtual column GROUPING__ID, and it's also
294
- * the bitmask indicates the selected GroupBy Expressions for each
295
- * aggregating output row.
296
341
*/
297
342
case class Cube (
298
343
groupByExprs : Seq [Expression ],
299
344
child : LogicalPlan ,
300
- aggregations : Seq [NamedExpression ],
301
- gid : AttributeReference = VirtualColumn .newGroupingId) extends GroupingAnalytics {
345
+ aggregations : Seq [NamedExpression ]) extends GroupingAnalytics {
302
346
303
347
def withNewAggs (aggs : Seq [NamedExpression ]): GroupingAnalytics =
304
348
this .copy(aggregations = aggs)
@@ -313,15 +357,11 @@ case class Cube(
313
357
* @param child Child operator
314
358
* @param aggregations The Aggregation expressions, those non selected group by expressions
315
359
* will be considered as constant null if it appears in the expressions
316
- * @param gid The attribute represents the virtual column GROUPING__ID, and it's also
317
- * the bitmask indicates the selected GroupBy Expressions for each
318
- * aggregating output row.
319
360
*/
320
361
case class Rollup (
321
362
groupByExprs : Seq [Expression ],
322
363
child : LogicalPlan ,
323
- aggregations : Seq [NamedExpression ],
324
- gid : AttributeReference = VirtualColumn .newGroupingId) extends GroupingAnalytics {
364
+ aggregations : Seq [NamedExpression ]) extends GroupingAnalytics {
325
365
326
366
def withNewAggs (aggs : Seq [NamedExpression ]): GroupingAnalytics =
327
367
this .copy(aggregations = aggs)
0 commit comments