Skip to content

Commit 104ea17

Browse files
cloud-fancmonkey
authored andcommitted
[SPARK-19309][SQL] disable common subexpression elimination for conditional expressions
## What changes were proposed in this pull request? As I pointed out in apache#15807 (comment) , the current subexpression elimination framework has a problem, it always evaluates all common subexpressions at the beginning, even they are inside conditional expressions and may not be accessed. Ideally we should implement it like scala lazy val, so we only evaluate it when it gets accessed at lease once. apache#15837 tries this approach, but it seems too complicated and may introduce performance regression. This PR simply stops common subexpression elimination for conditional expressions, with some cleanup. ## How was this patch tested? regression test Author: Wenchen Fan <wenchen@databricks.com> Closes apache#16659 from cloud-fan/codegen.
1 parent 22d6ac4 commit 104ea17

File tree

7 files changed

+84
-171
lines changed

7 files changed

+84
-171
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,28 +67,34 @@ class EquivalentExpressions {
6767
/**
6868
* Adds the expression to this data structure recursively. Stops if a matching expression
6969
* is found. That is, if `expr` has already been added, its children are not added.
70-
* If ignoreLeaf is true, leaf nodes are ignored.
7170
*/
72-
def addExprTree(
73-
root: Expression,
74-
ignoreLeaf: Boolean = true,
75-
skipReferenceToExpressions: Boolean = true): Unit = {
76-
val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) ||
71+
def addExprTree(expr: Expression): Unit = {
72+
val skip = expr.isInstanceOf[LeafExpression] ||
7773
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
7874
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
79-
root.find(_.isInstanceOf[LambdaVariable]).isDefined
80-
// There are some special expressions that we should not recurse into children.
75+
expr.find(_.isInstanceOf[LambdaVariable]).isDefined
76+
77+
// There are some special expressions that we should not recurse into all of its children.
8178
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
82-
// 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.
83-
val shouldRecurse = root match {
84-
// TODO: some expressions implements `CodegenFallback` but can still do codegen,
85-
// e.g. `CaseWhen`, we should support them.
86-
case _: CodegenFallback => false
87-
case _: ReferenceToExpressions if skipReferenceToExpressions => false
88-
case _ => true
79+
// 2. If: common subexpressions will always be evaluated at the beginning, but the true and
80+
// false expressions in `If` may not get accessed, according to the predicate
81+
// expression. We should only recurse into the predicate expression.
82+
// 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain
83+
// condition. We should only recurse into the first condition expression as it
84+
// will always get accessed.
85+
// 4. Coalesce: it's also a conditional expression, we should only recurse into the first
86+
// children, because others may not get accessed.
87+
def childrenToRecurse: Seq[Expression] = expr match {
88+
case _: CodegenFallback => Nil
89+
case i: If => i.predicate :: Nil
90+
// `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here.
91+
case c: CaseWhenCodegen => c.children.head :: Nil
92+
case c: Coalesce => c.children.head :: Nil
93+
case other => other.children
8994
}
90-
if (!skip && !addExpr(root) && shouldRecurse) {
91-
root.children.foreach(addExprTree(_, ignoreLeaf))
95+
96+
if (!skip && !addExpr(expr)) {
97+
childrenToRecurse.foreach(addExprTree)
9298
}
9399
}
94100

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ object UnsafeProjection {
117117
* Returns an UnsafeProjection for given Array of DataTypes.
118118
*/
119119
def create(fields: Array[DataType]): UnsafeProjection = {
120-
create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
120+
create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true)))
121121
}
122122

123123
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala

Lines changed: 0 additions & 92 deletions
This file was deleted.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -726,18 +726,18 @@ class CodegenContext {
726726
val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
727727

728728
// Add each expression tree and compute the common subexpressions.
729-
expressions.foreach(equivalentExpressions.addExprTree(_, true, false))
729+
expressions.foreach(equivalentExpressions.addExprTree)
730730

731731
// Get all the expressions that appear at least twice and set up the state for subexpression
732732
// elimination.
733733
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
734734
val codes = commonExprs.map { e =>
735735
val expr = e.head
736736
// Generate the code for this expression tree.
737-
val code = expr.genCode(this)
738-
val state = SubExprEliminationState(code.isNull, code.value)
737+
val eval = expr.genCode(this)
738+
val state = SubExprEliminationState(eval.isNull, eval.value)
739739
e.foreach(subExprEliminationExprs.put(_, state))
740-
code.code.trim
740+
eval.code.trim
741741
}
742742
SubExprCodes(codes, subExprEliminationExprs.toMap)
743743
}
@@ -747,7 +747,7 @@ class CodegenContext {
747747
* common subexpressions, generates the functions that evaluate those expressions and populates
748748
* the mapping of common subexpressions to the generated functions.
749749
*/
750-
private def subexpressionElimination(expressions: Seq[Expression]) = {
750+
private def subexpressionElimination(expressions: Seq[Expression]): Unit = {
751751
// Add each expression tree and compute the common subexpressions.
752752
expressions.foreach(equivalentExpressions.addExprTree(_))
753753

@@ -761,13 +761,13 @@ class CodegenContext {
761761
val value = s"${fnName}Value"
762762

763763
// Generate the code for this expression tree and wrap it in a function.
764-
val code = expr.genCode(this)
764+
val eval = expr.genCode(this)
765765
val fn =
766766
s"""
767767
|private void $fnName(InternalRow $INPUT_ROW) {
768-
| ${code.code.trim}
769-
| $isNull = ${code.isNull};
770-
| $value = ${code.value};
768+
| ${eval.code.trim}
769+
| $isNull = ${eval.isNull};
770+
| $value = ${eval.value};
771771
|}
772772
""".stripMargin
773773

@@ -780,9 +780,6 @@ class CodegenContext {
780780
// The cost of doing subexpression elimination is:
781781
// 1. Extra function call, although this is probably *good* as the JIT can decide to
782782
// inline or not.
783-
// 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly
784-
// very often. The reason it is not loaded is because of a prior branch.
785-
// 3. Extra store into isLoaded.
786783
// The benefit doing subexpression elimination is:
787784
// 1. Running the expression logic. Even for a simple expression, it is likely more than 3
788785
// above.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ import org.apache.spark.sql.Row
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.dsl.expressions._
2727
import org.apache.spark.sql.catalyst.expressions.codegen._
28-
import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType}
29-
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
28+
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, GetExternalRowField, ValidateExternalType}
29+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils}
3030
import org.apache.spark.sql.types._
3131
import org.apache.spark.unsafe.types.UTF8String
3232
import org.apache.spark.util.ThreadUtils
@@ -313,4 +313,15 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
313313
test("SPARK-17160: field names are properly escaped by AssertTrue") {
314314
GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil)
315315
}
316+
317+
test("should not apply common subexpression elimination on conditional expressions") {
318+
val row = InternalRow(null)
319+
val bound = BoundReference(0, IntegerType, true)
320+
val assertNotNull = AssertNotNull(bound, Nil)
321+
val expr = If(IsNull(bound), Literal(1), Add(assertNotNull, assertNotNull))
322+
val projection = GenerateUnsafeProjection.generate(
323+
Seq(expr), subexpressionEliminationEnabled = true)
324+
// should not throw exception
325+
projection(row)
326+
}
316327
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
9797
val add2 = Add(add, add)
9898

9999
var equivalence = new EquivalentExpressions
100-
equivalence.addExprTree(add, true)
101-
equivalence.addExprTree(abs, true)
102-
equivalence.addExprTree(add2, true)
100+
equivalence.addExprTree(add)
101+
equivalence.addExprTree(abs)
102+
equivalence.addExprTree(add2)
103103

104104
// Should only have one equivalence for `one + two`
105105
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 1)
@@ -115,41 +115,17 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
115115
val mul2 = Multiply(mul, mul)
116116
val sqrt = Sqrt(mul2)
117117
val sum = Add(mul2, sqrt)
118-
equivalence.addExprTree(mul, true)
119-
equivalence.addExprTree(mul2, true)
120-
equivalence.addExprTree(sqrt, true)
121-
equivalence.addExprTree(sum, true)
118+
equivalence.addExprTree(mul)
119+
equivalence.addExprTree(mul2)
120+
equivalence.addExprTree(sqrt)
121+
equivalence.addExprTree(sum)
122122

123123
// (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
124124
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 3)
125125
assert(equivalence.getEquivalentExprs(mul).size == 3)
126126
assert(equivalence.getEquivalentExprs(mul2).size == 3)
127127
assert(equivalence.getEquivalentExprs(sqrt).size == 2)
128128
assert(equivalence.getEquivalentExprs(sum).size == 1)
129-
130-
// Some expressions inspired by TPCH-Q1
131-
// sum(l_quantity) as sum_qty,
132-
// sum(l_extendedprice) as sum_base_price,
133-
// sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
134-
// sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
135-
// avg(l_extendedprice) as avg_price,
136-
// avg(l_discount) as avg_disc
137-
equivalence = new EquivalentExpressions
138-
val quantity = Literal(1)
139-
val price = Literal(1.1)
140-
val discount = Literal(.24)
141-
val tax = Literal(0.1)
142-
equivalence.addExprTree(quantity, false)
143-
equivalence.addExprTree(price, false)
144-
equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false)
145-
equivalence.addExprTree(
146-
Multiply(
147-
Multiply(price, Subtract(Literal(1), discount)),
148-
Add(Literal(1), tax)), false)
149-
equivalence.addExprTree(price, false)
150-
equivalence.addExprTree(discount, false)
151-
// quantity, price, discount and (price * (1 - discount))
152-
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 4)
153129
}
154130

