Skip to content

[SPARK-27604][SQL] Enhance constant propagation #24553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import scala.collection.mutable.{ArrayBuffer, Stack}
import scala.collection.mutable.{ArrayBuffer, Map, Stack}

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -55,119 +55,138 @@ object ConstantFolding extends Rule[LogicalPlan] {
}

/**
* Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding
* Substitutes [[Expression Expressions]] which can be statically evaluated with their corresponding
* value in conjunctive [[Expression Expressions]]
* eg.
* {{{
* SELECT * FROM table WHERE i = 5 AND j = i + 3
* ==> SELECT * FROM table WHERE i = 5 AND j = 8
* i = 5 AND j = i + 3 => ... i = 5 AND j = 8
* abs(i) = 5 AND j <= abs(i) + 3 => ... abs(i) = 5 AND j <= 8
* }}}
*
* Approach used:
* - Populate a mapping of attribute => constant value by looking at all the equals predicates
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
* in the AND node.
* - Populate a mapping of expression => constant value by looking at all the deterministic equals
* predicates
* - Using this mapping, replace occurrence of the expressions with the corresponding constant
* values in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f: Filter =>
val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true)
if (newCondition.isDefined) {
f.copy(condition = newCondition.get)
} else {
f
}
}
case f: Filter => f.mapExpressions(e => traverse(e, Some(false))._1)

type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)]
// Constant propagation can remove equalities from [[Join]] conditions as they don't add any
// real value, but [[ExtractEquiJoinKeys]] is not prepared to handle that situation.
// SPARK-30598 can solve this issue.
case j: Join => j

case o => o.mapExpressions(e => traverse(e, None)._1)
}

/**
* Traverse a condition as a tree and replace attributes with constant values.
* - On matching [[And]], recursively traverse each children and get propagated mappings.
* If the current node is not child of another [[And]], replace all occurrences of the
* attributes with the corresponding constant values.
* - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping
* of attribute => constant.
* - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping.
* - Otherwise, stop traversal and propagate empty mapping.
* @param condition condition to be traversed
* @param replaceChildren whether to replace attributes with constant values in children
* @param nullIsFalse whether a boolean expression result can be considered to false e.g. in the
* case of `WHERE e`, null result of expression `e` means the same as if it
* resulted false
* Traverse a condition as a tree and replace expressions with constant values.
* - On matching [[EqualTo]] or [[EqualNullSafe]], recursively traverse left and right children
* and propagate the expression => constant mapping.
* - On matching [[And]], recursively traverse left subtree and collect propagated mapping to
* replace expressions to constants in right subtree. Then recursively traverse right subtree
* and collect propagated mapping to replace expressions to constants in left subtree.
* - Otherwise, recursively traverse each children, propagate empty mapping.
* - During expression tree traversal tracks a boolean context that controls if constant
* propagation of a nullable expression can be safely applied.
* - E.g. in the case of `WHERE a = c AND f(a)` or `IF(a = c AND f(a), ..., ...)` where `a` is a
* nullable expression and `c` is a constant the null result of `a = c AND f(a)` means the
* same as if it resulted `false` therefore constant propagation can be safely applied (`a = c
* AND f(a)` => `a = c AND f(c)`). This context is represented by `Some(False)`.
* - In the case of `SELECT a = c AND f(a)` the `null` result really means `null`. In this
* context constant propagation can't be applied safely. This context is represented by
* `None`.
* - There is also a 3rd context due to an enclosing `Not` in which the context flips. E.g.
* constant propagation can't be applied on `WHERE NOT(a = c AND f(a))` but can be again on
* `WHERE NOT(IF(..., NOT(a = c AND f(a)), ...)`. This context is represented by `Some(True)`.
* @param expression expression to be traversed
* @param nullValue optional boolean that a null boolean expression result can be considered to
* @return A tuple including:
* 1. Option[Expression]: optional changed condition after traversal
* 2. EqualityPredicates: propagated mapping of attribute => constant
* 1. Expression: possibly changed expression after traversal
* 2. Map[Expression, Literal]: propagated mapping of expression => constant
*/
private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean)
: (Option[Expression], EqualityPredicates) =
condition match {
case e @ EqualTo(left: AttributeReference, right: Literal)
if safeToReplace(left, nullIsFalse) =>
(None, Seq(((left, right), e)))
case e @ EqualTo(left: Literal, right: AttributeReference)
if safeToReplace(right, nullIsFalse) =>
(None, Seq(((right, left), e)))
case e @ EqualNullSafe(left: AttributeReference, right: Literal)
if safeToReplace(left, nullIsFalse) =>
(None, Seq(((left, right), e)))
case e @ EqualNullSafe(left: Literal, right: AttributeReference)
if safeToReplace(right, nullIsFalse) =>
(None, Seq(((right, left), e)))
case a: And =>
val (newLeft, equalityPredicatesLeft) =
traverse(a.left, replaceChildren = false, nullIsFalse)
val (newRight, equalityPredicatesRight) =
traverse(a.right, replaceChildren = false, nullIsFalse)
val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
} else {
if (newLeft.isDefined || newRight.isDefined) {
Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
} else {
None
}
private def traverse(
expression: Expression,
nullValue: Option[Boolean] = None): (Expression, Map[Expression, Literal]) =
expression match {
case et @ EqualTo(left, right: Literal) if safeToReplace(left, nullValue) =>
(et.mapChildren(traverse(_)._1), Map(left.canonicalized -> right))
case et @ EqualTo(left: Literal, right) if safeToReplace(right, nullValue) =>
(et.mapChildren(traverse(_)._1), Map(right.canonicalized -> left))
case ens @ EqualNullSafe(left, right: Literal) if safeToReplace(left, nullValue) =>
(ens.mapChildren(traverse(_)._1), Map(left.canonicalized -> right))
case ens @ EqualNullSafe(left: Literal, right) if safeToReplace(right, nullValue) =>
(ens.mapChildren(traverse(_)._1), Map(right.canonicalized -> left))
case a @ And(left, right) =>
val (newLeft, equalityPredicatesLeft) = traverse(left, nullValue)
val replacedRight = replaceConstants(right, equalityPredicatesLeft)
val (replacedNewRight, equalityPredicatesRight) = traverse(replacedRight, nullValue)
val replacedNewLeft = replaceConstants(newLeft, equalityPredicatesRight)
val newAnd = a.withNewChildren(Seq(replacedNewLeft, replacedNewRight))
(newAnd, equalityPredicatesLeft ++= equalityPredicatesRight)
case o: Or => (o.mapChildren(traverse(_, nullValue)._1), Map.empty)
case n: Not => (n.mapChildren(traverse(_, nullValue.map(!_))._1), Map.empty)
case i @ If(predicate, trueValue, falseValue) =>
val newPredicate = traverse(predicate, Some(false))._1
val newTrueValue = traverse(trueValue, nullValue)._1
val newFalseValue = traverse(falseValue, nullValue)._1
val newIf = i.withNewChildren(Seq(newPredicate, newTrueValue, newFalseValue))
(newIf, Map.empty)
case cw @ CaseWhen(branches, elseValue) =>
val newBranches = branches.flatMap {
case (w, t) => Seq(traverse(w, Some(false))._1, traverse(t, nullValue)._1)
}
(newSelf, equalityPredicates)
case o: Or =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newLeft, _) = traverse(o.left, replaceChildren = true, nullIsFalse)
val (newRight, _) = traverse(o.right, replaceChildren = true, nullIsFalse)
val newSelf = if (newLeft.isDefined || newRight.isDefined) {
Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right))))
} else {
None
}
(newSelf, Seq.empty)
case n: Not =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false)
(newChild.map(Not), Seq.empty)
case _ => (None, Seq.empty)
val newElseValue = elseValue.map(traverse(_, nullValue)._1)
val newCaseWhen = cw.withNewChildren(newBranches ++ newElseValue)
(newCaseWhen, Map.empty)
case af @ ArrayFilter(argument, lf: LambdaFunction) =>
val newArgument = traverse(argument, nullValue)._1
val newLF: LambdaFunction = traverseLambdaFunction(lf, false)
val newArrayFilter = af.withNewChildren(Seq(newArgument, newLF))
(newArrayFilter, Map.empty)
case ae @ ArrayExists(argument, lf: LambdaFunction) =>
val newArgument = traverse(argument, nullValue)._1
val newLF: LambdaFunction = traverseLambdaFunction(lf,
SQLConf.get.getConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC))
val newArrayExists = ae.withNewChildren(Seq(newArgument, newLF))
(newArrayExists, Map.empty)
case mf @ MapFilter(argument, lf: LambdaFunction) =>
val newArgument = traverse(argument, nullValue)._1
val newLF: LambdaFunction = traverseLambdaFunction(lf, false)
val newMapFilter = mf.withNewChildren(Seq(newArgument, newLF))
(newMapFilter, Map.empty)

