Skip to content

Commit

Permalink
[SPARK-42851][SQL] Guard EquivalentExpressions.addExpr() with support…
Browse files Browse the repository at this point in the history
…edExpression()

### What changes were proposed in this pull request?

In `EquivalentExpressions.addExpr()`, add a guard `supportedExpression()` to make it consistent with `addExprTree()` and `getExprState()`.

### Why are the changes needed?

This fixes a regression caused by apache#39010 which added the `supportedExpression()` to `addExprTree()` and `getExprState()` but not `addExpr()`.

One example of a use case affected by the inconsistency is the `PhysicalAggregation` pattern in physical planning. There, it calls `addExpr()` to deduplicate the aggregate expressions, and then calls `getExprState()` to deduplicate the result expressions. Guarding inconsistently will cause the aggregate and result expressions go out of sync, eventually resulting in query execution error (or whole-stage codegen error).

### Does this PR introduce _any_ user-facing change?

This fixes a regression affecting Spark 3.3.2+, where it may manifest as an error running aggregate operators with higher-order functions.

Example running the SQL command:
```sql
select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2)
```
example error message before the fix:
```
java.lang.IllegalStateException: Couldn't find max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, false)))#4 in [max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false)))#3]
```
after the fix this error is gone.

### How was this patch tested?

Added new test cases to `SubexpressionEliminationSuite` for the immediate issue, and to `DataFrameAggregateSuite` for an example of user-visible symptom.

Closes apache#40473 from rednaxelafx/spark-42851.

Authored-by: Kris Mok <kris.mok@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
rednaxelafx authored and cloud-fan committed Mar 21, 2023
1 parent c9a530e commit ef0a76e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ class EquivalentExpressions {
* Returns true if there was already a matching expression.
*/
def addExpr(expr: Expression): Boolean = {
updateExprInMap(expr, equivalenceMap)
if (supportedExpression(expr)) {
updateExprInMap(expr, equivalenceMap)
} else {
false
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType}
import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, ObjectType}

class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Semantic equals and hash") {
Expand Down Expand Up @@ -449,6 +449,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
assert(e2.getCommonSubexpressions.size == 1)
assert(e2.getCommonSubexpressions.head == add)
}

test("SPARK-42851: Handle supportExpression consistently across add and get") {
val expr = {
val function = (lambda: Expression) => Add(lambda, Literal(1))
val elementType = IntegerType
val colClass = classOf[Array[Int]]
val inputType = ObjectType(colClass)
val inputObject = BoundReference(0, inputType, nullable = true)
objects.MapObjects(function, inputObject, elementType, true, Option(colClass))
}
val equivalence = new EquivalentExpressions
equivalence.addExpr(expr)
val hasMatching = equivalence.addExpr(expr)
val cseState = equivalence.getExprState(expr)
assert(hasMatching == cseState.isDefined)
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1538,6 +1538,13 @@ class DataFrameAggregateSuite extends QueryTest
)
checkAnswer(res, Row(1, 1, 1) :: Row(4, 1, 2) :: Nil)
}

test("SPARK-42851: common subexpression should consistently handle aggregate and result exprs") {
val res = sql(
"select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2)"
)
checkAnswer(res, Row(Array(1), Array(1)))
}
}

case class B(c: Option[Double])
Expand Down

0 comments on commit ef0a76e

Please sign in to comment.