-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-22103] Move HashAggregateExec parent consume to a separate function in codegen #19324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -242,6 +242,9 @@ class CodegenContext { | |
private val classFunctions: mutable.Map[String, mutable.Map[String, String]] = | ||
mutable.Map(outerClassName -> mutable.Map.empty[String, String]) | ||
|
||
// Verbatim extra code to be added to the OuterClass. | ||
private val extraCode: mutable.ListBuffer[String] = mutable.ListBuffer[String]() | ||
|
||
// Returns the size of the most recently added class. | ||
private def currClassSize(): Int = classSize(classes.head._1) | ||
|
||
|
@@ -328,6 +331,22 @@ class CodegenContext { | |
(inlinedFunctions ++ initNestedClasses ++ declareNestedClasses).mkString("\n") | ||
} | ||
|
||
/** | ||
* Emits any source code added with addExtraCode | ||
*/ | ||
def emitExtraCode(): String = { | ||
extraCode.mkString("\n") | ||
} | ||
|
||
/** | ||
* Add extra source code to the outermost generated class. | ||
* @param code verbatim source code to be added. | ||
*/ | ||
def addExtraCode(code: String): Unit = { | ||
extraCode.append(code) | ||
classSize(outerClassName) += code.length | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
} | ||
|
||
final val JAVA_BOOLEAN = "boolean" | ||
final val JAVA_BYTE = "byte" | ||
final val JAVA_SHORT = "short" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -197,11 +197,14 @@ trait CodegenSupport extends SparkPlan { | |
* | ||
* This should be override by subclass to support codegen. | ||
* | ||
* For example, Filter will generate the code like this: | ||
* Note: The operator should not assume the existence of an outer processing loop, | ||
* which it can jump from with "continue;"! | ||
* | ||
* For example, filter could generate this: | ||
* # code to evaluate the predicate expression, result is isNull1 and value2 | ||
* if (isNull1 || !value2) continue; | ||
* # call consume(), which will call parent.doConsume() | ||
* if (!isNull1 && value2) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this may lead to deeply nested code, but I don't have a better idea for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in reality the filter code generates a |
||
* # call consume(), which will call parent.doConsume() | ||
* } | ||
* | ||
* Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input). | ||
*/ | ||
|
@@ -329,6 +332,15 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co | |
def doCodeGen(): (CodegenContext, CodeAndComment) = { | ||
val ctx = new CodegenContext | ||
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) | ||
|
||
// main next function. | ||
ctx.addNewFunction("processNext", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tangent fix: add processNext() with |
||
s""" | ||
protected void processNext() throws java.io.IOException { | ||
${code.trim} | ||
} | ||
""", inlineToOuterClass = true) | ||
|
||
val source = s""" | ||
public Object generate(Object[] references) { | ||
return new GeneratedIterator(references); | ||
|
@@ -352,9 +364,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co | |
${ctx.initPartition()} | ||
} | ||
|
||
protected void processNext() throws java.io.IOException { | ||
${code.trim} | ||
} | ||
${ctx.emitExtraCode()} | ||
|
||
${ctx.declareAddedFunctions()} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -425,12 +425,14 @@ case class HashAggregateExec( | |
|
||
/** | ||
* Generate the code for output. | ||
* @return function name for the result code. | ||
*/ | ||
private def generateResultCode( | ||
ctx: CodegenContext, | ||
keyTerm: String, | ||
bufferTerm: String, | ||
plan: String): String = { | ||
private def generateResultFunction(ctx: CodegenContext): String = { | ||
val funcName = ctx.freshName("doAggregateWithKeysOutput") | ||
val keyTerm = ctx.freshName("keyTerm") | ||
val bufferTerm = ctx.freshName("bufferTerm") | ||
|
||
val body = | ||
if (modes.contains(Final) || modes.contains(Complete)) { | ||
// generate output using resultExpressions | ||
ctx.currentVars = null | ||
|
@@ -462,18 +464,36 @@ case class HashAggregateExec( | |
$evaluateAggResults | ||
${consume(ctx, resultVars)} | ||
""" | ||
|
||
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) { | ||
// This should be the last operator in a stage, we should output UnsafeRow directly | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tangent fix: The partial aggregation doesn't necessarily have to be the last operator in the stage. E.g. if the shuffle requirement between the partial/final aggregation was already satisfied, or between 2. and 3. in |
||
val joinerTerm = ctx.freshName("unsafeRowJoiner") | ||
ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, | ||
s"$joinerTerm = $plan.createUnsafeJoiner();") | ||
val resultRow = ctx.freshName("resultRow") | ||
// resultExpressions are Attributes of groupingExpressions and aggregateBufferAttributes. | ||
assert(resultExpressions.forall(_.isInstanceOf[Attribute])) | ||
assert(resultExpressions.length == | ||
groupingExpressions.length + aggregateBufferAttributes.length) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we don't have these 2 requirements for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Final/Complete aggregations can have arbitrary projections in their |
||
|
||
ctx.currentVars = null | ||
|
||
ctx.INPUT_ROW = keyTerm | ||
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => | ||
BoundReference(i, e.dataType, e.nullable).genCode(ctx) | ||
} | ||
val evaluateKeyVars = evaluateVariables(keyVars) | ||
|
||
ctx.INPUT_ROW = bufferTerm | ||
val resultBufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => | ||
BoundReference(i, e.dataType, e.nullable).genCode(ctx) | ||
} | ||
val evaluateResultBufferVars = evaluateVariables(resultBufferVars) | ||
|
||
ctx.currentVars = keyVars ++ resultBufferVars | ||
val inputAttrs = resultExpressions.map(_.toAttribute) | ||
val resultVars = resultExpressions.map { e => | ||
BindReferences.bindReference(e, inputAttrs).genCode(ctx) | ||
} | ||
s""" | ||
UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); | ||
${consume(ctx, null, resultRow)} | ||
$evaluateKeyVars | ||
$evaluateResultBufferVars | ||
${consume(ctx, resultVars)} | ||
""" | ||
|
||
} else { | ||
// generate result based on grouping key | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we only go to this branch when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, e.g. for aggregation coming from Distinct. |
||
ctx.INPUT_ROW = keyTerm | ||
|
@@ -483,6 +503,13 @@ case class HashAggregateExec( | |
} | ||
consume(ctx, eval) | ||
} | ||
ctx.addNewFunction(funcName, | ||
s""" | ||
private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm) | ||
throws java.io.IOException { | ||
$body | ||
} | ||
""") | ||
} | ||
|
||
/** | ||
|
@@ -581,11 +608,6 @@ case class HashAggregateExec( | |
val iterTerm = ctx.freshName("mapIter") | ||
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") | ||
|
||
val doAgg = ctx.freshName("doAggregateWithKeys") | ||
val peakMemory = metricTerm(ctx, "peakMemory") | ||
val spillSize = metricTerm(ctx, "spillSize") | ||
val avgHashProbe = metricTerm(ctx, "avgHashProbe") | ||
|
||
def generateGenerateCode(): String = { | ||
if (isFastHashMapEnabled) { | ||
if (isVectorizedHashMapEnabled) { | ||
|
@@ -599,10 +621,14 @@ case class HashAggregateExec( | |
} | ||
} else "" | ||
} | ||
ctx.addExtraCode(generateGenerateCode()) | ||
|
||
val doAgg = ctx.freshName("doAggregateWithKeys") | ||
val peakMemory = metricTerm(ctx, "peakMemory") | ||
val spillSize = metricTerm(ctx, "spillSize") | ||
val avgHashProbe = metricTerm(ctx, "avgHashProbe") | ||
val doAggFuncName = ctx.addNewFunction(doAgg, | ||
s""" | ||
${generateGenerateCode} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a tangent fix: this generated code for the hash map was piggy-backed here together with the |
||
private void $doAgg() throws java.io.IOException { | ||
$hashMapTerm = $thisPlan.createHashMap(); | ||
${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | ||
|
@@ -618,7 +644,7 @@ case class HashAggregateExec( | |
// generate code for output | ||
val keyTerm = ctx.freshName("aggKey") | ||
val bufferTerm = ctx.freshName("aggBuffer") | ||
val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan) | ||
val outputFunc = generateResultFunction(ctx) | ||
val numOutput = metricTerm(ctx, "numOutputRows") | ||
|
||
// The child could change `copyResult` to true, but we had already consumed all the rows, | ||
|
@@ -641,7 +667,7 @@ case class HashAggregateExec( | |
$numOutput.add(1); | ||
UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); | ||
UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); | ||
$outputCode | ||
$outputFunc($keyTerm, $bufferTerm); | ||
|
||
if (shouldStop()) return; | ||
} | ||
|
@@ -654,18 +680,23 @@ case class HashAggregateExec( | |
val row = ctx.freshName("fastHashMapRow") | ||
ctx.currentVars = null | ||
ctx.INPUT_ROW = row | ||
var schema: StructType = groupingKeySchema | ||
bufferSchema.foreach(i => schema = schema.add(i)) | ||
val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex | ||
.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }) | ||
val generateKeyRow = GenerateUnsafeProjection.createCode(ctx, | ||
groupingKeySchema.toAttributes.zipWithIndex | ||
.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) } | ||
) | ||
val generateBufferRow = GenerateUnsafeProjection.createCode(ctx, | ||
bufferSchema.toAttributes.zipWithIndex | ||
.map { case (attr, i) => | ||
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) | ||
s""" | ||
| while ($iterTermForFastHashMap.hasNext()) { | ||
| $numOutput.add(1); | ||
| org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = | ||
| (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) | ||
| $iterTermForFastHashMap.next(); | ||
| ${generateRow.code} | ||
| ${consume(ctx, Seq.empty, {generateRow.value})} | ||
| ${generateKeyRow.code} | ||
| ${generateBufferRow.code} | ||
| $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we didn't call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| | ||
| if (shouldStop()) return; | ||
| } | ||
|
@@ -692,7 +723,7 @@ case class HashAggregateExec( | |
$numOutput.add(1); | ||
UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | ||
UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | ||
$outputCode | ||
$outputFunc($keyTerm, $bufferTerm); | ||
|
||
if (shouldStop()) return; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -201,11 +201,14 @@ case class FilterExec(condition: Expression, child: SparkPlan) | |
ev | ||
} | ||
|
||
// Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is tricky, how hard it is to fix all places that use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah i see, you are trying to avoid generating deeply nested if-else branches. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
s""" | ||
|$generated | ||
|$nullChecks | ||
|$numOutput.add(1); | ||
|${consume(ctx, resultVars)} | ||
|do { | ||
| $generated | ||
| $nullChecks | ||
| $numOutput.add(1); | ||
| ${consume(ctx, resultVars)} | ||
|} while(false); | ||
""".stripMargin | ||
} | ||
|
||
|
@@ -316,9 +319,10 @@ case class SampleExec( | |
""".stripMargin.trim) | ||
|
||
s""" | ||
| if ($sampler.sample() == 0) continue; | ||
| $numOutput.add(1); | ||
| ${consume(ctx, input)} | ||
| if ($sampler.sample() != 0) { | ||
| $numOutput.add(1); | ||
| ${consume(ctx, input)} | ||
| } | ||
""".stripMargin.trim | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd call it
addInnerClass
, as ideally you can't add arbitrary code to outer class.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 Although it doesn't prevent you going to add functions, but we have
addNewFunction
for it. So we'd better claim that this is just for inner class.