Skip to content

Commit 756c969

Browse files
chenghao-intelrxin
authored andcommitted
[WIP][Spark-SQL] Optimize the Constant Folding for Expression
Currently, expression does not support the "constant null" well in constant folding. e.g. Sum(a, 0) actually always produces Literal(0, NumericType) in runtime. For example: ``` explain select isnull(key+null) from src; == Logical Plan == Project [HiveGenericUdf#isnull((key#30 + CAST(null, IntegerType))) AS c_0#28] MetastoreRelation default, src, None == Optimized Logical Plan == Project [true AS c_0#28] MetastoreRelation default, src, None == Physical Plan == Project [true AS c_0#28] HiveTableScan [], (MetastoreRelation default, src, None), None ``` I've create a new Optimization rule called NullPropagation for such kind of constant folding. Author: Cheng Hao <hao.cheng@intel.com> Author: Michael Armbrust <michael@databricks.com> Closes #482 from chenghao-intel/optimize_constant_folding and squashes the following commits: 2f14b50 [Cheng Hao] Fix code style issues 68b9fad [Cheng Hao] Remove the Literal pattern matching for NullPropagation 29c8166 [Cheng Hao] Update the code for feedback of code review 50444cc [Cheng Hao] Remove the unnecessary null checking 80f9f18 [Cheng Hao] Update the UnitTest for aggregation constant folding 27ea3d7 [Cheng Hao] Fix Constant Folding Bugs & Add More Unittests b28e03a [Cheng Hao] Merge pull request #1 from marmbrus/pr/482 9ccefdb [Michael Armbrust] Add tests for optimized expression evaluation. 543ef9d [Cheng Hao] fix code style issues 9cf0396 [Cheng Hao] update code according to the code review comment 536c005 [Cheng Hao] Add Exceptional case for constant folding 3c045c7 [Cheng Hao] Optimize the Constant Folding by adding more rules 2645d4f [Cheng Hao] Constant Folding(null propagation) (cherry picked from commit 3eb53bd) Signed-off-by: Reynold Xin <rxin@apache.org>
1 parent 00fac73 commit 756c969

14 files changed

+1502
-32
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

+11-11
Original file line numberDiff line numberDiff line change
@@ -114,37 +114,37 @@ package object dsl {
114114
def attr = analysis.UnresolvedAttribute(s)
115115

116116
/** Creates a new AttributeReference of type boolean */
117-
def boolean = AttributeReference(s, BooleanType, nullable = false)()
117+
def boolean = AttributeReference(s, BooleanType, nullable = true)()
118118

119119
/** Creates a new AttributeReference of type byte */
120-
def byte = AttributeReference(s, ByteType, nullable = false)()
120+
def byte = AttributeReference(s, ByteType, nullable = true)()
121121

122122
/** Creates a new AttributeReference of type short */
123-
def short = AttributeReference(s, ShortType, nullable = false)()
123+
def short = AttributeReference(s, ShortType, nullable = true)()
124124

125125
/** Creates a new AttributeReference of type int */
126-
def int = AttributeReference(s, IntegerType, nullable = false)()
126+
def int = AttributeReference(s, IntegerType, nullable = true)()
127127

128128
/** Creates a new AttributeReference of type long */
129-
def long = AttributeReference(s, LongType, nullable = false)()
129+
def long = AttributeReference(s, LongType, nullable = true)()
130130

131131
/** Creates a new AttributeReference of type float */
132-
def float = AttributeReference(s, FloatType, nullable = false)()
132+
def float = AttributeReference(s, FloatType, nullable = true)()
133133

134134
/** Creates a new AttributeReference of type double */
135-
def double = AttributeReference(s, DoubleType, nullable = false)()
135+
def double = AttributeReference(s, DoubleType, nullable = true)()
136136

137137
/** Creates a new AttributeReference of type string */
138-
def string = AttributeReference(s, StringType, nullable = false)()
138+
def string = AttributeReference(s, StringType, nullable = true)()
139139

140140
/** Creates a new AttributeReference of type decimal */
141-
def decimal = AttributeReference(s, DecimalType, nullable = false)()
141+
def decimal = AttributeReference(s, DecimalType, nullable = true)()
142142

143143
/** Creates a new AttributeReference of type timestamp */
144-
def timestamp = AttributeReference(s, TimestampType, nullable = false)()
144+
def timestamp = AttributeReference(s, TimestampType, nullable = true)()
145145

146146
/** Creates a new AttributeReference of type binary */
147-
def binary = AttributeReference(s, BinaryType, nullable = false)()
147+
def binary = AttributeReference(s, BinaryType, nullable = true)()
148148
}
149149

