Skip to content

Commit b6229df

Browse files
viiryadongjoon-hyun
authored andcommitted
[SPARK-32258][SQL] NormalizeFloatingNumbers directly normalizes IF/CaseWhen/Coalesce child expressions
### What changes were proposed in this pull request? This patch proposes to let `NormalizeFloatingNumbers` rule directly normalizes on certain children expressions. It could simplify expression tree. ### Why are the changes needed? Currently NormalizeFloatingNumbers rule treats some expressions as black box but we can optimize it a bit by normalizing directly the inner children expressions. Also see apache#28962 (comment). ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests. Closes apache#29061 from viirya/SPARK-32258. Authored-by: Liang-Chi Hsieh <viirya@gmail.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent bc3d4ba commit b6229df

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression}
20+
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression}
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2222
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
2323
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window}
@@ -116,6 +116,15 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
116116
case CreateMap(children, useStringTypeWhenEmpty) =>
117117
CreateMap(children.map(normalize), useStringTypeWhenEmpty)
118118

119+
case If(cond, trueValue, falseValue) =>
120+
If(cond, normalize(trueValue), normalize(falseValue))
121+
122+
case CaseWhen(branches, elseVale) =>
123+
CaseWhen(branches.map(br => (br._1, normalize(br._2))), elseVale.map(normalize))
124+
125+
case Coalesce(children) =>
126+
Coalesce(children.map(normalize))
127+
119128
case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
120129
KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))
121130

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.dsl.expressions._
2121
import org.apache.spark.sql.catalyst.dsl.plans._
22-
import org.apache.spark.sql.catalyst.expressions.{And, IsNull, KnownFloatingPointNormalized}
22+
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, IsNull, KnownFloatingPointNormalized}
2323
import org.apache.spark.sql.catalyst.plans.PlanTest
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -85,8 +85,43 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest {
8585
val optimized = Optimize.execute(query)
8686
val doubleOptimized = Optimize.execute(optimized)
8787
val joinCond = IsNull(a) === IsNull(b) &&
88-
KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(a, 0.0))) ===
89-
KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(b, 0.0)))
88+
coalesce(KnownFloatingPointNormalized(NormalizeNaNAndZero(a)),
89+
KnownFloatingPointNormalized(NormalizeNaNAndZero(0.0))) ===
90+
coalesce(KnownFloatingPointNormalized(NormalizeNaNAndZero(b)),
91+
KnownFloatingPointNormalized(NormalizeNaNAndZero(0.0)))
92+
val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond))
93+
94+
comparePlans(doubleOptimized, correctAnswer)
95+
}
96+
97+
test("SPARK-32258: normalize the children of If") {
98+
val cond = If(a > 0.1D, a, a + 0.2D) === b
99+
val query = testRelation1.join(testRelation2, condition = Some(cond))
100+
val optimized = Optimize.execute(query)
101+
val doubleOptimized = Optimize.execute(optimized)
102+
103+
val joinCond = If(a > 0.1D,
104+
KnownFloatingPointNormalized(NormalizeNaNAndZero(a)),
105+
KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.2D))) ===
106+
KnownFloatingPointNormalized(NormalizeNaNAndZero(b))
107+
val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond))
108+
109+
comparePlans(doubleOptimized, correctAnswer)
110+
}
111+
112+
test("SPARK-32258: normalize the children of CaseWhen") {
113+
val cond = CaseWhen(
114+
Seq((a > 0.1D, a), (a > 0.2D, a + 0.2D)),
115+
Some(a + 0.3D)) === b
116+
val query = testRelation1.join(testRelation2, condition = Some(cond))
117+
val optimized = Optimize.execute(query)
118+
val doubleOptimized = Optimize.execute(optimized)
119+
120+
val joinCond = CaseWhen(
121+
Seq((a > 0.1D, KnownFloatingPointNormalized(NormalizeNaNAndZero(a))),
122+
(a > 0.2D, KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.2D)))),
123+
Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.3D)))) ===
124+
KnownFloatingPointNormalized(NormalizeNaNAndZero(b))
90125
val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond))
91126

92127
comparePlans(doubleOptimized, correctAnswer)

0 commit comments

Comments
 (0)