Skip to content

Commit 65a9ac2

Browse files
maropudongjoon-hyun
authored andcommitted
[SPARK-30027][SQL] Support codegen for aggregate filters in HashAggregateExec
### What changes were proposed in this pull request? This pr intends to support code generation for `HashAggregateExec` with filters. Quick benchmark results: ``` $ ./bin/spark-shell --master=local[1] --conf spark.driver.memory=8g --conf spark.sql.shuffle.partitions=1 -v scala> spark.range(100000000).selectExpr("id % 3 as k1", "id % 5 as k2", "rand() as v1", "rand() as v2").write.saveAsTable("t") scala> sql("SELECT k1, k2, AVG(v1) FILTER (WHERE v2 > 0.5) FROM t GROUP BY k1, k2").write.format("noop").mode("overwrite").save() >> Before this PR Elapsed time: 16.170697619s >> After this PR Elapsed time: 6.7825313s ``` The query above is compiled into code below; ``` ... /* 285 */ private void agg_doAggregate_avg_0(boolean agg_exprIsNull_2_0, org.apache.spark.sql.catalyst.InternalRow agg_unsafeRowAggBuffer_0, double agg_expr_2_0) throws java.io.IOException { /* 286 */ // evaluate aggregate function for avg /* 287 */ boolean agg_isNull_10 = true; /* 288 */ double agg_value_12 = -1.0; /* 289 */ boolean agg_isNull_11 = agg_unsafeRowAggBuffer_0.isNullAt(0); /* 290 */ double agg_value_13 = agg_isNull_11 ? /* 291 */ -1.0 : (agg_unsafeRowAggBuffer_0.getDouble(0)); /* 292 */ if (!agg_isNull_11) { /* 293 */ agg_agg_isNull_12_0 = true; /* 294 */ double agg_value_14 = -1.0; /* 295 */ do { /* 296 */ if (!agg_exprIsNull_2_0) { /* 297 */ agg_agg_isNull_12_0 = false; /* 298 */ agg_value_14 = agg_expr_2_0; /* 299 */ continue; /* 300 */ } /* 301 */ /* 302 */ if (!false) { /* 303 */ agg_agg_isNull_12_0 = false; /* 304 */ agg_value_14 = 0.0D; /* 305 */ continue; /* 306 */ } /* 307 */ /* 308 */ } while (false); /* 309 */ /* 310 */ agg_isNull_10 = false; // resultCode could change nullability. /* 311 */ /* 312 */ agg_value_12 = agg_value_13 + agg_value_14; /* 313 */ /* 314 */ } /* 315 */ boolean agg_isNull_15 = false; /* 316 */ long agg_value_17 = -1L; /* 317 */ if (!false && agg_exprIsNull_2_0) { /* 318 */ boolean agg_isNull_18 = agg_unsafeRowAggBuffer_0.isNullAt(1); /* 319 */ long agg_value_20 = agg_isNull_18 ? /* 320 */ -1L : (agg_unsafeRowAggBuffer_0.getLong(1)); /* 321 */ agg_isNull_15 = agg_isNull_18; /* 322 */ agg_value_17 = agg_value_20; /* 323 */ } else { /* 324 */ boolean agg_isNull_19 = true; /* 325 */ long agg_value_21 = -1L; /* 326 */ boolean agg_isNull_20 = agg_unsafeRowAggBuffer_0.isNullAt(1); /* 327 */ long agg_value_22 = agg_isNull_20 ? /* 328 */ -1L : (agg_unsafeRowAggBuffer_0.getLong(1)); /* 329 */ if (!agg_isNull_20) { /* 330 */ agg_isNull_19 = false; // resultCode could change nullability. /* 331 */ /* 332 */ agg_value_21 = agg_value_22 + 1L; /* 333 */ /* 334 */ } /* 335 */ agg_isNull_15 = agg_isNull_19; /* 336 */ agg_value_17 = agg_value_21; /* 337 */ } /* 338 */ // update unsafe row buffer /* 339 */ if (!agg_isNull_10) { /* 340 */ agg_unsafeRowAggBuffer_0.setDouble(0, agg_value_12); /* 341 */ } else { /* 342 */ agg_unsafeRowAggBuffer_0.setNullAt(0); /* 343 */ } /* 344 */ /* 345 */ if (!agg_isNull_15) { /* 346 */ agg_unsafeRowAggBuffer_0.setLong(1, agg_value_17); /* 347 */ } else { /* 348 */ agg_unsafeRowAggBuffer_0.setNullAt(1); /* 349 */ } /* 350 */ } ... ``` ### Why are the changes needed? For high performance. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing tests. Closes #27019 from maropu/AggregateFilterCodegen. Authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
1 parent 9c30116 commit 65a9ac2

File tree

5 files changed

+151
-106
lines changed

5 files changed

