Skip to content

[SPARK-40903][SQL] Avoid reordering decimal Add for canonicalization if data type is changed #38379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,13 @@ abstract class Expression extends TreeNode[Expression] {
* This means that the lazy `cannonicalized` is called and computed only on the root of the
* adjacent expressions.
*/
lazy val canonicalized: Expression = {
lazy val canonicalized: Expression = withCanonicalizedChildren

/**
* The default process of canonicalization. It is a one pass, bottum-up expression tree
* computation based oncanonicalizing children before canonicalizing the current node.
*/
final protected def withCanonicalizedChildren: Expression = {
val canonicalizedChildren = children.map(_.canonicalized)
withNewChildren(canonicalizedChildren)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,15 @@ case class Add(

override lazy val canonicalized: Expression = {
// TODO: do not reorder consecutive `Add`s with different `evalMode`
orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode))
val reorderResult =
orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode))
if (resolved && reorderResult.resolved && reorderResult.dataType == dataType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a big concern but the cost of re-calculate the data type. I'm fine with this

reorderResult
} else {
// SPARK-40903: Avoid reordering decimal Add for canonicalization if the result data type is
// changed, which may cause data checking error within ComplexTypeMergingExpression.
withCanonicalizedChildren
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.Range
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, LongType, StringType, StructField, StructType}

class CanonicalizeSuite extends SparkFunSuite {

Expand Down Expand Up @@ -187,7 +187,23 @@ class CanonicalizeSuite extends SparkFunSuite {
test("SPARK-40362: Commutative operator under BinaryComparison") {
Seq(EqualTo, EqualNullSafe, GreaterThan, LessThan, GreaterThanOrEqual, LessThanOrEqual)
.foreach { bc =>
assert(bc(Add($"a", $"b"), Literal(10)).semanticEquals(bc(Add($"b", $"a"), Literal(10))))
assert(bc(Multiply($"a", $"b"), Literal(10)).semanticEquals(
bc(Multiply($"b", $"a"), Literal(10))))
}
}

test("SPARK-40903: Only reorder decimal Add when the result data type is not changed") {
val d = Decimal(1.2)
val literal1 = Literal.create(d, DecimalType(2, 1))
val literal2 = Literal.create(d, DecimalType(2, 1))
val literal3 = Literal.create(d, DecimalType(3, 2))
assert(Add(literal1, literal2).semanticEquals(Add(literal2, literal1)))
assert(Add(Add(literal1, literal2), literal3).semanticEquals(
Add(Add(literal3, literal2), literal1)))

val literal4 = Literal.create(d, DecimalType(12, 5))
val literal5 = Literal.create(d, DecimalType(12, 6))
assert(!Add(Add(literal4, literal5), literal1).semanticEquals(
Add(Add(literal1, literal5), literal4)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4518,6 +4518,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}
}

test("SPARK-40903: Don't reorder Add for canonicalize if it is decimal type") {
val tableName = "decimalTable"
withTable(tableName) {
sql(s"create table $tableName(a decimal(12, 5), b decimal(12, 6)) using orc")
checkAnswer(sql(s"select sum(coalesce(a + b + 1.75, a)) from $tableName"), Row(null))
}
}
}

case class Foo(bar: Option[String])