Skip to content

Commit 18b75d4

Browse files
gengliangwanggatorsmile
authored andcommitted
[SPARK-22719][SQL] Refactor ConstantPropagation
## What changes were proposed in this pull request? The current time complexity of ConstantPropagation is O(n^2), which can be slow when the query is complex. Refactor the implementation with O( n ) time complexity, and some pruning to avoid traversing the whole `Condition` ## How was this patch tested? Unit test. Also simple benchmark test in ConstantPropagationSuite ``` val condition = (1 to 500).map{_ => Rand(0) === Rand(0)}.reduce(And) val query = testRelation .select(columnA) .where(condition) val start = System.currentTimeMillis() (1 to 40).foreach { _ => Optimize.execute(query.analyze) } val end = System.currentTimeMillis() println(end - start) ``` Run time before changes: 18989ms (474ms per loop) Run time after changes: 1275 ms (32ms per loop) Author: Wang Gengliang <ltnwgl@gmail.com> Closes #19912 from gengliangwang/ConstantPropagation.
1 parent f41c0a9 commit 18b75d4

File tree

1 file changed

+73
-33
lines changed

1 file changed

+73
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 73 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -64,49 +64,89 @@ object ConstantFolding extends Rule[LogicalPlan] {
6464
* }}}
6565
*
6666
* Approach used:
67-
* - Start from AND operator as the root
68-
* - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they
69-
* don't have a `NOT` or `OR` operator in them
7067
* - Populate a mapping of attribute => constant value by looking at all the equals predicates
7168
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
7269
* in the AND node.
7370
*/
7471
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
75-
private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find {
76-
case _: Not | _: Or => true
77-
case _ => false
78-
}.isDefined
79-
8072
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
81-
case f: Filter => f transformExpressionsUp {
82-
case and: And =>
83-
val conjunctivePredicates =
84-
splitConjunctivePredicates(and)
85-
.filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe])
86-
.filterNot(expr => containsNonConjunctionPredicates(expr))
87-
88-
val equalityPredicates = conjunctivePredicates.collect {
89-
case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e)
90-
case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e)
91-
case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e)
92-
case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e)
93-
}
73+
case f: Filter =>
74+
val (newCondition, _) = traverse(f.condition, replaceChildren = true)
75+
if (newCondition.isDefined) {
76+
f.copy(condition = newCondition.get)
77+
} else {
78+
f
79+
}
80+
}
9481

95-
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
96-
val predicates = equalityPredicates.map(_._2).toSet
82+
type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)]
9783

98-
def replaceConstants(expression: Expression) = expression transform {
99-
case a: AttributeReference =>
100-
constantsMap.get(a) match {
101-
case Some(literal) => literal
102-
case None => a
103-
}
84+
/**
85+
* Traverse a condition as a tree and replace attributes with constant values.
86+
* - On matching [[And]], recursively traverse each children and get propagated mappings.
87+
* If the current node is not child of another [[And]], replace all occurrences of the
88+
* attributes with the corresponding constant values.
89+
* - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping
90+
* of attribute => constant.
91+
* - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping.
92+
* - Otherwise, stop traversal and propagate empty mapping.
93+
* @param condition condition to be traversed
94+
* @param replaceChildren whether to replace attributes with constant values in children
95+
* @return A tuple including:
96+
* 1. Option[Expression]: optional changed condition after traversal
97+
* 2. EqualityPredicates: propagated mapping of attribute => constant
98+
*/
99+
private def traverse(condition: Expression, replaceChildren: Boolean)
100+
: (Option[Expression], EqualityPredicates) =
101+
condition match {
102+
case e @ EqualTo(left: AttributeReference, right: Literal) => (None, Seq(((left, right), e)))
103+
case e @ EqualTo(left: Literal, right: AttributeReference) => (None, Seq(((right, left), e)))
104+
case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
105+
(None, Seq(((left, right), e)))
106+
case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
107+
(None, Seq(((right, left), e)))
108+
case a: And =>
109+
val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false)
110+
val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false)
111+
val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
112+
val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
113+
Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
114+
replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
115+
} else {
116+
if (newLeft.isDefined || newRight.isDefined) {
117+
Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
118+
} else {
119+
None
120+
}
104121
}
105-
106-
and transform {
107-
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants(e)
108-
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants(e)
122+
(newSelf, equalityPredicates)
123+
case o: Or =>
124+
// Ignore the EqualityPredicates from children since they are only propagated through And.
125+
val (newLeft, _) = traverse(o.left, replaceChildren = true)
126+
val (newRight, _) = traverse(o.right, replaceChildren = true)
127+
val newSelf = if (newLeft.isDefined || newRight.isDefined) {
128+
Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right))))
129+
} else {
130+
None
109131
}
132+
(newSelf, Seq.empty)
133+
case n: Not =>
134+
// Ignore the EqualityPredicates from children since they are only propagated through And.
135+
val (newChild, _) = traverse(n.child, replaceChildren = true)
136+
(newChild.map(Not), Seq.empty)
137+
case _ => (None, Seq.empty)
138+
}
139+
140+
private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
141+
: Expression = {
142+
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
143+
val predicates = equalityPredicates.map(_._2).toSet
144+
def replaceConstants0(expression: Expression) = expression transform {
145+
case a: AttributeReference => constantsMap.getOrElse(a, a)
146+
}
147+
condition transform {
148+
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
149+
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
110150
}
111151
}
112152
}

0 commit comments

Comments
 (0)