Skip to content

[SPARK-47511][SQL] Canonicalize With expressions by re-assigning IDs #45649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,84 @@ case class With(child: Expression, defs: Seq[CommonExpressionDef])
newChildren: IndexedSeq[Expression]): Expression = {
copy(child = newChildren.head, defs = newChildren.tail.map(_.asInstanceOf[CommonExpressionDef]))
}

/**
* Builds a map of ids (originally assigned ids -> canonicalized ids) to be re-assigned during
* canonicalization.
*/
private lazy val canonicalizationIdMap: Map[Long, Long] = {
// Start numbering after taking into account all nested With expression id maps.
var currentId = child.map {
case w: With => w.canonicalizationIdMap.size
case _ => 0L
}.sum
defs.map { d =>
currentId += 1
d.id.id -> currentId
}.toMap
}

/**
* Canonicalize by re-assigning all ids in CommonExpressionRef's and CommonExpressionDef's
* starting from 0. This uses [[canonicalizationIdMap]], which contains all mappings for
* CommonExpressionDef's defined in this scope.
* Note that this takes into account nested With expressions by sharing a numbering scope (see
* [[canonicalizationIdMap]].
*/
override lazy val canonicalized: Expression = copy(
child = child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
case r: CommonExpressionRef if !r.id.canonicalized =>
r.copy(id = r.id.canonicalize(canonicalizationIdMap))
}.canonicalized,
defs = defs.map {
case d: CommonExpressionDef if !d.id.canonicalized =>
d.copy(id = d.id.canonicalize(canonicalizationIdMap)).canonicalized
.asInstanceOf[CommonExpressionDef]
case d => d.canonicalized.asInstanceOf[CommonExpressionDef]
}
)
}

object With {
/**
* Helper function to create a [[With]] statement with an arbitrary number of common expressions.
* Note that the number of arguments in `commonExprs` should be the same as the number of
* arguments taken by `replaced`.
*
* @param commonExprs list of common expressions
* @param replaced closure that defines the common expressions in the main expression
* @return the expression returned by replaced with its arguments replaced by commonExprs in order
*/
def apply(commonExprs: Expression*)(replaced: Seq[Expression] => Expression): With = {
val commonExprDefs = commonExprs.map(CommonExpressionDef(_))
val commonExprRefs = commonExprDefs.map(new CommonExpressionRef(_))
With(replaced(commonExprRefs), commonExprDefs)
}
}

case class CommonExpressionId(id: Long = CommonExpressionId.newId, canonicalized: Boolean = false) {
Copy link
Contributor

@cloud-fan cloud-fan Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In QueryPlan we have this

  /**
   * A private mutable variable to indicate whether this plan is the result of canonicalization.
   * This is used solely for making sure we wouldn't execute a canonicalized plan.
   * See [[canonicalized]] on how this is set.
   */
  @transient private var _isCanonicalizedPlan: Boolean = false

  protected def isCanonicalizedPlan: Boolean = _isCanonicalizedPlan

Shall we do the same in With?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what replacement are you suggesting? the canonicalized parameter in CommonExpressionId is used to distinguish between an ID that was assigned initially by curId.getAndIncrement() or newly assigned by canonicalization. it is more of a property of the ID itself than the With/CommonExpressionDef/CommonExpressionRef operators? also, if we haven't called .canonicalized on the outermost With, it is possible to have a With expression that contains some canonicalized IDs but some non-canonicalized IDs

/**
* Re-assign to a canonicalized id based on idMap. If it is not found in idMap, the id is defined
* in an outer scope and will be replaced later.
*/
def canonicalize(idMap: Map[Long, Long]): CommonExpressionId = {
if (idMap.contains(id)) {
copy(id = idMap(id), canonicalized = true)
} else {
this
}
}
}

object CommonExpressionId {
private[sql] val curId = new java.util.concurrent.atomic.AtomicLong()
def newId: Long = curId.getAndIncrement()
}

