Skip to content

make RewriteWithExpression idempotent #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class ProtoToParsedPlanTestSuite
object Helper extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Finish Analysis", Once, ReplaceExpressions) ::
Batch("Rewrite With expression", FixedPoint(10), RewriteWithExpression) :: Nil
Batch("Rewrite With expression", Once, RewriteWithExpression) :: Nil
}
Helper.execute(catalystPlan)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
Batch("Finish Analysis", Once, FinishAnalysis) ::
// We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression
// may produce `With` expressions that need to be rewritten.
Batch("Rewrite With expression", fixedPoint, RewriteWithExpression) ::
Batch("Rewrite With expression", Once, RewriteWithExpression) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CommonExpressionRef, Expression, With}
import org.apache.spark.sql.catalyst.expressions.{Alias, CommonExpressionDef, CommonExpressionRef, Expression, With}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION}
Expand All @@ -35,57 +35,48 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) {
case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
val commonExprs = mutable.ArrayBuffer.empty[Alias]
// `With` can be nested, we should only rewrite the leaf `With` expression, as the outer
// `With` needs to add its own Project, in the next iteration when it becomes leaf.
// This is done via "transform down" and check if the common expression definitions does not
// contain nested `With`.
var newPlan: LogicalPlan = p.transformExpressionsDown {
case With(child, defs) if defs.forall(!_.containsPattern(WITH_EXPRESSION)) =>
val idToCheapExpr = mutable.HashMap.empty[Long, Expression]
val idToNonCheapExpr = mutable.HashMap.empty[Long, Alias]
defs.zipWithIndex.foreach { case (commonExprDef, index) =>
if (CollapseProject.isCheap(commonExprDef.child)) {
idToCheapExpr(commonExprDef.id) = commonExprDef.child
var newChildren = p.children
var newPlan: LogicalPlan = p.transformExpressionsUp {
case With(child, defs) =>
val refToExpr = mutable.HashMap.empty[Long, Expression]
val childProjections = Array.fill(newChildren.size)(mutable.ArrayBuffer.empty[Alias])

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
if (CollapseProject.isCheap(child)) {
refToExpr(id) = child
} else {
// TODO: we should calculate the ref count and also inline the common expression
// if it's ref count is 1.
val alias = Alias(commonExprDef.child, s"_common_expr_$index")()
commonExprs += alias
idToNonCheapExpr(commonExprDef.id) = alias
val childProjectionIndex = newChildren.indexWhere(
c => child.references.subsetOf(c.outputSet)
)
if (childProjectionIndex == -1) {
// When we cannot rewrite the common expressions, force to inline them so that the
// query can still run. This can happen if the join condition contains `With` and
// the common expression references columns from both join sides.
// TODO: things can go wrong if the common expression is nondeterministic. We
// don't fix it for now to match the old buggy behavior when certain
// `RuntimeReplaceable` did not use the `With` expression.
// TODO: we should calculate the ref count and also inline the common expression
// if it's ref count is 1.
refToExpr(id) = child
} else {
val alias = Alias(child, s"_common_expr_$index")()
childProjections(childProjectionIndex) += alias
refToExpr(id) = alias.toAttribute
}
}
}

child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
case ref: CommonExpressionRef =>
idToCheapExpr.getOrElse(ref.id, idToNonCheapExpr(ref.id).toAttribute)
newChildren = newChildren.zip(childProjections).map { case (child, projections) =>
if (projections.nonEmpty) {
Project(child.output ++ projections, child)
} else {
child
}
}
}

var exprsToAdd = commonExprs.toSeq
val newChildren = newPlan.children.map { child =>
val (newExprs, others) = exprsToAdd.partition(_.references.subsetOf(child.outputSet))
exprsToAdd = others
if (newExprs.nonEmpty) {
Project(child.output ++ newExprs, child)
} else {
child
}
}

if (exprsToAdd.nonEmpty) {
// When we cannot rewrite the common expressions, force to inline them so that the query
// can still run. This can happen if the join condition contains `With` and the common
// expression references columns from both join sides.
// TODO: things can go wrong if the common expression is nondeterministic. We don't fix
// it for now to match the old buggy behavior when certain `RuntimeReplaceable`
// did not use the `With` expression.
val attrToExpr = AttributeMap(exprsToAdd.map { alias =>
alias.toAttribute -> alias.child
})
newPlan = newPlan.transformExpressionsUp {
case a: Attribute => attrToExpr.getOrElse(a, a)
}
child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
case ref: CommonExpressionRef => refToExpr(ref.id)
}
}

newPlan = newPlan.withNewChildren(newChildren)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.IntegerType
class RewriteWithExpressionSuite extends PlanTest {

object Optimizer extends RuleExecutor[LogicalPlan] {
val batches = Batch("Rewrite With expression", FixedPoint(10), RewriteWithExpression) :: Nil
val batches = Batch("Rewrite With expression", Once, RewriteWithExpression) :: Nil
}

private val testRelation = LocalRelation($"a".int, $"b".int)
Expand Down