Skip to content

Commit 448faf0

Browse files
committed
fix
1 parent 45fd7a5 commit 448faf0

File tree

3 files changed

+83
-29
lines changed

3 files changed

+83
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
103103
RemoveDispensableExpressions,
104104
SimplifyBinaryComparison,
105105
ReplaceNullWithFalseInPredicate,
106-
SimplifyConditionalInPredicate,
106+
SimplifyConditionalsInPredicate,
107107
PruneFilters,
108108
SimplifyCasts,
109109
SimplifyCaseConversionExpressions,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If, LambdaFunction, Literal, MapFilter, Not, Or}
21+
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
22+
import org.apache.spark.sql.catalyst.plans.logical._
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.types.BooleanType
25+
import org.apache.spark.util.Utils
26+
27+
28+
object SimplifyConditionalsInPredicate extends Rule[LogicalPlan] {
29+
30+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
31+
case f @ Filter(cond, _) => f.copy(condition = simplifyConditional(cond))
32+
case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(simplifyConditional(cond)))
33+
case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(simplifyConditional(cond)))
34+
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(simplifyConditional(cond)))
35+
case p: LogicalPlan => p transformExpressions {
36+
case i @ If(pred, _, _) => i.copy(predicate = simplifyConditional(pred))
37+
case cw @ CaseWhen(branches, _) =>
38+
val newBranches = branches.map { case (cond, value) =>
39+
simplifyConditional(cond) -> value
40+
}
41+
cw.copy(branches = newBranches)
42+
case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) =>
43+
val newLambda = lf.copy(function = simplifyConditional(func))
44+
af.copy(function = newLambda)
45+
case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _), false) =>
46+
val newLambda = lf.copy(function = simplifyConditional(func))
47+
ae.copy(function = newLambda)
48+
case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) =>
49+
val newLambda = lf.copy(function = simplifyConditional(func))
50+
mf.copy(function = newLambda)
51+
}
52+
}
53+
54+
private def simplifyConditional(e: Expression): Expression = e match {
55+
case Literal(null, BooleanType) => FalseLiteral
56+
case And(left, right) => And(simplifyConditional(left), simplifyConditional(right))
57+
case Or(left, right) => Or(simplifyConditional(left), simplifyConditional(right))
58+
case If(cond, t, FalseLiteral) => And(cond, t)
59+
case If(cond, t, TrueLiteral) => Or(Not(cond), t)
60+
case If(cond, FalseLiteral, f) => And(Not(cond), f)
61+
case If(cond, TrueLiteral, f) => Or(cond, f)
62+
case CaseWhen(Seq((cond, trueValue)),
63+
Some(FalseLiteral) | Some(Literal(null, BooleanType)) | None) =>
64+
And(cond, trueValue)
65+
case CaseWhen(Seq((cond, trueValue)), Some(TrueLiteral)) =>
66+
Or(Not(cond), trueValue)
67+
case CaseWhen(Seq((cond, FalseLiteral)), elseValue) =>
68+
And(Not(cond), elseValue.getOrElse(Literal(null, BooleanType)))
69+
case CaseWhen(Seq((cond, TrueLiteral)), elseValue) =>
70+
Or(cond, elseValue.getOrElse(Literal(null, BooleanType)))
71+
case e if e.dataType == BooleanType => e
72+
case e =>
73+
val message = "Expected a Boolean type expression in simplifyConditional, " +
74+
s"but got the type `${e.dataType.catalogString}` in `${e.sql}`."
75+
if (Utils.isTesting) {
76+
throw new IllegalArgumentException(message)
77+
} else {
78+
logWarning(message)
79+
e
80+
}
81+
}
82+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.plans.logical._
3030
import org.apache.spark.sql.catalyst.rules._
3131
import org.apache.spark.sql.internal.SQLConf
3232
import org.apache.spark.sql.types._
33-
import org.apache.spark.util.Utils
3433

3534
/*
3635
* Optimization rules defined in this file should not affect the structure of the logical plan.
@@ -586,33 +585,6 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
586585
}
587586

588587

589-
object SimplifyConditionalInPredicate extends Rule[LogicalPlan] {
590-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
591-
case f @ Filter(cond, _) => f.copy(condition = simplifyConditional(cond))
592-
}
593-
594-
private def simplifyConditional(e: Expression): Expression = e match {
595-
case cw @ CaseWhen(branches, elseValue) if cw.dataType == BooleanType && branches.size == 1 &&
596-
elseValue.forall(_.semanticEquals(FalseLiteral)) =>
597-
val (whenVal, thenVal) = branches.head
598-
And(whenVal, thenVal)
599-
case i @ If(pred, trueVal, FalseLiteral) if i.dataType == BooleanType =>
600-
And(pred, trueVal)
601-
case e if e.dataType == BooleanType =>
602-
e
603-
case e =>
604-
val message = "Expected a Boolean type expression in simplifyConditional, " +
605-
s"but got the type `${e.dataType.catalogString}` in `${e.sql}`."
606-
if (Utils.isTesting) {
607-
throw new IllegalArgumentException(message)
608-
} else {
609-
logWarning(message)
610-
e
611-
}
612-
}
613-
}
614-
615-
616588
/**
617589
* Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition.
618590
* For example, when the expression is just checking to see if a string starts with a given

0 commit comments

Comments
 (0)