/**
* A wrapper of common expression to carry the id.
*/
case class CommonExpressionDef(child: Expression, id: Long = CommonExpressionDef.newId)
case class CommonExpressionDef(child: Expression, id: CommonExpressionId = new CommonExpressionId())
extends UnaryExpression with Unevaluable {
override def dataType: DataType = child.dataType
override protected def withNewChildInternal(newChild: Expression): Expression =
Expand All @@ -51,13 +123,8 @@ case class CommonExpressionDef(child: Expression, id: Long = CommonExpressionDef
* A reference to the common expression by its id. Only resolved common expressions can be
* referenced, so that we can determine the data type and nullable of the reference node.
*/
case class CommonExpressionRef(id: Long, dataType: DataType, nullable: Boolean)
case class CommonExpressionRef(id: CommonExpressionId, dataType: DataType, nullable: Boolean)
extends LeafExpression with Unevaluable {
def this(exprDef: CommonExpressionDef) = this(exprDef.id, exprDef.dataType, exprDef.nullable)
override val nodePatterns: Seq[TreePattern] = Seq(COMMON_EXPR_REF)
}

object CommonExpressionDef {
private[sql] val curId = new java.util.concurrent.atomic.AtomicLong()
def newId: Long = curId.getAndIncrement()
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{COALESCE, NULL_CHECK, TreePattern}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -158,11 +159,15 @@ case class NullIf(left: Expression, right: Expression, replacement: Expression)
extends RuntimeReplaceable with InheritAnalysisRules {

def this(left: Expression, right: Expression) = {
this(left, right, {
val commonExpr = CommonExpressionDef(left)
val ref = new CommonExpressionRef(commonExpr)
With(If(EqualTo(ref, right), Literal.create(null, left.dataType), ref), Seq(commonExpr))
})
this(left, right,
if (SQLConf.get.getConf(SQLConf.REPLACE_NULLIF_USING_WITH_EXPR)) {
With(left) { case Seq(ref) =>
If(EqualTo(ref, right), Literal.create(null, left.dataType), ref)
}
} else {
If(EqualTo(left, right), Literal.create(null, left.dataType), left)
}
)
}

override def parameters: Seq[Expression] = Seq(left, right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,18 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
// Rewrite nested With expressions first
val child = rewriteWithExprAndInputPlans(w.child, inputPlans)
val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans))
val refToExpr = mutable.HashMap.empty[Long, Expression]
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]
val childProjections = Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias])

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
if (child.containsPattern(COMMON_EXPR_REF)) {
throw SparkException.internalError(
"Common expression definition cannot reference other Common expression definitions")
}
if (id.canonicalized) {
throw SparkException.internalError(
"Cannot rewrite canonicalized Common expression definitions")
}

