Skip to content

[SPARK-22103] Move HashAggregateExec parent consume to a separate function in codegen #19324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ class CodegenContext {
private val classFunctions: mutable.Map[String, mutable.Map[String, String]] =
mutable.Map(outerClassName -> mutable.Map.empty[String, String])

// Verbatim extra code to be added to the OuterClass.
private val extraCode: mutable.ListBuffer[String] = mutable.ListBuffer[String]()

// Returns the size of the most recently added class.
private def currClassSize(): Int = classSize(classes.head._1)

Expand Down Expand Up @@ -328,6 +331,22 @@ class CodegenContext {
(inlinedFunctions ++ initNestedClasses ++ declareNestedClasses).mkString("\n")
}

/**
* Emits any source code added with addExtraCode
*/
def emitExtraCode(): String = {
extraCode.mkString("\n")
}

/**
* Add extra source code to the outermost generated class.
* @param code verbatim source code to be added.
*/
def addExtraCode(code: String): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd call it addInnerClass, as ideally you can't add arbitrary code to outer class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 Although it doesn't prevent you going to add functions, but we have addNewFunction for it. So we'd better claim that this is just for inner class.

extraCode.append(code)
classSize(outerClassName) += code.length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The classSize is mainly used to deal with the limit of number of named constants. So I think we don't need to add extra code size into it, if we only add inner class?

}

final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,14 @@ trait CodegenSupport extends SparkPlan {
*
* This should be override by subclass to support codegen.
*
* For example, Filter will generate the code like this:
* Note: The operator should not assume the existence of an outer processing loop,
* which it can jump from with "continue;"!
*
* For example, filter could generate this:
* # code to evaluate the predicate expression, result is isNull1 and value2
* if (isNull1 || !value2) continue;
* # call consume(), which will call parent.doConsume()
* if (!isNull1 && value2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may lead to deeply nested code, but I don't have a better idea for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in reality the filter code generates a do { } while(false) with continue inside to jump out, just like it did before. There's appropriate comment to it there.
I didn't want to complicate this example here, so changing the "will generate" to "could generate" is intentional to kind of show that it could, but not necessarily will :-)

* # call consume(), which will call parent.doConsume()
* }
*
* Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input).
*/
Expand Down Expand Up @@ -329,6 +332,15 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
def doCodeGen(): (CodegenContext, CodeAndComment) = {
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)

// main next function.
ctx.addNewFunction("processNext",
Copy link
Contributor Author

@juliuszsompolski juliuszsompolski Sep 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tangent fix: add processNext() with addNewFunction, so that it is also taken into account by #18810

s"""
protected void processNext() throws java.io.IOException {
${code.trim}
}
""", inlineToOuterClass = true)

val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
Expand All @@ -352,9 +364,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
${ctx.initPartition()}
}

protected void processNext() throws java.io.IOException {
${code.trim}
}
${ctx.emitExtraCode()}

