Skip to content

Commit 9cf0396

Browse files
update code according to the code review comment
1 parent 536c005 commit 9cf0396

File tree

3 files changed

+37
-15
lines changed

3 files changed

+37
-15
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ abstract class Expression extends TreeNode[Expression] {
222222
}
223223
}
224224

225+
/**
226+
* Root class for rewritten 2 operands UDF expression. By default, we assume it produces Null if
227+
* either one of its operands is null. Exceptional case requires to update the optimization rule
228+
* at [[optimizer.ConstantFolding ConstantFolding]]
229+
*/
225230
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
226231
self: Product =>
227232

@@ -238,6 +243,11 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
238243
self: Product =>
239244
}
240245

246+
/**
247+
* Root class for rewritten single operand UDF expression. By default, we assume it produces Null if
248+
* its operand is null. Exceptional case requires to update the optimization rule
249+
* at [[optimizer.ConstantFolding ConstantFolding]]
250+
*/
241251
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
242252
self: Product =>
243253

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.errors.TreeNodeException
21+
import org.apache.spark.sql.catalyst.trees
2122

2223
abstract sealed class SortDirection
2324
case object Ascending extends SortDirection
@@ -27,7 +28,10 @@ case object Descending extends SortDirection
2728
* An expression that can be used to sort a tuple. This class extends expression primarily so that
2829
* transformations over expression will descend into its child.
2930
*/
30-
case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression {
31+
case class SortOrder(child: Expression, direction: SortDirection) extends Expression
32+
with trees.UnaryNode[Expression] {
33+
34+
override def references = child.references
3135
override def dataType = child.dataType
3236
override def nullable = child.nullable
3337

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.types._
2626
object Optimizer extends RuleExecutor[LogicalPlan] {
2727
val batches =
2828
Batch("ConstantFolding", Once,
29+
NullPropagation,
2930
ConstantFolding,
3031
BooleanSimplification,
3132
SimplifyFilters,
@@ -87,23 +88,18 @@ object ColumnPruning extends Rule[LogicalPlan] {
8788

8889
/**
8990
* Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with
90-
* equivalent [[catalyst.expressions.Literal Literal]] values.
91+
* equivalent [[catalyst.expressions.Literal Literal]] values. This rule is more specific with
92+
* Null value propagation from bottom to top of the expression tree.
9193
*/
92-
object ConstantFolding extends Rule[LogicalPlan] {
94+
object NullPropagation extends Rule[LogicalPlan] {
9395
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
9496
case q: LogicalPlan => q transformExpressionsUp {
9597
// Skip redundant folding of literals.
9698
case l: Literal => l
97-
// if it's foldable
98-
case e if e.foldable => Literal(e.eval(null), e.dataType)
9999
case e @ Count(Literal(null, _)) => Literal(null, e.dataType)
100100
case e @ Sum(Literal(null, _)) => Literal(null, e.dataType)
101101
case e @ Average(Literal(null, _)) => Literal(null, e.dataType)
102-
case e @ IsNull(Literal(null, _)) => Literal(true, BooleanType)
103-
case e @ IsNull(Literal(_, _)) => Literal(false, BooleanType)
104102
case e @ IsNull(c @ Rand) => Literal(false, BooleanType)
105-
case e @ IsNotNull(Literal(null, _)) => Literal(false, BooleanType)
106-
case e @ IsNotNull(Literal(_, _)) => Literal(true, BooleanType)
107103
case e @ IsNotNull(c @ Rand) => Literal(true, BooleanType)
108104
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
109105
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
@@ -113,10 +109,10 @@ object ConstantFolding extends Rule[LogicalPlan] {
113109
case Literal(null, _) => false
114110
case _ => true
115111
})
116-
if(newChildren.length == null) {
112+
if(newChildren.length == 0) {
117113
Literal(null, e.dataType)
118-
} else if(newChildren.length == children.length){
119-
e
114+
} else if(newChildren.length == 1) {
115+
newChildren(0)
120116
} else {
121117
Coalesce(newChildren)
122118
}
@@ -126,9 +122,8 @@ object ConstantFolding extends Rule[LogicalPlan] {
126122
case Literal(candidate, _) if(candidate == v) => true
127123
case _ => false
128124
})) => Literal(true, BooleanType)
129-
130-
case e @ SortOrder(_, _) => e
131-
// put exceptional cases(Unary & Binary Expression) before here.
125+
// Put exceptional cases(Unary & Binary Expression if it doesn't produce null with constant
126+
// null operand) before here.
132127
case e: UnaryExpression => e.child match {
133128
case Literal(null, _) => Literal(null, e.dataType)
134129
case _ => e
@@ -141,6 +136,19 @@ object ConstantFolding extends Rule[LogicalPlan] {
141136
}
142137
}
143138
}
139+
/**
140+
* Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with
141+
* equivalent [[catalyst.expressions.Literal Literal]] values.
142+
*/
143+
object ConstantFolding extends Rule[LogicalPlan] {
144+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
145+
case q: LogicalPlan => q transformExpressionsDown {
146+
// Skip redundant folding of literals.
147+
case l: Literal => l
148+
case e if e.foldable => Literal(e.eval(null), e.dataType)
149+
}
150+
}
151+
}
144152

145153
/**
146154
* Simplifies boolean expressions where the answer can be determined without evaluating both sides.

0 commit comments

Comments
 (0)