@@ -53,7 +53,8 @@ case class HashAggregateExec(
53
53
resultExpressions : Seq [NamedExpression ],
54
54
child : SparkPlan )
55
55
extends BaseAggregateExec
56
- with BlockingOperatorWithCodegen {
56
+ with BlockingOperatorWithCodegen
57
+ with GeneratePredicateHelper {
57
58
58
59
require(HashAggregateExec .supportsAggregate(aggregateBufferAttributes))
59
60
@@ -131,10 +132,8 @@ case class HashAggregateExec(
131
132
override def usedInputs : AttributeSet = inputSet
132
133
133
134
override def supportCodegen : Boolean = {
134
- // ImperativeAggregate and filter predicate are not supported right now
135
- // TODO: SPARK-30027 Support codegen for filter exprs in HashAggregateExec
136
- ! (aggregateExpressions.exists(_.aggregateFunction.isInstanceOf [ImperativeAggregate ]) ||
137
- aggregateExpressions.exists(_.filter.isDefined))
135
+ // ImperativeAggregate are not supported right now
136
+ ! aggregateExpressions.exists(_.aggregateFunction.isInstanceOf [ImperativeAggregate ])
138
137
}
139
138
140
139
override def inputRDDs (): Seq [RDD [InternalRow ]] = {
@@ -254,7 +253,7 @@ case class HashAggregateExec(
254
253
aggNames : Seq [String ],
255
254
aggBufferUpdatingExprs : Seq [Seq [Expression ]],
256
255
aggCodeBlocks : Seq [Block ],
257
- subExprs : Map [Expression , SubExprEliminationState ]): Option [String ] = {
256
+ subExprs : Map [Expression , SubExprEliminationState ]): Option [Seq [ String ] ] = {
258
257
val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil }
259
258
if (exprValsInSubExprs.exists(_.isInstanceOf [SimpleExprValue ])) {
260
259
// `SimpleExprValue`s cannot be used as an input variable for split functions, so
@@ -293,7 +292,7 @@ case class HashAggregateExec(
293
292
val inputVariables = args.map(_.variableName).mkString(" , " )
294
293
s " $doAggFuncName( $inputVariables); "
295
294
}
296
- Some (splitCodes.mkString( " \n " ).trim )
295
+ Some (splitCodes)
297
296
} else {
298
297
val errMsg = " Failed to split aggregate code into small functions because the parameter " +
299
298
" length of at least one split function went over the JVM limit: " +
@@ -308,6 +307,39 @@ case class HashAggregateExec(
308
307
}
309
308
}
310
309
310
+ private def generateEvalCodeForAggFuncs (
311
+ ctx : CodegenContext ,
312
+ input : Seq [ExprCode ],
313
+ inputAttrs : Seq [Attribute ],
314
+ boundUpdateExprs : Seq [Seq [Expression ]],
315
+ aggNames : Seq [String ],
316
+ aggCodeBlocks : Seq [Block ],
317
+ subExprs : SubExprCodes ): String = {
318
+ val aggCodes = if (conf.codegenSplitAggregateFunc &&
319
+ aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
320
+ val maybeSplitCodes = splitAggregateExpressions(
321
+ ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
322
+
323
+ maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code))
324
+ } else {
325
+ aggCodeBlocks.map(_.code)
326
+ }
327
+
328
+ aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map {
329
+ case (aggCode, (Partial | Complete , Some (condition))) =>
330
+ // Note: wrap in "do { } while(false);", so the generated checks can jump out
331
+ // with "continue;"
332
+ s """
333
+ |do {
334
+ | ${generatePredicateCode(ctx, condition, inputAttrs, input)}
335
+ | $aggCode
336
+ |} while(false);
337
+ """ .stripMargin
338
+ case (aggCode, _) =>
339
+ aggCode
340
+ }.mkString(" \n " )
341
+ }
342
+
311
343
private def doConsumeWithoutKeys (ctx : CodegenContext , input : Seq [ExprCode ]): String = {
312
344
// only have DeclarativeAggregate
313
345
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf [DeclarativeAggregate ])
@@ -354,24 +386,14 @@ case class HashAggregateExec(
354
386
""" .stripMargin
355
387
}
356
388
357
- val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
358
- aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
359
- val maybeSplitCode = splitAggregateExpressions(
360
- ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
361
-
362
- maybeSplitCode.getOrElse {
363
- aggCodeBlocks.fold(EmptyBlock )(_ + _).code
364
- }
365
- } else {
366
- aggCodeBlocks.fold(EmptyBlock )(_ + _).code
367
- }
368
-
389
+ val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
390
+ ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
369
391
s """
370
392
|// do aggregate
371
393
|// common sub-expressions
372
394
| $effectiveCodes
373
395
|// evaluate aggregate functions and update aggregation buffers
374
- | $codeToEvalAggFunc
396
+ | $codeToEvalAggFuncs
375
397
""" .stripMargin
376
398
}
377
399
@@ -908,7 +930,7 @@ case class HashAggregateExec(
908
930
}
909
931
}
910
932
911
- val inputAttr = aggregateBufferAttributes ++ inputAttributes
933
+ val inputAttrs = aggregateBufferAttributes ++ inputAttributes
912
934
// Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when
913
935
// generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while
914
936
// generating input columns, we use `currentVars`.
@@ -930,7 +952,7 @@ case class HashAggregateExec(
930
952
val updateRowInRegularHashMap : String = {
931
953
ctx.INPUT_ROW = unsafeRowBuffer
932
954
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
933
- bindReferences(updateExprsForOneFunc, inputAttr )
955
+ bindReferences(updateExprsForOneFunc, inputAttrs )
934
956
}
935
957
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
936
958
val effectiveCodes = subExprs.codes.mkString(" \n " )
@@ -961,23 +983,13 @@ case class HashAggregateExec(
961
983
""" .stripMargin
962
984
}
963
985
964
- val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
965
- aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
966
- val maybeSplitCode = splitAggregateExpressions(
967
- ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
968
-
969
- maybeSplitCode.getOrElse {
970
- aggCodeBlocks.fold(EmptyBlock )(_ + _).code
971
- }
972
- } else {
973
- aggCodeBlocks.fold(EmptyBlock )(_ + _).code
974
- }
975
-
986
+ val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
987
+ ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
976
988
s """
977
989
|// common sub-expressions
978
990
| $effectiveCodes
979
991
|// evaluate aggregate functions and update aggregation buffers
980
- | $codeToEvalAggFunc
992
+ | $codeToEvalAggFuncs
981
993
""" .stripMargin
982
994
}
983
995
@@ -986,7 +998,7 @@ case class HashAggregateExec(
986
998
if (isVectorizedHashMapEnabled) {
987
999
ctx.INPUT_ROW = fastRowBuffer
988
1000
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
989
- bindReferences(updateExprsForOneFunc, inputAttr )
1001
+ bindReferences(updateExprsForOneFunc, inputAttrs )
990
1002
}
991
1003
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
992
1004
val effectiveCodes = subExprs.codes.mkString(" \n " )
@@ -1016,18 +1028,8 @@ case class HashAggregateExec(
1016
1028
""" .stripMargin
1017
1029
}
1018
1030
1019
-
1020
- val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
1021
- aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
1022
- val maybeSplitCode = splitAggregateExpressions(
1023
- ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
1024
-
1025
- maybeSplitCode.getOrElse {
1026
- aggCodeBlocks.fold(EmptyBlock )(_ + _).code
1027
- }
1028
- } else {
1029
- aggCodeBlocks.fold(EmptyBlock )(_ + _).code
1030
- }
1031
+ val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
1032
+ ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
1031
1033
1032
1034
// If vectorized fast hash map is on, we first generate code to update row
1033
1035
// in vectorized fast hash map, if the previous loop up hit vectorized fast hash map.
@@ -1037,7 +1039,7 @@ case class HashAggregateExec(
1037
1039
| // common sub-expressions
1038
1040
| $effectiveCodes
1039
1041
| // evaluate aggregate functions and update aggregation buffers
1040
- | $codeToEvalAggFunc
1042
+ | $codeToEvalAggFuncs
1041
1043
|} else {
1042
1044
| $updateRowInRegularHashMap
1043
1045
|}
0 commit comments