155131
test("Expression equivalence - non deterministic") {
@@ -167,11 +143,24 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
167143
val add = Add(two, fallback)
168144

169145
val equivalence = new EquivalentExpressions
170-
equivalence.addExprTree(add, true)
146+
equivalence.addExprTree(add)
171147
// the `two` inside `fallback` should not be added
172148
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
173149
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
174150
}
151+
152+
test("Children of conditional expressions") {
153+
val condition = And(Literal(true), Literal(false))
154+
val add = Add(Literal(1), Literal(2))
155+
val ifExpr = If(condition, add, add)
156+
157+
val equivalence = new EquivalentExpressions
158+
equivalence.addExprTree(ifExpr)
159+
// the `add` inside `If` should not be added
160+
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
161+
// only ifExpr and its predicate expression
162+
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2)
163+
}
175164
}
176165

177166
case class CodegenFallbackExpression(child: Expression)

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,15 @@ case class SimpleTypedAggregateExpression(
143143
override lazy val aggBufferAttributes: Seq[AttributeReference] =
144144
bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference])
145145

146+
private def serializeToBuffer(expr: Expression): Seq[Expression] = {
147+
bufferSerializer.map(_.transform {
148+
case _: BoundReference => expr
149+
})
150+
}
151+
146152
override lazy val initialValues: Seq[Expression] = {
147153
val zero = Literal.fromObject(aggregator.zero, bufferExternalType)
148-
bufferSerializer.map(ReferenceToExpressions(_, zero :: Nil))
154+
serializeToBuffer(zero)
149155
}
150156