// Actually most of the expressions could propagate nullValue safely.
// We use these few in tests.
case a: Alias => (a.mapChildren(traverse(_, nullValue)._1), Map.empty)
case ca: CreateArray => (ca.mapChildren(traverse(_, nullValue)._1), Map.empty)
case gai: GetArrayItem => (gai.mapChildren(traverse(_, nullValue)._1), Map.empty)
case cm: CreateMap => (cm.mapChildren(traverse(_, nullValue)._1), Map.empty)
case cmv: GetMapValue => (cmv.mapChildren(traverse(_, nullValue)._1), Map.empty)

// Stay on the safe side and don't propagate nullValue.
case o => (o.mapChildren(traverse(_)._1), Map.empty)
}

// We need to take into account if an attribute is nullable and the context of the conjunctive
// expression. E.g. `SELECT * FROM t WHERE NOT(c = 1 AND c + 1 = 1)` where attribute `c` can be
// substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable then the enclosing
// NOT prevents us to do the substitution as NOT flips the context (`nullIsFalse`) of what a
// null result of the enclosed expression means.
private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) =
!ar.nullable || nullIsFalse

private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
: Expression = {
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicates = equalityPredicates.map(_._2).toSet
def replaceConstants0(expression: Expression) = expression transform {
case a: AttributeReference => constantsMap.getOrElse(a, a)
}
condition transform {
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
}
private def traverseLambdaFunction(lf: LambdaFunction, threeValuedLogic: Boolean) = {
val newFunction = traverse(lf.function, if (threeValuedLogic) None else Some(false))._1
lf.withNewChildren(newFunction +: lf.arguments).asInstanceOf[LambdaFunction]
}

private def safeToReplace(expression : Expression, nullValue: Option[Boolean]) =
!expression.foldable && expression.deterministic &&
(!expression.nullable || nullValue.contains(false))

private def replaceConstants(expression: Expression, constants: Map[Expression, Literal]) =
if (constants.isEmpty) {
expression
} else {
expression transform {
case e if constants.contains(e.canonicalized) => constants(e.canonicalized)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf

/**
* Unit tests for constant propagation in expressions.
Expand All @@ -40,12 +41,13 @@ class ConstantPropagationSuite extends PlanTest {
BooleanSimplification) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int.notNull)
val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int.notNull, 'e.int.notNull)

private val columnA = 'a
private val columnB = 'b
private val columnC = 'c
private val columnD = 'd
private val columnE = 'e

test("basic test") {
val query = testRelation
Expand Down Expand Up @@ -154,13 +156,11 @@ class ConstantPropagationSuite extends PlanTest {

test("conflicting equality predicates") {
val query = testRelation
.select(columnA)
.where(
columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3)))
.analyze

val correctAnswer = testRelation
.select(columnA)
.where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze
val correctAnswer = testRelation.where(Literal.FalseLiteral)

comparePlans(Optimize.execute(query.analyze), correctAnswer)
}
Expand All @@ -186,4 +186,94 @@ class ConstantPropagationSuite extends PlanTest {
.analyze
comparePlans(Optimize.execute(query2), correctAnswer2)
}

test("Constant propagation in conflicting equalities") {
val query = testRelation
.select(columnA)
.where(columnA === Literal(1) && columnA === Literal(2))
.analyze
val correctAnswer = testRelation
.select(columnA)
.where(Literal.FalseLiteral)
.analyze
comparePlans(Optimize.execute(query), correctAnswer)
}

test("Enhanced constant propagation") {
def testSelect(expression: Expression, expected: Expression): Unit = {
val plan = testRelation.select(expression.as("x")).analyze
val expectedPlan = testRelation.select(expected.as("x")).analyze
comparePlans(Optimize.execute(plan), expectedPlan)
}

def testFilter(expression: Expression, expected: Expression): Unit = {
val plan = testRelation.select(columnA).where(expression).analyze
val expectedPlan = testRelation.select(columnA).where(expected).analyze
comparePlans(Optimize.execute(plan), expectedPlan)
}

val nullable =
abs(columnA) === Literal(1) && columnB === Literal(1) && abs(columnA) <= columnB
val reducedNullable = abs(columnA) === Literal(1) && columnB === Literal(1)

val nonNullable =
abs(columnD) === Literal(1) && columnE === Literal(1) && abs(columnD) <= columnE
val reducedNonNullable = abs(columnD) === Literal(1) && columnE === Literal(1)

val expression = nullable || nonNullable
val partlyReduced = nullable || reducedNonNullable
val reduced = reducedNullable || reducedNonNullable

val simplifiedNegatedNullable =
abs(columnA) =!= Literal(1) || columnB =!= Literal(1) || abs(columnA) > columnB
val reducedSimplifiedNegatedNullable = abs(columnA) =!= Literal(1) || columnB =!= Literal(1)

val reducedSimplifiedNegatedNonNullable = abs(columnD) =!= Literal(1) || columnE =!= Literal(1)

val partlyReducedSimplifiedNegated =
simplifiedNegatedNullable && reducedSimplifiedNegatedNonNullable
val reducedSimplifiedNegated =
reducedSimplifiedNegatedNullable && reducedSimplifiedNegatedNonNullable

testSelect(expression, partlyReduced)
testSelect(If(expression, expression, expression),
If(reduced, partlyReduced, partlyReduced))
testSelect(CaseWhen(Seq((expression, expression)), expression),
CaseWhen(Seq((reduced, partlyReduced)), partlyReduced))
testSelect(ArrayFilter(CreateArray(Seq(expression)), LambdaFunction(expression, Nil)),
ArrayFilter(CreateArray(Seq(partlyReduced)), LambdaFunction(reduced, Nil)))
Seq(true, false).foreach { tvl =>
withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key -> s"$tvl") {
testSelect(ArrayExists(CreateArray(Seq(expression)), LambdaFunction(expression, Nil)),
ArrayExists(CreateArray(Seq(partlyReduced)),
LambdaFunction(if (tvl) partlyReduced else reduced, Nil)))
}
}
testSelect(MapFilter(CreateMap(Seq(expression, expression)), LambdaFunction(expression, Nil)),
MapFilter(CreateMap(Seq(partlyReduced, partlyReduced)), LambdaFunction(reduced, Nil)))
testSelect(Not(If(expression, Not(expression), Not(expression))),
Not(If(reduced, partlyReducedSimplifiedNegated, partlyReducedSimplifiedNegated)))

testFilter(expression, reduced)
testFilter(If(expression, expression, expression),
If(reduced, reduced, reduced))
testFilter(CaseWhen(Seq((expression, expression)), expression),
CaseWhen(Seq((reduced, reduced)), reduced))
testFilter(
GetArrayItem(ArrayFilter(CreateArray(Seq(expression)), LambdaFunction(expression, Nil)), 1),
GetArrayItem(ArrayFilter(CreateArray(Seq(reduced)), LambdaFunction(reduced, Nil)), 1))
Seq(true, false).foreach { tvl =>
withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key -> s"$tvl") {
testFilter(ArrayExists(CreateArray(Seq(expression)), LambdaFunction(expression, Nil)),
ArrayExists(CreateArray(Seq(reduced)),
LambdaFunction(if (tvl) partlyReduced else reduced, Nil)))
}
}
testFilter(
GetMapValue(MapFilter(CreateMap(Seq(expression, expression)),
LambdaFunction(expression, Nil)), true),
GetMapValue(MapFilter(CreateMap(Seq(reduced, reduced)), LambdaFunction(reduced, Nil)), true))
testFilter(Not(If(expression, Not(expression), Not(expression))),
Not(If(reduced, reducedSimplifiedNegated, reducedSimplifiedNegated)))
}
}
Loading