Skip to content

[SPARK-33092][SQL] Support subexpression elimination in ProjectExec #29975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,13 @@ case class SubExprEliminationState(isNull: ExprValue, value: ExprValue)
* @param codes Strings representing the codes that evaluate common subexpressions.
* @param states Foreach expression that is participating in subexpression elimination,
* the state to use.
* @param exprCodesNeedEvaluate Some expression codes that need to be evaluated before
* calling common subexpressions.
*/
case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState])
case class SubExprCodes(
codes: Seq[String],
states: Map[Expression, SubExprEliminationState],
exprCodesNeedEvaluate: Seq[ExprCode])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change needed to support subexpr elimination in ProjectExec? What I'm interested in is that why we didn't need this change when supporting it in HashAggregateExec.

Copy link
Member Author

@viirya viirya Oct 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is needed. ProjectExec doesn't require all its child's outputs to be evaluated in advance. Instead it only early evaluates the outputs used more than twice (deferring evaluation). So we need to extract these variables used by subexpressions and evaluate them before subexpressions. In HashAggregateExec we don't need to consider that. The simplest way is to evaluate all child's outputs, of course.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a suggestion: exprCodesNeedEvaluate -> exprCodesForEarlyEvals?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exprCodesForEarlyEvals sounds confusing. They are not for early evaluating something but needed for evaluating subexpressions.


/**
* The main information about a new added function.
Expand Down Expand Up @@ -1044,7 +1049,7 @@ class CodegenContext extends Logging {
// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
val commonExprVals = commonExprs.map(_.head.genCode(this))
lazy val commonExprVals = commonExprs.map(_.head.genCode(this))

lazy val nonSplitExprCode = {
commonExprs.zip(commonExprVals).map { case (exprs, eval) =>
Expand All @@ -1055,10 +1060,17 @@ class CodegenContext extends Logging {
}
}

val codes = if (commonExprVals.map(_.code.length).sum > SQLConf.get.methodSplitThreshold) {
val inputVarsForAllFuncs = commonExprs.map { expr =>
getLocalInputVariableValues(this, expr.head).toSeq
}
// For some operators, they do not require all its child's outputs to be evaluated in advance.
// Instead it only early evaluates part of outputs, for example, `ProjectExec` only early
// evaluate the outputs used more than twice. So we need to extract these variables used by
// subexpressions and evaluate them before subexpressions.
val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ProjectExec doesn't require all its child's outputs to be evaluated in advance. Instead it only early evaluates the outputs used more than twice (deferring evaluation). So we need to extract these variables used by subexpressions and evaluate them before subexpressions

Could you leave some comments here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr.head)
(inputVars.toSeq, exprCodes.toSeq)
}.unzip

val splitThreshold = SQLConf.get.methodSplitThreshold
val codes = if (commonExprVals.map(_.code.length).sum > splitThreshold) {
if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
commonExprs.zipWithIndex.map { case (exprs, i) =>
val expr = exprs.head
Expand Down Expand Up @@ -1109,7 +1121,7 @@ class CodegenContext extends Logging {
} else {
nonSplitExprCode
}
SubExprCodes(codes, localSubExprEliminationExprs.toMap)
SubExprCodes(codes, localSubExprEliminationExprs.toMap, exprCodesNeedEvaluate.flatten)
}

/**
Expand Down Expand Up @@ -1732,15 +1744,23 @@ object CodeGenerator extends Logging {
}

/**
* Extracts all the input variables from references and subexpression elimination states
* for a given `expr`. This result will be used to split the generated code of
* expressions into multiple functions.
* This methods returns two values in a Tuple.
*
* First value: Extracts all the input variables from references and subexpression
* elimination states for a given `expr`. This result will be used to split the
* generated code of expressions into multiple functions.
*
* Second value: Returns the set of `ExprCodes`s which are necessary codes before
* evaluating subexpressions.
*/
def getLocalInputVariableValues(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you describe what's the second value of the returned value in the code comment above?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

ctx: CodegenContext,
expr: Expression,
subExprs: Map[Expression, SubExprEliminationState] = Map.empty): Set[VariableValue] = {
subExprs: Map[Expression, SubExprEliminationState] = Map.empty)
: (Set[VariableValue], Set[ExprCode]) = {
val argSet = mutable.Set[VariableValue]()
val exprCodesNeedEvaluate = mutable.Set[ExprCode]()

if (ctx.INPUT_ROW != null) {
argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow])
}
Expand All @@ -1761,16 +1781,21 @@ object CodeGenerator extends Logging {

case ref: BoundReference if ctx.currentVars != null &&
ctx.currentVars(ref.ordinal) != null =>
val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal)
collectLocalVariable(value)
collectLocalVariable(isNull)
val exprCode = ctx.currentVars(ref.ordinal)
// If the referred variable is not evaluated yet.
if (exprCode.code != EmptyBlock) {
exprCodesNeedEvaluate += exprCode.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this copy? A unnecessary copy can happen if exprCodesNeedEvaluate already has the same entry?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying exprCode because we need the unevaluated code in exprCode, but we also need to empty code of exprCode. We need copied code so we can evaluate them before evaluating subexpressions. We need to empty code of exprCode so we don't re-evaluate the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to empty code of exprCode so we don't re-evaluate the code.

AFAIK when we empty the code, we also do a copy, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exprCode.code = EmptyBlock, do you mean this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I was wrong, the copy here is necessary.

exprCode.code = EmptyBlock
}
collectLocalVariable(exprCode.value)
collectLocalVariable(exprCode.isNull)

case e =>
stack.pushAll(e.children)
}
}

argSet.toSet
(argSet.toSet, exprCodesNeedEvaluate.toSet)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set instead of Seq here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are already Set. You mean using Seq?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, my bad. Yes.

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ case class HashAggregateExec(
} else {
val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc =>
val inputVarsForOneFunc = aggExprsForOneFunc.map(
CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq
CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)._1).reduce(_ ++ _).toSeq
val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc)

// Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,23 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val exprs = bindReferences[Expression](projectList, child.output)
val resultVars = exprs.map(_.genCode(ctx))
val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) {
// subexpression elimination
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)
val genVars = ctx.withSubExprEliminationExprs(subExprs.states) {
exprs.map(_.genCode(ctx))
}
(subExprs.codes.mkString("\n"), genVars, subExprs.exprCodesNeedEvaluate)
} else {
("", exprs.map(_.genCode(ctx)), Seq.empty)
}

// Evaluation of non-deterministic expressions can't be deferred.
val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
s"""
Copy link
Contributor

@cloud-fan cloud-fan Oct 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you give an example of this part of generated code for the query in https://issues.apache.org/jira/browse/SPARK-32989 ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I will update to the PR description.

Copy link
Contributor

@LuciferYang LuciferYang Oct 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

manual testing examples like SPARK-32989 , the performance is much better than before :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for testing! It is late in my timezone, I will update the generated code tomorrow.

|// common sub-expressions
|${evaluateVariables(localValInputs)}
|$subExprsCode
|${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))}
|${consume(ctx, resultVars)}
""".stripMargin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
}
}
// this input data will fail to read middle way.
val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j)
val input = spark.range(15).select(failingUdf('id).as('i)).select('i, -'i as 'j)
Copy link
Member Author

@viirya viirya Oct 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

failingUdf is evaluated twice for each row previously. Now it is only once. So we need to increase range to make it throw exception as before.

val e3 = intercept[SparkException] {
input.write.format(cls.getName).option("path", path).mode("overwrite").save()
}
Expand Down