@@ -64,49 +64,89 @@ object ConstantFolding extends Rule[LogicalPlan] {
64
64
* }}}
65
65
*
66
66
* Approach used:
67
- * - Start from AND operator as the root
68
- * - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they
69
- * don't have a `NOT` or `OR` operator in them
70
67
* - Populate a mapping of attribute => constant value by looking at all the equals predicates
71
68
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
72
69
* in the AND node.
73
70
*/
74
71
object ConstantPropagation extends Rule [LogicalPlan ] with PredicateHelper {
75
- private def containsNonConjunctionPredicates (expression : Expression ): Boolean = expression.find {
76
- case _ : Not | _ : Or => true
77
- case _ => false
78
- }.isDefined
79
-
80
72
def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
81
- case f : Filter => f transformExpressionsUp {
82
- case and : And =>
83
- val conjunctivePredicates =
84
- splitConjunctivePredicates(and)
85
- .filter(expr => expr.isInstanceOf [EqualTo ] || expr.isInstanceOf [EqualNullSafe ])
86
- .filterNot(expr => containsNonConjunctionPredicates(expr))
87
-
88
- val equalityPredicates = conjunctivePredicates.collect {
89
- case e @ EqualTo (left : AttributeReference , right : Literal ) => ((left, right), e)
90
- case e @ EqualTo (left : Literal , right : AttributeReference ) => ((right, left), e)
91
- case e @ EqualNullSafe (left : AttributeReference , right : Literal ) => ((left, right), e)
92
- case e @ EqualNullSafe (left : Literal , right : AttributeReference ) => ((right, left), e)
93
- }
73
+ case f : Filter =>
74
+ val (newCondition, _) = traverse(f.condition, replaceChildren = true )
75
+ if (newCondition.isDefined) {
76
+ f.copy(condition = newCondition.get)
77
+ } else {
78
+ f
79
+ }
80
+ }
94
81
95
- val constantsMap = AttributeMap (equalityPredicates.map(_._1))
96
- val predicates = equalityPredicates.map(_._2).toSet
82
+ type EqualityPredicates = Seq [((AttributeReference , Literal ), BinaryComparison )]
97
83
98
- def replaceConstants (expression : Expression ) = expression transform {
99
- case a : AttributeReference =>
100
- constantsMap.get(a) match {
101
- case Some (literal) => literal
102
- case None => a
103
- }
84
+ /**
85
+ * Traverse a condition as a tree and replace attributes with constant values.
86
+ * - On matching [[And ]], recursively traverse each children and get propagated mappings.
87
+ * If the current node is not child of another [[And ]], replace all occurrences of the
88
+ * attributes with the corresponding constant values.
89
+ * - If a child of [[And ]] is [[EqualTo ]] or [[EqualNullSafe ]], propagate the mapping
90
+ * of attribute => constant.
91
+ * - On matching [[Or ]] or [[Not ]], recursively traverse each children, propagate empty mapping.
92
+ * - Otherwise, stop traversal and propagate empty mapping.
93
+ * @param condition condition to be traversed
94
+ * @param replaceChildren whether to replace attributes with constant values in children
95
+ * @return A tuple including:
96
+ * 1. Option[Expression]: optional changed condition after traversal
97
+ * 2. EqualityPredicates: propagated mapping of attribute => constant
98
+ */
99
+ private def traverse (condition : Expression , replaceChildren : Boolean )
100
+ : (Option [Expression ], EqualityPredicates ) =
101
+ condition match {
102
+ case e @ EqualTo (left : AttributeReference , right : Literal ) => (None , Seq (((left, right), e)))
103
+ case e @ EqualTo (left : Literal , right : AttributeReference ) => (None , Seq (((right, left), e)))
104
+ case e @ EqualNullSafe (left : AttributeReference , right : Literal ) =>
105
+ (None , Seq (((left, right), e)))
106
+ case e @ EqualNullSafe (left : Literal , right : AttributeReference ) =>
107
+ (None , Seq (((right, left), e)))
108
+ case a : And =>
109
+ val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false )
110
+ val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false )
111
+ val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
112
+ val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
113
+ Some (And (replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
114
+ replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
115
+ } else {
116
+ if (newLeft.isDefined || newRight.isDefined) {
117
+ Some (And (newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
118
+ } else {
119
+ None
120
+ }
104
121
}
105
-
106
- and transform {
107
- case e @ EqualTo (_, _) if ! predicates.contains(e) => replaceConstants(e)
108
- case e @ EqualNullSafe (_, _) if ! predicates.contains(e) => replaceConstants(e)
122
+ (newSelf, equalityPredicates)
123
+ case o : Or =>
124
+ // Ignore the EqualityPredicates from children since they are only propagated through And.
125
+ val (newLeft, _) = traverse(o.left, replaceChildren = true )
126
+ val (newRight, _) = traverse(o.right, replaceChildren = true )
127
+ val newSelf = if (newLeft.isDefined || newRight.isDefined) {
128
+ Some (Or (left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right))))
129
+ } else {
130
+ None
109
131
}
132
+ (newSelf, Seq .empty)
133
+ case n : Not =>
134
+ // Ignore the EqualityPredicates from children since they are only propagated through And.
135
+ val (newChild, _) = traverse(n.child, replaceChildren = true )
136
+ (newChild.map(Not ), Seq .empty)
137
+ case _ => (None , Seq .empty)
138
+ }
139
+
140
+ private def replaceConstants (condition : Expression , equalityPredicates : EqualityPredicates )
141
+ : Expression = {
142
+ val constantsMap = AttributeMap (equalityPredicates.map(_._1))
143
+ val predicates = equalityPredicates.map(_._2).toSet
144
+ def replaceConstants0 (expression : Expression ) = expression transform {
145
+ case a : AttributeReference => constantsMap.getOrElse(a, a)
146
+ }
147
+ condition transform {
148
+ case e @ EqualTo (_, _) if ! predicates.contains(e) => replaceConstants0(e)
149
+ case e @ EqualNullSafe (_, _) if ! predicates.contains(e) => replaceConstants0(e)
110
150
}
111
151
}
112
152
}
0 commit comments