151157
override lazy val updateExpressions: Seq[Expression] = {
@@ -154,8 +160,7 @@ case class SimpleTypedAggregateExpression(
154160
"reduce",
155161
bufferExternalType,
156162
bufferDeserializer :: inputDeserializer.get :: Nil)
157-
158-
bufferSerializer.map(ReferenceToExpressions(_, reduced :: Nil))
163+
serializeToBuffer(reduced)
159164
}
160165

161166
override lazy val mergeExpressions: Seq[Expression] = {
@@ -170,8 +175,7 @@ case class SimpleTypedAggregateExpression(
170175
"merge",
171176
bufferExternalType,
172177
leftBuffer :: rightBuffer :: Nil)
173-
174-
bufferSerializer.map(ReferenceToExpressions(_, merged :: Nil))
178+
serializeToBuffer(merged)
175179
}
176180

177181
override lazy val evaluateExpression: Expression = {
@@ -181,19 +185,17 @@ case class SimpleTypedAggregateExpression(
181185
outputExternalType,
182186
bufferDeserializer :: Nil)
183187

188+
val outputSerializeExprs = outputSerializer.map(_.transform {
189+
case _: BoundReference => resultObj
190+
})
191+
184192
dataType match {
185-
case s: StructType =>
193+
case _: StructType =>
186194
val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get
187-
val struct = If(
188-
IsNull(objRef),
189-
Literal.create(null, dataType),
190-
CreateStruct(outputSerializer))
191-
ReferenceToExpressions(struct, resultObj :: Nil)
195+
If(IsNull(objRef), Literal.create(null, dataType), CreateStruct(outputSerializeExprs))
192196
case _ =>
193-
assert(outputSerializer.length == 1)
194-
outputSerializer.head transform {
195-
case b: BoundReference => resultObj
196-
}
197+
assert(outputSerializeExprs.length == 1)
198+
outputSerializeExprs.head
197199
}
198200
}
199201

0 commit comments

Comments
 (0)