Skip to content

[SPARK-22719][SQL]Refactor ConstantPropagation #19912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,49 +64,89 @@ object ConstantFolding extends Rule[LogicalPlan] {
* }}}
*
* Approach used:
* - Start from AND operator as the root
* - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they
* don't have a `NOT` or `OR` operator in them
* - Populate a mapping of attribute => constant value by looking at all the equals predicates
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
* in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find {
case _: Not | _: Or => true
case _ => false
}.isDefined

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f: Filter => f transformExpressionsUp {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be the same effect if we turn transformExpressionsUp to transformExpressionsDown?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this doesn't work because it still replaces the replaced And.

case and: And =>
val conjunctivePredicates =
splitConjunctivePredicates(and)
.filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe])
.filterNot(expr => containsNonConjunctionPredicates(expr))

val equalityPredicates = conjunctivePredicates.collect {
case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e)
case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e)
case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e)
case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e)
}
case f: Filter =>
val (newCondition, _) = traverse(f.condition, replaceChildren = true)
if (newCondition.isDefined) {
f.copy(condition = newCondition.get)
} else {
f
}
}

val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicates = equalityPredicates.map(_._2).toSet
type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)]

def replaceConstants(expression: Expression) = expression transform {
case a: AttributeReference =>
constantsMap.get(a) match {
case Some(literal) => literal
case None => a
}
/**
* Traverse a condition as a tree and replace attributes with constant values.
* - On matching [[And]], recursively traverse each children and get propagated mappings.
* If the current node is not child of another [[And]], replace all occurrences of the
* attributes with the corresponding constant values.
* - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping
* of attribute => constant.
* - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping.
* - Otherwise, stop traversal and propagate empty mapping.
* @param condition condition to be traversed
* @param replaceChildren whether to replace attributes with constant values in children
* @return A tuple including:
* 1. Option[Expression]: optional changed condition after traversal
* 2. EqualityPredicates: propagated mapping of attribute => constant
*/
private def traverse(condition: Expression, replaceChildren: Boolean)
: (Option[Expression], EqualityPredicates) =
condition match {
case e @ EqualTo(left: AttributeReference, right: Literal) => (None, Seq(((left, right), e)))
case e @ EqualTo(left: Literal, right: AttributeReference) => (None, Seq(((right, left), e)))
case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
(None, Seq(((left, right), e)))
case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
(None, Seq(((right, left), e)))
case a: And =>
val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false)
val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false)
val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
} else {
if (newLeft.isDefined || newRight.isDefined) {
Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
} else {
None
}
}

and transform {
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants(e)
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants(e)
(newSelf, equalityPredicates)
case o: Or =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newLeft, _) = traverse(o.left, replaceChildren = true)
val (newRight, _) = traverse(o.right, replaceChildren = true)
val newSelf = if (newLeft.isDefined || newRight.isDefined) {
Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right))))
} else {
None
}
(newSelf, Seq.empty)
case n: Not =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newChild, _) = traverse(n.child, replaceChildren = true)
(newChild.map(Not), Seq.empty)
case _ => (None, Seq.empty)
}

private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
: Expression = {
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicates = equalityPredicates.map(_._2).toSet
def replaceConstants0(expression: Expression) = expression transform {
case a: AttributeReference => constantsMap.getOrElse(a, a)
}
condition transform {
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
}
}
}
Expand Down