${ctx.declareAddedFunctions()}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,14 @@ case class HashAggregateExec(

/**
* Generate the code for output.
* @return function name for the result code.
*/
private def generateResultCode(
ctx: CodegenContext,
keyTerm: String,
bufferTerm: String,
plan: String): String = {
private def generateResultFunction(ctx: CodegenContext): String = {
val funcName = ctx.freshName("doAggregateWithKeysOutput")
val keyTerm = ctx.freshName("keyTerm")
val bufferTerm = ctx.freshName("bufferTerm")

val body =
if (modes.contains(Final) || modes.contains(Complete)) {
// generate output using resultExpressions
ctx.currentVars = null
Expand Down Expand Up @@ -462,18 +464,36 @@ case class HashAggregateExec(
$evaluateAggResults
${consume(ctx, resultVars)}
"""

} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
// This should be the last operator in a stage, we should output UnsafeRow directly
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tangent fix: The partial aggregation doesn't necessarily have to be the last operator in the stage. E.g. if the shuffle requirement between the partial/final aggregation was already satisfied, or between 2. and 3. in planAggregateWithOneDistinct. Outputting the UnsafeRow through UnsafeRowJoiner was unnecessary then.

val joinerTerm = ctx.freshName("unsafeRowJoiner")
ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
s"$joinerTerm = $plan.createUnsafeJoiner();")
val resultRow = ctx.freshName("resultRow")
// resultExpressions are Attributes of groupingExpressions and aggregateBufferAttributes.
assert(resultExpressions.forall(_.isInstanceOf[Attribute]))
assert(resultExpressions.length ==
groupingExpressions.length + aggregateBufferAttributes.length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we don't have these 2 requirements for the modes.contains(Final) || modes.contains(Complete) branch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Final/Complete aggregations can have arbitrary projections in their resultExpressions, while partial aggregations are always constructed with only the grouping keys and aggregate expressions. The code that was here before with the UnsafeRowJoiner was using this assumption, so now I put it into assertion.


ctx.currentVars = null

ctx.INPUT_ROW = keyTerm
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).genCode(ctx)
}
val evaluateKeyVars = evaluateVariables(keyVars)

ctx.INPUT_ROW = bufferTerm
val resultBufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).genCode(ctx)
}
val evaluateResultBufferVars = evaluateVariables(resultBufferVars)

ctx.currentVars = keyVars ++ resultBufferVars
val inputAttrs = resultExpressions.map(_.toAttribute)
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, inputAttrs).genCode(ctx)
}
s"""
UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
${consume(ctx, null, resultRow)}
$evaluateKeyVars
$evaluateResultBufferVars
${consume(ctx, resultVars)}
"""

} else {
// generate result based on grouping key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we only go to this branch when aggregateExpressions is empty, is that possible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, e.g. for aggregation coming from Distinct.

ctx.INPUT_ROW = keyTerm
Expand All @@ -483,6 +503,13 @@ case class HashAggregateExec(
}
consume(ctx, eval)
}
ctx.addNewFunction(funcName,
s"""
private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm)
throws java.io.IOException {
$body
}
""")
}

/**
Expand Down Expand Up @@ -581,11 +608,6 @@ case class HashAggregateExec(
val iterTerm = ctx.freshName("mapIter")
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")

val doAgg = ctx.freshName("doAggregateWithKeys")
val peakMemory = metricTerm(ctx, "peakMemory")
val spillSize = metricTerm(ctx, "spillSize")
val avgHashProbe = metricTerm(ctx, "avgHashProbe")

def generateGenerateCode(): String = {
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
Expand All @@ -599,10 +621,14 @@ case class HashAggregateExec(
}
} else ""
}
ctx.addExtraCode(generateGenerateCode())

val doAgg = ctx.freshName("doAggregateWithKeys")
val peakMemory = metricTerm(ctx, "peakMemory")
val spillSize = metricTerm(ctx, "spillSize")
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
val doAggFuncName = ctx.addNewFunction(doAgg,
s"""
${generateGenerateCode}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a tangent fix: this generated code for the hash map was piggy-backed here together with the doAggregateWithKeys function, and it could become inaccessible from the top function if the function gets generated in a nested class (after #18075)

private void $doAgg() throws java.io.IOException {
$hashMapTerm = $thisPlan.createHashMap();
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
Expand All @@ -618,7 +644,7 @@ case class HashAggregateExec(
// generate code for output
val keyTerm = ctx.freshName("aggKey")
val bufferTerm = ctx.freshName("aggBuffer")
val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)
val outputFunc = generateResultFunction(ctx)
val numOutput = metricTerm(ctx, "numOutputRows")

// The child could change `copyResult` to true, but we had already consumed all the rows,
Expand All @@ -641,7 +667,7 @@ case class HashAggregateExec(
$numOutput.add(1);
UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
$outputCode
$outputFunc($keyTerm, $bufferTerm);

if (shouldStop()) return;
}
Expand All @@ -654,18 +680,23 @@ case class HashAggregateExec(
val row = ctx.freshName("fastHashMapRow")
ctx.currentVars = null
ctx.INPUT_ROW = row
var schema: StructType = groupingKeySchema
bufferSchema.foreach(i => schema = schema.add(i))
val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex
.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) })
val generateKeyRow = GenerateUnsafeProjection.createCode(ctx,
groupingKeySchema.toAttributes.zipWithIndex
.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }
)
val generateBufferRow = GenerateUnsafeProjection.createCode(ctx,
bufferSchema.toAttributes.zipWithIndex
.map { case (attr, i) =>
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) })
s"""
| while ($iterTermForFastHashMap.hasNext()) {
| $numOutput.add(1);
| org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row =
| (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row)
| $iterTermForFastHashMap.next();
| ${generateRow.code}
| ${consume(ctx, Seq.empty, {generateRow.value})}
| ${generateKeyRow.code}
| ${generateBufferRow.code}
| $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we didn't call outputCode before, are you fixing a potential bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generateRow.code was doing the job of outputCode before - i.e. putting all expected output into one UnsafeRow, from which the parent can consume it.

|
| if (shouldStop()) return;
| }
Expand All @@ -692,7 +723,7 @@ case class HashAggregateExec(
$numOutput.add(1);
UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
$outputCode
$outputFunc($keyTerm, $bufferTerm);

if (shouldStop()) return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,14 @@ case class FilterExec(condition: Expression, child: SparkPlan)
ev
}

// Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is tricky, how hard it is to fix all places that use continue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see, you are trying to avoid generating deeply nested if-else branches.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

genPredicate and generated ~50 lines above would have to be rewritten to now use continue. As you pointed in a previous comment, that would potentially lead to very nested scopes. Shouldn't be a problem for the compiler; for code generation the genPredicate would have to maintain these scopes and where to end them - i.e. wherever it not places a continue, it would have to open a nested scope, and then it would have to be closed in a correct place.

s"""
|$generated
|$nullChecks
|$numOutput.add(1);
|${consume(ctx, resultVars)}
|do {
| $generated
| $nullChecks
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
|} while(false);
""".stripMargin
}

Expand Down Expand Up @@ -316,9 +319,10 @@ case class SampleExec(
""".stripMargin.trim)

s"""
| if ($sampler.sample() == 0) continue;
| $numOutput.add(1);
| ${consume(ctx, input)}
| if ($sampler.sample() != 0) {
| $numOutput.add(1);
| ${consume(ctx, input)}
| }
""".stripMargin.trim
}
}
Expand Down
Loading