-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-33092][SQL] Support subexpression elimination in ProjectExec #29975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,8 +90,13 @@ case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) | |
* @param codes Strings representing the codes that evaluate common subexpressions. | ||
* @param states Foreach expression that is participating in subexpression elimination, | ||
* the state to use. | ||
* @param exprCodesNeedEvaluate Some expression codes that need to be evaluated before | ||
* calling common subexpressions. | ||
*/ | ||
case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState]) | ||
case class SubExprCodes( | ||
codes: Seq[String], | ||
states: Map[Expression, SubExprEliminationState], | ||
exprCodesNeedEvaluate: Seq[ExprCode]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just a suggestion: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
/** | ||
* The main information about a new added function. | ||
|
@@ -1044,7 +1049,7 @@ class CodegenContext extends Logging { | |
// Get all the expressions that appear at least twice and set up the state for subexpression | ||
// elimination. | ||
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) | ||
val commonExprVals = commonExprs.map(_.head.genCode(this)) | ||
lazy val commonExprVals = commonExprs.map(_.head.genCode(this)) | ||
|
||
lazy val nonSplitExprCode = { | ||
commonExprs.zip(commonExprVals).map { case (exprs, eval) => | ||
|
@@ -1055,10 +1060,17 @@ class CodegenContext extends Logging { | |
} | ||
} | ||
|
||
val codes = if (commonExprVals.map(_.code.length).sum > SQLConf.get.methodSplitThreshold) { | ||
val inputVarsForAllFuncs = commonExprs.map { expr => | ||
getLocalInputVariableValues(this, expr.head).toSeq | ||
} | ||
// For some operators, they do not require all its child's outputs to be evaluated in advance. | ||
// Instead it only early evaluates part of outputs, for example, `ProjectExec` only early | ||
// evaluate the outputs used more than twice. So we need to extract these variables used by | ||
// subexpressions and evaluate them before subexpressions. | ||
val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Could you leave some comments here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. |
||
val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr.head) | ||
(inputVars.toSeq, exprCodes.toSeq) | ||
}.unzip | ||
|
||
val splitThreshold = SQLConf.get.methodSplitThreshold | ||
val codes = if (commonExprVals.map(_.code.length).sum > splitThreshold) { | ||
if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { | ||
commonExprs.zipWithIndex.map { case (exprs, i) => | ||
val expr = exprs.head | ||
|
@@ -1109,7 +1121,7 @@ class CodegenContext extends Logging { | |
} else { | ||
nonSplitExprCode | ||
} | ||
SubExprCodes(codes, localSubExprEliminationExprs.toMap) | ||
SubExprCodes(codes, localSubExprEliminationExprs.toMap, exprCodesNeedEvaluate.flatten) | ||
} | ||
|
||
/** | ||
|
@@ -1732,15 +1744,23 @@ object CodeGenerator extends Logging { | |
} | ||
|
||
/** | ||
* Extracts all the input variables from references and subexpression elimination states | ||
* for a given `expr`. This result will be used to split the generated code of | ||
* expressions into multiple functions. | ||
* This methods returns two values in a Tuple. | ||
* | ||
* First value: Extracts all the input variables from references and subexpression | ||
* elimination states for a given `expr`. This result will be used to split the | ||
* generated code of expressions into multiple functions. | ||
* | ||
* Second value: Returns the set of `ExprCodes`s which are necessary codes before | ||
* evaluating subexpressions. | ||
*/ | ||
def getLocalInputVariableValues( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you describe what's the second value of the returned value in the code comment above? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. |
||
ctx: CodegenContext, | ||
expr: Expression, | ||
subExprs: Map[Expression, SubExprEliminationState] = Map.empty): Set[VariableValue] = { | ||
subExprs: Map[Expression, SubExprEliminationState] = Map.empty) | ||
: (Set[VariableValue], Set[ExprCode]) = { | ||
val argSet = mutable.Set[VariableValue]() | ||
val exprCodesNeedEvaluate = mutable.Set[ExprCode]() | ||
|
||
if (ctx.INPUT_ROW != null) { | ||
argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow]) | ||
} | ||
|
@@ -1761,16 +1781,21 @@ object CodeGenerator extends Logging { | |
|
||
case ref: BoundReference if ctx.currentVars != null && | ||
ctx.currentVars(ref.ordinal) != null => | ||
val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal) | ||
collectLocalVariable(value) | ||
collectLocalVariable(isNull) | ||
val exprCode = ctx.currentVars(ref.ordinal) | ||
// If the referred variable is not evaluated yet. | ||
if (exprCode.code != EmptyBlock) { | ||
exprCodesNeedEvaluate += exprCode.copy() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need this copy? A unnecessary copy can happen if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copying There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
AFAIK when we empty the code, we also do a copy, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I was wrong, the copy here is necessary. |
||
exprCode.code = EmptyBlock | ||
} | ||
collectLocalVariable(exprCode.value) | ||
collectLocalVariable(exprCode.isNull) | ||
|
||
case e => | ||
stack.pushAll(e.children) | ||
} | ||
} | ||
|
||
argSet.toSet | ||
(argSet.toSet, exprCodesNeedEvaluate.toSet) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are already There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, my bad. Yes. |
||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,10 +66,23 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) | |
|
||
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { | ||
val exprs = bindReferences[Expression](projectList, child.output) | ||
val resultVars = exprs.map(_.genCode(ctx)) | ||
val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) { | ||
// subexpression elimination | ||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) | ||
val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { | ||
exprs.map(_.genCode(ctx)) | ||
} | ||
(subExprs.codes.mkString("\n"), genVars, subExprs.exprCodesNeedEvaluate) | ||
} else { | ||
("", exprs.map(_.genCode(ctx)), Seq.empty) | ||
} | ||
|
||
// Evaluation of non-deterministic expressions can't be deferred. | ||
val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) | ||
s""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you give an example of this part of generated code for the query in https://issues.apache.org/jira/browse/SPARK-32989 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. I will update to the PR description. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. manual testing examples like SPARK-32989 , the performance is much better than before :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for testing! It is late in my timezone, I will update the generated code tomorrow. |
||
|// common sub-expressions | ||
|${evaluateVariables(localValInputs)} | ||
|$subExprsCode | ||
|${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))} | ||
|${consume(ctx, resultVars)} | ||
""".stripMargin | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -268,7 +268,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS | |
} | ||
} | ||
// this input data will fail to read middle way. | ||
val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j) | ||
val input = spark.range(15).select(failingUdf('id).as('i)).select('i, -'i as 'j) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
val e3 = intercept[SparkException] { | ||
input.write.format(cls.getName).option("path", path).mode("overwrite").save() | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change needed to support subexpr elimination in
ProjectExec
? What I'm interested in is that why we didn't need this change when supporting it inHashAggregateExec
.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is needed.
ProjectExec
doesn't require all its child's outputs to be evaluated in advance. Instead it only early evaluates the outputs used more than twice (deferring evaluation). So we need to extract these variables used by subexpressions and evaluate them before subexpressions. InHashAggregateExec
we don't need to consider that. The simplest way is to evaluate all child's outputs, of course.