+151
-106
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,24 @@ trait PredicateHelper extends AliasHelper with Logging {
242242
None
243243
}
244244
}
245+
246+
// If one expression and its children are null intolerant, it is null intolerant.
247+
protected def isNullIntolerant(expr: Expression): Boolean = expr match {
248+
case e: NullIntolerant => e.children.forall(isNullIntolerant)
249+
case _ => false
250+
}
251+
252+
protected def outputWithNullability(
253+
output: Seq[Attribute],
254+
nonNullAttrExprIds: Seq[ExprId]): Seq[Attribute] = {
255+
output.map { a =>
256+
if (a.nullable && nonNullAttrExprIds.contains(a.exprId)) {
257+
a.withNullability(false)
258+
} else {
259+
a
260+
}
261+
}
262+
}
245263
}
246264

247265
@ExpressionDescription(

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ case class HashAggregateExec(
5353
resultExpressions: Seq[NamedExpression],
5454
child: SparkPlan)
5555
extends BaseAggregateExec
56-
with BlockingOperatorWithCodegen {
56+
with BlockingOperatorWithCodegen
57+
with GeneratePredicateHelper {
5758

5859
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
5960

@@ -131,10 +132,8 @@ case class HashAggregateExec(
131132
override def usedInputs: AttributeSet = inputSet
132133

133134
override def supportCodegen: Boolean = {
134-
// ImperativeAggregate and filter predicate are not supported right now
135-
// TODO: SPARK-30027 Support codegen for filter exprs in HashAggregateExec
136-
!(aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) ||
137-
aggregateExpressions.exists(_.filter.isDefined))
135+
// ImperativeAggregate are not supported right now
136+
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
138137
}
139138

140139
override def inputRDDs(): Seq[RDD[InternalRow]] = {
@@ -254,7 +253,7 @@ case class HashAggregateExec(
254253
aggNames: Seq[String],
255254
aggBufferUpdatingExprs: Seq[Seq[Expression]],
256255
aggCodeBlocks: Seq[Block],
257-
subExprs: Map[Expression, SubExprEliminationState]): Option[String] = {
256+
subExprs: Map[Expression, SubExprEliminationState]): Option[Seq[String]] = {
258257
val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil }
259258
if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
260259
// `SimpleExprValue`s cannot be used as an input variable for split functions, so
@@ -293,7 +292,7 @@ case class HashAggregateExec(
293292
val inputVariables = args.map(_.variableName).mkString(", ")
294293
s"$doAggFuncName($inputVariables);"
295294
}
296-
Some(splitCodes.mkString("\n").trim)
295+
Some(splitCodes)
297296
} else {
298297
val errMsg = "Failed to split aggregate code into small functions because the parameter " +
299298
"length of at least one split function went over the JVM limit: " +
@@ -308,6 +307,39 @@ case class HashAggregateExec(
308307
}
309308
}
310309

310+
private def generateEvalCodeForAggFuncs(
311+
ctx: CodegenContext,
312+
input: Seq[ExprCode],
313+
inputAttrs: Seq[Attribute],
314+
boundUpdateExprs: Seq[Seq[Expression]],
315+
aggNames: Seq[String],
316+
aggCodeBlocks: Seq[Block],
317+
subExprs: SubExprCodes): String = {
318+
val aggCodes = if (conf.codegenSplitAggregateFunc &&
319+
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
320+
val maybeSplitCodes = splitAggregateExpressions(
321+
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
322+
323+
maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code))
324+
} else {
325+
aggCodeBlocks.map(_.code)
326+
}
327+
328+
aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map {
329+
case (aggCode, (Partial | Complete, Some(condition))) =>
330+
// Note: wrap in "do { } while(false);", so the generated checks can jump out
331+
// with "continue;"
332+
s"""
333+
|do {
334+
| ${generatePredicateCode(ctx, condition, inputAttrs, input)}
335+
| $aggCode
336+
|} while(false);
337+
""".stripMargin
338+
case (aggCode, _) =>
339+
aggCode
340+
}.mkString("\n")
341+
}
342+
311343
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
312344
// only have DeclarativeAggregate
313345
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
@@ -354,24 +386,14 @@ case class HashAggregateExec(
354386
""".stripMargin
355387
}
356388

357-
val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
358-
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
359-
val maybeSplitCode = splitAggregateExpressions(
360-
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
361-
362-
maybeSplitCode.getOrElse {
363-
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
364-
}
365-
} else {
366-
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
367-
}
368-
389+
val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
390+
ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
369391
s"""
370392
|// do aggregate
371393
|// common sub-expressions
372394
|$effectiveCodes
373395
|// evaluate aggregate functions and update aggregation buffers
374-
|$codeToEvalAggFunc
396+
|$codeToEvalAggFuncs
375397
""".stripMargin
376398
}
377399

@@ -908,7 +930,7 @@ case class HashAggregateExec(
908930
}
909931
}
910932

911-
val inputAttr = aggregateBufferAttributes ++ inputAttributes
933+
val inputAttrs = aggregateBufferAttributes ++ inputAttributes
912934
// Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when
913935
// generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while
914936
// generating input columns, we use `currentVars`.
@@ -930,7 +952,7 @@ case class HashAggregateExec(
930952
val updateRowInRegularHashMap: String = {
931953
ctx.INPUT_ROW = unsafeRowBuffer
932954
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
933-
bindReferences(updateExprsForOneFunc, inputAttr)
955+
bindReferences(updateExprsForOneFunc, inputAttrs)
934956
}
935957
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
936958
val effectiveCodes = subExprs.codes.mkString("\n")
@@ -961,23 +983,13 @@ case class HashAggregateExec(
961983
""".stripMargin
962984
}
963985

964-
val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
965-
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
966-
val maybeSplitCode = splitAggregateExpressions(
967-
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
968-
969-
maybeSplitCode.getOrElse {
970-
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
971-
}
972-
} else {
973-
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
974-
}
975-
986+
val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
987+
ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
976988
s"""
977989
|// common sub-expressions
978990
|$effectiveCodes
979991
|// evaluate aggregate functions and update aggregation buffers
980-
|$codeToEvalAggFunc
992+
|$codeToEvalAggFuncs
981993
""".stripMargin
982994
}
983995

@@ -986,7 +998,7 @@ case class HashAggregateExec(
986998
if (isVectorizedHashMapEnabled) {
987999
ctx.INPUT_ROW = fastRowBuffer
9881000
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
989-
bindReferences(updateExprsForOneFunc, inputAttr)
1001+
bindReferences(updateExprsForOneFunc, inputAttrs)
9901002
}
9911003
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
9921004
val effectiveCodes = subExprs.codes.mkString("\n")
@@ -1016,18 +1028,8 @@ case class HashAggregateExec(
10161028
""".stripMargin
10171029
}
10181030

1019-
1020-
val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
1021-
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
1022-
val maybeSplitCode = splitAggregateExpressions(
1023-
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
1024-
1025-
maybeSplitCode.getOrElse {
1026-
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
1027-
}
1028-
} else {
1029-
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
1030-
}
1031+
val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
1032+
ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
10311033

10321034
// If vectorized fast hash map is on, we first generate code to update row
10331035
// in vectorized fast hash map, if the previous loop up hit vectorized fast hash map.
@@ -1037,7 +1039,7 @@ case class HashAggregateExec(
10371039
| // common sub-expressions
10381040
| $effectiveCodes
10391041
| // evaluate aggregate functions and update aggregation buffers
1040-
| $codeToEvalAggFunc
1042+
| $codeToEvalAggFuncs
10411043
|} else {
10421044
| $updateRowInRegularHashMap
10431045
|}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 76 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -109,59 +109,39 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
109109
}
110110
}
111111

112-
/** Physical plan for Filter. */
113-
case class FilterExec(condition: Expression, child: SparkPlan)
114-
extends UnaryExecNode with CodegenSupport with PredicateHelper {
115-
116-
// Split out all the IsNotNulls from condition.
117-
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
118-
case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
119-
case _ => false
120-
}
121-
122-
// If one expression and its children are null intolerant, it is null intolerant.
123-
private def isNullIntolerant(expr: Expression): Boolean = expr match {
124-
case e: NullIntolerant => e.children.forall(isNullIntolerant)
125-
case _ => false
126-
}
127-
128-
// The columns that will filtered out by `IsNotNull` could be considered as not nullable.
129-
private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId)
130-
131-
// Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate
132-
// all the variables at the beginning to take advantage of short circuiting.
133-
override def usedInputs: AttributeSet = AttributeSet.empty
134-
135-
override def output: Seq[Attribute] = {
136-
child.output.map { a =>
137-
if (a.nullable && notNullAttributes.contains(a.exprId)) {
138-
a.withNullability(false)
139-
} else {
140-
a
141-
}
112+
trait GeneratePredicateHelper extends PredicateHelper {
113+
self: CodegenSupport =>
114+
115+
protected def generatePredicateCode(
116+
ctx: CodegenContext,
117+
condition: Expression,
118+
inputAttrs: Seq[Attribute],
119+
inputExprCode: Seq[ExprCode]): String = {
120+
val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
121+
case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(AttributeSet(inputAttrs))
122+
case _ => false
142123
}
143-
}
144-
145-
override lazy val metrics = Map(
146-
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
147-
148-
override def inputRDDs(): Seq[RDD[InternalRow]] = {
149-
child.asInstanceOf[CodegenSupport].inputRDDs()
150-
}
151-
152-
protected override def doProduce(ctx: CodegenContext): String = {
153-
child.asInstanceOf[CodegenSupport].produce(ctx, this)
154-
}
155-
156-
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
157-
val numOutput = metricTerm(ctx, "numOutputRows")
158-
124+
val nonNullAttrExprIds = notNullPreds.flatMap(_.references).distinct.map(_.exprId)
125+
val outputAttrs = outputWithNullability(inputAttrs, nonNullAttrExprIds)
126+
generatePredicateCode(
127+
ctx, inputAttrs, inputExprCode, outputAttrs, notNullPreds, otherPreds,
128+
nonNullAttrExprIds)
129+
}
130+
131+
protected def generatePredicateCode(
132+
ctx: CodegenContext,
133+
inputAttrs: Seq[Attribute],
134+
inputExprCode: Seq[ExprCode],
135+
outputAttrs: Seq[Attribute],
136+
notNullPreds: Seq[Expression],
137+
otherPreds: Seq[Expression],
138+
nonNullAttrExprIds: Seq[ExprId]): String = {
159139
/**
160140
* Generates code for `c`, using `in` for input attributes and `attrs` for nullability.
161141
*/
162142
def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = {
163143
val bound = BindReferences.bindReference(c, attrs)
164-
val evaluated = evaluateRequiredVariables(child.output, in, c.references)
144+
val evaluated = evaluateRequiredVariables(inputAttrs, in, c.references)
165145

166146
// Generate the code for the predicate.
167147
val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx)
@@ -195,10 +175,10 @@ case class FilterExec(condition: Expression, child: SparkPlan)
195175
if (idx != -1 && !generatedIsNotNullChecks(idx)) {
196176
generatedIsNotNullChecks(idx) = true
197177
// Use the child's output. The nullability is what the child produced.
198-
genPredicate(notNullPreds(idx), input, child.output)
199-
} else if (notNullAttributes.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) {
178+
genPredicate(notNullPreds(idx), inputExprCode, inputAttrs)
179+
} else if (nonNullAttrExprIds.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) {
200180
extraIsNotNullAttrs += r
201-
genPredicate(IsNotNull(r), input, child.output)
181+
genPredicate(IsNotNull(r), inputExprCode, inputAttrs)
202182
} else {
203183
""
204184
}
@@ -208,18 +188,61 @@ case class FilterExec(condition: Expression, child: SparkPlan)
208188
// enforced them with the IsNotNull checks above.
209189
s"""
210190
|$nullChecks
211-
|${genPredicate(c, input, output)}
191+
|${genPredicate(c, inputExprCode, outputAttrs)}
212192
""".stripMargin.trim
213193
}.mkString("\n")
214194

215195
val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) =>
216196
if (!generatedIsNotNullChecks(idx)) {
217-
genPredicate(c, input, child.output)
197+
genPredicate(c, inputExprCode, inputAttrs)
218198
} else {
219199
""
220200
}
221201
}.mkString("\n")
222202

203+
s"""
204+
|$generated
205+
|$nullChecks
206+
""".stripMargin
207+
}
208+
}
209+
210+
/** Physical plan for Filter. */
211+
case class FilterExec(condition: Expression, child: SparkPlan)
212+
extends UnaryExecNode with CodegenSupport with GeneratePredicateHelper {
213+
214+
// Split out all the IsNotNulls from condition.
215+
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
216+
case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
217+
case _ => false
218+
}
219+
220+
// The columns that will filtered out by `IsNotNull` could be considered as not nullable.
221+
private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId)
222+
223+
// Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate
224+
// all the variables at the beginning to take advantage of short circuiting.
225+
override def usedInputs: AttributeSet = AttributeSet.empty
226+
227+
override def output: Seq[Attribute] = outputWithNullability(child.output, notNullAttributes)
228+
229+
override lazy val metrics = Map(
230+
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
231+
232+
override def inputRDDs(): Seq[RDD[InternalRow]] = {
233+
child.asInstanceOf[CodegenSupport].inputRDDs()
234+
}
235+
236+
protected override def doProduce(ctx: CodegenContext): String = {
237+
child.asInstanceOf[CodegenSupport].produce(ctx, this)
238+
}
239+
240+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
241+
val numOutput = metricTerm(ctx, "numOutputRows")
242+
243+
val predicateCode = generatePredicateCode(
244+
ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes)
245+
223246
// Reset the isNull to false for the not-null columns, then the followed operators could
224247
// generate better code (remove dead branches).
225248
val resultVars = input.zipWithIndex.map { case (ev, i) =>
@@ -232,8 +255,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
232255
// Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;"
233256
s"""
234257
|do {
235-
| $generated
236-
| $nullChecks
258+
| $predicateCode
237259
| $numOutput.add(1);
238260
| ${consume(ctx, resultVars)}
239261
|} while(false);

0 commit comments

Comments
 (0)