150150
implicit class DslAttribute(a: AttributeReference) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ abstract class Expression extends TreeNode[Expression] {
4444
* - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its
4545
* child is foldable.
4646
*/
47-
// TODO: Supporting more foldable expressions. For example, deterministic Hive UDFs.
4847
def foldable: Boolean = false
4948
def nullable: Boolean
5049
def references: Set[Attribute]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.errors.TreeNodeException
21+
import org.apache.spark.sql.catalyst.trees
2122

2223
abstract sealed class SortDirection
2324
case object Ascending extends SortDirection
@@ -27,7 +28,10 @@ case object Descending extends SortDirection
2728
* An expression that can be used to sort a tuple. This class extends expression primarily so that
2829
* transformations over expression will descend into its child.
2930
*/
30-
case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression {
31+
case class SortOrder(child: Expression, direction: SortDirection) extends Expression
32+
with trees.UnaryNode[Expression] {
33+
34+
override def references = child.references
3135
override def dataType = child.dataType
3236
override def nullable = child.nullable
3337

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

+20-14
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
2828
val children = child :: ordinal :: Nil
2929
/** `Null` is returned for invalid ordinals. */
3030
override def nullable = true
31+
override def foldable = child.foldable && ordinal.foldable
3132
override def references = children.flatMap(_.references).toSet
3233
def dataType = child.dataType match {
3334
case ArrayType(dt) => dt
@@ -40,23 +41,27 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
4041
override def toString = s"$child[$ordinal]"
4142

4243
override def eval(input: Row): Any = {
43-
if (child.dataType.isInstanceOf[ArrayType]) {
44-
val baseValue = child.eval(input).asInstanceOf[Seq[_]]
45-
val o = ordinal.eval(input).asInstanceOf[Int]
46-
if (baseValue == null) {
47-
null
48-
} else if (o >= baseValue.size || o < 0) {
49-
null
50-
} else {
51-
baseValue(o)
52-
}
44+
val value = child.eval(input)
45+
if (value == null) {
46+
null
5347
} else {
54-
val baseValue = child.eval(input).asInstanceOf[Map[Any, _]]
5548
val key = ordinal.eval(input)
56-
if (baseValue == null) {
49+
if (key == null) {
5750
null
5851
} else {
59-
baseValue.get(key).orNull
52+
if (child.dataType.isInstanceOf[ArrayType]) {
53+
val baseValue = value.asInstanceOf[Seq[_]]
54+
val o = key.asInstanceOf[Int]
55+
if (o >= baseValue.size || o < 0) {
56+
null
57+
} else {
58+
baseValue(o)
59+
}
60+
} else {
61+
val baseValue = value.asInstanceOf[Map[Any, _]]
62+
val key = ordinal.eval(input)
63+
baseValue.get(key).orNull
64+
}
6065
}
6166
}
6267
}
@@ -69,7 +74,8 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio
6974
type EvaluatedType = Any
7075

7176
def dataType = field.dataType
72-
def nullable = field.nullable
77+
override def nullable = field.nullable
78+
override def foldable = child.foldable
7379

7480
protected def structType = child.dataType match {
7581
case s: StructType => s

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

+1-2
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ abstract class BinaryPredicate extends BinaryExpression with Predicate {
6565
def nullable = left.nullable || right.nullable
6666
}
6767

68-
case class Not(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
69-
def references = child.references
68+
case class Not(child: Expression) extends UnaryExpression with Predicate {
7069
override def foldable = child.foldable
7170
def nullable = child.nullable
7271
override def toString = s"NOT $child"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

+67
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.types._
2626
object Optimizer extends RuleExecutor[LogicalPlan] {
2727
val batches =
2828
Batch("ConstantFolding", Once,
29+
NullPropagation,
2930
ConstantFolding,
3031
BooleanSimplification,
3132
SimplifyFilters,
@@ -85,6 +86,72 @@ object ColumnPruning extends Rule[LogicalPlan] {
8586
}
8687
}
8788

89+
/**
90+
* Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with
91+
* equivalent [[catalyst.expressions.Literal Literal]] values. This rule is more specific with
92+
* Null value propagation from bottom to top of the expression tree.
93+
*/
94+
object NullPropagation extends Rule[LogicalPlan] {
95+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
96+
case q: LogicalPlan => q transformExpressionsUp {
97+
case e @ Count(Literal(null, _)) => Literal(0, e.dataType)
98+
case e @ Sum(Literal(c, _)) if c == 0 => Literal(0, e.dataType)
99+
case e @ Average(Literal(c, _)) if c == 0 => Literal(0.0, e.dataType)
100+
case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType)
101+
case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType)
102+
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
103+
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
104+
case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
105+
case e @ Coalesce(children) => {
106+
val newChildren = children.filter(c => c match {
107+
case Literal(null, _) => false
108+
case _ => true
109+
})
110+
if (newChildren.length == 0) {
111+
Literal(null, e.dataType)
112+
} else if (newChildren.length == 1) {
113+
newChildren(0)
114+
} else {
115+
Coalesce(newChildren)
116+
}
117+
}
118+
case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue
119+
case e @ In(Literal(v, _), list) if (list.exists(c => c match {
120+
case Literal(candidate, _) if candidate == v => true
121+
case _ => false
122+
})) => Literal(true, BooleanType)
123+
case e: UnaryMinus => e.child match {
124+
case Literal(null, _) => Literal(null, e.dataType)
125+
case _ => e
126+
}
127+
case e: Cast => e.child match {
128+
case Literal(null, _) => Literal(null, e.dataType)
129+
case _ => e
130+
}
131+
case e: Not => e.child match {
132+
case Literal(null, _) => Literal(null, e.dataType)
133+
case _ => e
134+
}
135+
// Put exceptional cases above if any
136+
case e: BinaryArithmetic => e.children match {
137+
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
138+
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
139+
case _ => e
140+
}
141+
case e: BinaryComparison => e.children match {
142+
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
143+
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
144+
case _ => e
145+
}
146+
case e: StringRegexExpression => e.children match {
147+
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
148+
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
149+
case _ => e
150+
}
151+
}
152+
}
153+
}
154+
88155
/**
89156
* Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with
90157
* equivalent [[catalyst.expressions.Literal Literal]] values.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

+112-3
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ class ExpressionEvaluationSuite extends FunSuite {
108108
truthTable.foreach {
109109
case (l,r,answer) =>
110110
val expr = op(Literal(l, BooleanType), Literal(r, BooleanType))
111-
val result = expr.eval(null)
112-
if (result != answer)
113-
fail(s"$expr should not evaluate to $result, expected: $answer")
111+
checkEvaluation(expr, answer)
114112
}
115113
}
116114
}
@@ -131,6 +129,7 @@ class ExpressionEvaluationSuite extends FunSuite {
131129

132130
test("LIKE literal Regular Expression") {
133131
checkEvaluation(Literal(null, StringType).like("a"), null)
132+
checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null)
134133
checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null)
135134
checkEvaluation("abdef" like "abdef", true)
136135
checkEvaluation("a_%b" like "a\\__b", true)
@@ -159,9 +158,14 @@ class ExpressionEvaluationSuite extends FunSuite {
159158
checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%")))
160159
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%")))
161160
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%")))
161+
162+
checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%")))
162163
}
163164

164165
test("RLIKE literal Regular Expression") {
166+
checkEvaluation(Literal(null, StringType) rlike "abdef", null)
167+
checkEvaluation("abdef" rlike Literal(null, StringType), null)
168+
checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null)
165169
checkEvaluation("abdef" rlike "abdef", true)
166170
checkEvaluation("abbbbc" rlike "a.*c", true)
167171

@@ -257,6 +261,8 @@ class ExpressionEvaluationSuite extends FunSuite {
257261
assert(("abcdef" cast DecimalType).nullable === true)
258262
assert(("abcdef" cast DoubleType).nullable === true)
259263
assert(("abcdef" cast FloatType).nullable === true)
264+
265+
checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null)
260266
}
261267

262268
test("timestamp") {
@@ -287,5 +293,108 @@ class ExpressionEvaluationSuite extends FunSuite {
287293
// A test for higher precision than millis
288294
checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001)
289295
}
296+
297+
test("null checking") {
298+
val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
299+
val c1 = 'a.string.at(0)
300+
val c2 = 'a.string.at(1)
301+
val c3 = 'a.boolean.at(2)
302+
val c4 = 'a.boolean.at(3)
303+
304+
checkEvaluation(IsNull(c1), false, row)
305+
checkEvaluation(IsNotNull(c1), true, row)
306+
307+
checkEvaluation(IsNull(c2), true, row)
308+
checkEvaluation(IsNotNull(c2), false, row)
309+
310+
checkEvaluation(IsNull(Literal(1, ShortType)), false)
311+
checkEvaluation(IsNotNull(Literal(1, ShortType)), true)
312+
313+
checkEvaluation(IsNull(Literal(null, ShortType)), true)
314+
checkEvaluation(IsNotNull(Literal(null, ShortType)), false)
315+
316+
checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
317+
checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row)
318+
checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row)
319+
320+
checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row)
321+
checkEvaluation(If(c3, c1, c2), "^Ba*n", row)
322+
checkEvaluation(If(c4, c2, c1), "^Ba*n", row)
323+
checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row)
324+
checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row)
325+
checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row)
326+
checkEvaluation(If(Literal(false, BooleanType),
327+
Literal("a", StringType), Literal("b", StringType)), "b", row)
328+
329+
checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row)
330+
checkEvaluation(In(Literal("^Ba*n", StringType),
331+
Literal("^Ba*n", StringType) :: Nil), true, row)
332+
checkEvaluation(In(Literal("^Ba*n", StringType),
333+
Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
334+
}
335+
336+
test("complex type") {
337+
val row = new GenericRow(Array[Any](
338+
"^Ba*n", // 0
339+
null.asInstanceOf[String], // 1
340+
new GenericRow(Array[Any]("aa", "bb")), // 2
341+
Map("aa"->"bb"), // 3
342+
Seq("aa", "bb") // 4
343+
))
344+
345+
val typeS = StructType(
346+
StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
347+
)
348+
val typeMap = MapType(StringType, StringType)
349+
val typeArray = ArrayType(StringType)
350+
351+
checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
352+
Literal("aa")), "bb", row)
353+
checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row)
354+
checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row)
355+
checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
356+
Literal(null, StringType)), null, row)
357+
358+
checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
359+
Literal(1)), "bb", row)
360+
checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row)
361+
checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row)
362+
checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
363+
Literal(null, IntegerType)), null, row)
364+
365+
checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row)
366+
checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
367+
}
368+
369+
test("arithmetic") {
370+
val row = new GenericRow(Array[Any](1, 2, 3, null))
371+
val c1 = 'a.int.at(0)
372+
val c2 = 'a.int.at(1)
373+
val c3 = 'a.int.at(2)
374+
val c4 = 'a.int.at(3)
375+
376+
checkEvaluation(UnaryMinus(c1), -1, row)
377+
checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100)
378+
379+
checkEvaluation(Add(c1, c4), null, row)
380+
checkEvaluation(Add(c1, c2), 3, row)
381+
checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row)
382+
checkEvaluation(Add(Literal(null, IntegerType), c2), null, row)
383+
checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
384+
}
385+
386+
test("BinaryComparison") {
387+
val row = new GenericRow(Array[Any](1, 2, 3, null))
388+
val c1 = 'a.int.at(0)
389+
val c2 = 'a.int.at(1)
390+
val c3 = 'a.int.at(2)
391+
val c4 = 'a.int.at(3)
392+
393+
checkEvaluation(LessThan(c1, c4), null, row)
394+
checkEvaluation(LessThan(c1, c2), true, row)
395+
checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row)
396+
checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row)
397+
checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
398+
}
290399
}
291400

0 commit comments

Comments
 (0)