if (CollapseProject.isCheap(child)) {
refToExpr(id) = child
Expand Down Expand Up @@ -114,6 +118,10 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
if (!refToExpr.contains(ref.id)) {
throw SparkException.internalError("Undefined common expression id " + ref.id)
}
if (ref.id.canonicalized) {
throw SparkException.internalError(
"Cannot rewrite canonicalized Common expression references")
}
refToExpr(ref.id)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3372,6 +3372,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val REPLACE_NULLIF_USING_WITH_EXPR =
buildConf("spark.databricks.sql.replaceNullIfUsingWithExpr")
.internal()
.doc("When true, NullIf expressions are rewritten using With expressions to avoid " +
"expression duplication.")
.booleanConf
.createWithDefault(true)

val USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES =
buildConf("spark.sql.defaultColumn.useNullsForMissingDefaultValues")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.Range
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.MULTI_COMMUTATIVE_OP_OPT_THRESHOLD
import org.apache.spark.sql.types.{BooleanType, Decimal, DecimalType, IntegerType, LongType, StringType, StructField, StructType, TimestampNTZType, TimestampType}
import org.apache.spark.sql.types.{BooleanType, Decimal, DecimalType, DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampNTZType, TimestampType}

class CanonicalizeSuite extends SparkFunSuite {

Expand Down Expand Up @@ -351,4 +351,107 @@ class CanonicalizeSuite extends SparkFunSuite {
assert(op.canonicalized.toJSON.nonEmpty)
SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, default.toString)
}

test("canonicalization of With expressions with one common expression") {
val expr = Divide(Literal.create(1, IntegerType), AttributeReference("a", IntegerType)())
val common1 = IsNull(With(expr.copy()) { case Seq(expr) =>
If(EqualTo(expr, Literal.create(0.0, DoubleType)), Literal.create(0.0, DoubleType), expr)
})
val common2 = IsNull(With(expr.copy()) { case Seq(expr) =>
If(EqualTo(expr, Literal.create(0.0, DoubleType)), Literal.create(0.0, DoubleType), expr)
})
// Check that canonicalization is consistent across multiple executions.
assert(common1.canonicalized == common2.canonicalized)
// Check that CommonExpressionDef starts ID'ing at 1 and that its child is canonicalized.
assert(common1.canonicalized.exists {
case d: CommonExpressionDef => d.id.id == 1 && d.child == expr.canonicalized
case _ => false
})
// Check that CommonExpressionRef ID corresponds to the def.
assert(common1.canonicalized.exists {
case r: CommonExpressionRef => r.id.id == 1
case _ => false
})
}

test("canonicalization of With expressions with multiple common expressions") {
val expr1 = Divide(Literal.create(1, IntegerType), AttributeReference("a", IntegerType)())
val expr2 = Multiply(Literal.create(2, IntegerType), AttributeReference("a", IntegerType)())
val common1 = With(expr1.copy(), expr2.copy()) { case Seq(expr1, expr2) =>
If(EqualTo(expr1, expr2), expr1, expr2)
}
val common2 = With(expr1.copy(), expr2.copy()) { case Seq(expr1, expr2) =>
If(EqualTo(expr1, expr2), expr1, expr2)
}
// Check that canonicalization is consistent across multiple executions.
assert(common1.canonicalized == common2.canonicalized)
// Check that CommonExpressionDef starts ID'ing at 1 and that its child is canonicalized.
assert(common1.canonicalized.exists {
case d: CommonExpressionDef => d.id.id == 1 && d.child == expr1.canonicalized
case _ => false
})
assert(common1.canonicalized.exists {
case d: CommonExpressionDef => d.id.id == 2 && d.child == expr2.canonicalized
case _ => false
})
// Check that CommonExpressionRef ID corresponds to the def.
assert(common1.canonicalized.exists {
case r: CommonExpressionRef => r.id.id == 1
case _ => false
})
assert(common1.canonicalized.exists {
case r: CommonExpressionRef => r.id.id == 2
case _ => false
})
}

test("canonicalization of With expressions with nested common expressions") {
val expr1 = AttributeReference("a", BooleanType)()
val expr2 = AttributeReference("b", BooleanType)()

val common1 = With(expr1) { case Seq(expr1) =>
Or(With(expr2) { case Seq(expr2) =>
And(EqualTo(expr1, expr2), EqualTo(expr1, expr2))
}, expr1)
}
val common2 = With(expr1) { case Seq(expr1) =>
Or(With(expr2) { case Seq(expr2) =>
And(EqualTo(expr1, expr2), EqualTo(expr1, expr2))
}, expr1)
}
// Check that canonicalization is consistent across multiple executions.
assert(common1.canonicalized == common2.canonicalized)
// Check that CommonExpressionDef starts ID'ing at 1 and that its child is canonicalized.
assert(common1.canonicalized.exists {
case d: CommonExpressionDef => d.id.id == 1 && d.child == expr2.canonicalized
case _ => false
})
assert(common1.canonicalized.exists {
case d: CommonExpressionDef => d.id.id == 2 && d.child == expr1.canonicalized
case _ => false
})
// Check that CommonExpressionRef ID corresponds to the def.
assert(common1.canonicalized.exists {
case r: CommonExpressionRef => r.id.id == 1
case _ => false
})
assert(common1.canonicalized.exists {
case r: CommonExpressionRef => r.id.id == 2
case _ => false
})

val common3 = With(expr1.newInstance()) { case Seq(expr1) =>
Or(With(expr2.newInstance()) { case Seq(expr2) =>
And(EqualTo(expr1, expr2), EqualTo(expr1, expr2))
}, expr1)
}
val common4 = With(expr1.newInstance()) { case Seq(expr1) =>
Or(With(expr2.newInstance()) { case Seq(expr2) =>
And(EqualTo(expr2, expr1), EqualTo(expr2, expr1))
}, expr1)
}
// Check that canonicalization for two different expressions with similar structures is
// different.
assert(common3.canonicalized != common4.canonicalized)
}
}