Skip to content

Commit fa65718

Browse files
committed
Update Optimizer.scala
1 parent ab8e9a6 commit fa65718

File tree

1 file changed

+78
-53
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer

1 file changed

+78
-53
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 78 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ object DefaultOptimizer extends Optimizer {
3636
// SubQueries are only needed for analysis and can be removed before execution.
3737
Batch("Remove SubQueries", FixedPoint(100),
3838
EliminateSubQueries) ::
39+
Batch("Transform Condition", FixedPoint(100),
40+
TransformCondition) ::
3941
Batch("Operator Reordering", FixedPoint(100),
4042
UnionPushdown,
4143
CombineFilters,
@@ -60,6 +62,80 @@ object DefaultOptimizer extends Optimizer {
6062
ConvertToLocalRelation) :: Nil
6163
}
6264

65+
/**
66+
* Transform and/or Condition:
67+
* 1. a && a => a
68+
* 2. (a || b) && (a || c) => a || (b && c)
69+
* 3. a || a => a
70+
* 4. (a && b) || (a && c) => a && (b || c)
71+
*/
72+
object TransformCondition extends Rule[LogicalPlan] with PredicateHelper {
73+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
74+
case q: LogicalPlan => q transformExpressionsUp {
75+
case and @ And(left, right) => (left, right) match {
76+
77+
// a && a => a
78+
case (l, r) if l fastEquals r => l
79+
// (a || b) && (a || c) => a || (b && c)
80+
case _ =>
81+
// 1. Split left and right to get the disjunctive predicates,
82+
// i.e. lhsSet = (a, b), rhsSet = (a, c)
83+
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
84+
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
85+
// 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff)
86+
val lhsSet = splitDisjunctivePredicates(left).toSet
87+
val rhsSet = splitDisjunctivePredicates(right).toSet
88+
val common = lhsSet.intersect(rhsSet)
89+
if (common.isEmpty) {
90+
// No common factors, return the original predicate
91+
and
92+
} else {
93+
val ldiff = lhsSet.diff(common)
94+
val rdiff = rhsSet.diff(common)
95+
if (ldiff.isEmpty || rdiff.isEmpty) {
96+
// (a || b || c || ...) && (a || b) => (a || b)
97+
common.reduce(Or)
98+
} else {
99+
// (a || b || c || ...) && (a || b || d || ...) =>
100+
// ((c || ...) && (d || ...)) || a || b
101+
(common + And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
102+
}
103+
}
104+
} // end of And(left, right)
105+
106+
case or @ Or(left, right) => (left, right) match {
107+
108+
case (l, r) if l fastEquals r => l
109+
// (a && b) || (a && c) => a && (b || c)
110+
case _ =>
111+
// 1. Split left and right to get the conjunctive predicates,
112+
// i.e. lhsSet = (a, b), rhsSet = (a, c)
113+
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
114+
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
115+
// 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff)
116+
val lhsSet = splitConjunctivePredicates(left).toSet
117+
val rhsSet = splitConjunctivePredicates(right).toSet
118+
val common = lhsSet.intersect(rhsSet)
119+
if (common.isEmpty) {
120+
// No common factors, return the original predicate
121+
or
122+
} else {
123+
val ldiff = lhsSet.diff(common)
124+
val rdiff = rhsSet.diff(common)
125+
if (ldiff.isEmpty || rdiff.isEmpty) {
126+
// (a && b) || (a && b && c && ...) => a && b
127+
common.reduce(And)
128+
} else {
129+
// (a && b && c && ...) || (a && b && d && ...) =>
130+
// ((c && ...) || (d && ...)) && a && b
131+
(common + Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
132+
}
133+
}
134+
} // end of Or(left, right)
135+
}
136+
}
137+
}
138+
63139
/**
64140
* Pushes operations to either side of a Union.
65141
*/
@@ -347,32 +423,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
347423
// l && false => false
348424
case (_, Literal(false, BooleanType)) => Literal(false)
349425
// a && a => a
350-
case (l, r) if l fastEquals r => l
351-
// (a || b) && (a || c) => a || (b && c)
352-
case _ =>
353-
// 1. Split left and right to get the disjunctive predicates,
354-
// i.e. lhsSet = (a, b), rhsSet = (a, c)
355-
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
356-
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
357-
// 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff)
358-
val lhsSet = splitDisjunctivePredicates(left).toSet
359-
val rhsSet = splitDisjunctivePredicates(right).toSet
360-
val common = lhsSet.intersect(rhsSet)
361-
if (common.isEmpty) {
362-
// No common factors, return the original predicate
363-
and
364-
} else {
365-
val ldiff = lhsSet.diff(common)
366-
val rdiff = rhsSet.diff(common)
367-
if (ldiff.isEmpty || rdiff.isEmpty) {
368-
// (a || b || c || ...) && (a || b) => (a || b)
369-
common.reduce(Or)
370-
} else {
371-
// (a || b || c || ...) && (a || b || d || ...) =>
372-
// ((c || ...) && (d || ...)) || a || b
373-
(common + And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
374-
}
375-
}
426+
case _ => and
376427
} // end of And(left, right)
377428

378429
case or @ Or(left, right) => (left, right) match {
@@ -384,33 +435,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
384435
case (Literal(false, BooleanType), r) => r
385436
// l || false => l
386437
case (l, Literal(false, BooleanType)) => l
387-
// a || a => a
388-
case (l, r) if l fastEquals r => l
389-
// (a && b) || (a && c) => a && (b || c)
390-
case _ =>
391-
// 1. Split left and right to get the conjunctive predicates,
392-
// i.e. lhsSet = (a, b), rhsSet = (a, c)
393-
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
394-
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
395-
// 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff)
396-
val lhsSet = splitConjunctivePredicates(left).toSet
397-
val rhsSet = splitConjunctivePredicates(right).toSet
398-
val common = lhsSet.intersect(rhsSet)
399-
if (common.isEmpty) {
400-
// No common factors, return the original predicate
401-
or
402-
} else {
403-
val ldiff = lhsSet.diff(common)
404-
val rdiff = rhsSet.diff(common)
405-
if (ldiff.isEmpty || rdiff.isEmpty) {
406-
// (a && b) || (a && b && c && ...) => a && b
407-
common.reduce(And)
408-
} else {
409-
// (a && b && c && ...) || (a && b && d && ...) =>
410-
// ((c && ...) || (d && ...)) && a && b
411-
(common + Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
412-
}
413-
}
438+
case _ => or
414439
} // end of Or(left, right)
415440

416441
case not @ Not(exp) => exp match {

0 commit comments

Comments
 (0)