Skip to content

[SPARK-40903][SQL] Avoid reordering decimal Add for canonicalization if data type is changed #38379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

gengliangwang
Copy link
Member

@gengliangwang gengliangwang commented Oct 24, 2022

What changes were proposed in this pull request?

Avoid reordering Add for canonicalizing if it is decimal type and the result data type is changed.
Expressions are canonicalized for comparisons and explanations. For non-decimal Add expression, the order can be sorted by hashcode, and the result is supposed to be the same.
However, for Add expression of Decimal type, the behavior is different: Given decimal (p1, s1) and another decimal (p2, s2), the result integral part is max(p1-s1, p2-s2) +1, the result decimal part is max(s1, s2). Thus the result data type is (max(p1-s1, p2-s2) +1 + max(s1, s2), max(s1, s2)).
Thus the order matters:

For (decimal(12,5) + decimal(12,6)) + decimal(3, 2), the first add decimal(12,5) + decimal(12,6) results in decimal(14, 6), and then decimal(14, 6) + decimal(3, 2)  results in decimal(15, 6)
For (decimal(12, 6) + decimal(3,2)) + decimal(12, 5), the first add decimal(12, 6) + decimal(3,2) results in decimal(13, 6), and then decimal(13, 6) + decimal(12, 5) results in decimal(14, 6)

In the following query:

create table foo(a decimal(12, 5), b decimal(12, 6)) using orc
select sum(coalesce(a+b+1.75, a)) from foo

At first coalesce(a+b+ 1.75, a) is resolved as coalesce(a+b+ 1.75, cast(a as decimal(15, 6)). In the canonicalized version, the expression becomes coalesce(1.75+b+a, cast(a as decimal(15, 6)). As explained above, 1.75+b+a is of decimal(14, 6), which is different from  cast(a as decimal(15, 6). Thus the following error will happen:

java.lang.IllegalArgumentException: requirement failed: All input types must be the same except nullable, containsNull, valueContainsNull flags. The input types found are
	DecimalType(14,6)
	DecimalType(15,6)
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.sql.catalyst.expressions.ComplexTypeMergingExpression.dataTypeCheck(Expression.scala:1149)
	at org.apache.spark.sql.catalyst.expressions.ComplexTypeMergingExpression.dataTypeCheck$(Expression.scala:1143) 

This PR is to fix the bug.

Why are the changes needed?

Bug fix

Does this PR introduce any user-facing change?

No

How was this patch tested?

A new test case

@gengliangwang gengliangwang changed the title fix decimal add's canonicalized [SPARK-40903][SQL] Avoid reordering decimal Add for canonicalization Oct 24, 2022
@gengliangwang
Copy link
Member Author

cc @peter-toth @ulysses-you

@github-actions github-actions bot added the SQL label Oct 24, 2022
@gengliangwang
Copy link
Member Author

I confirmed that the regression is caused by the refactoring PR #36698. Before the refactor, the query will look like

coalesce(CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(a#0 as decimal(14,6))) + promote_precision(cast(b#1 as decimal(14,6)))), DecimalType(14,6), true) as decimal(15,6))) + promote_precision(cast(1.75 as decimal(15,6)))), DecimalType(15,6), true), a#0)

All the children of Add are cast as the final data type. Thus reordering Add for canonicalization won’t matter.

@@ -477,7 +477,10 @@ case class Add(
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Add =
copy(left = newLeft, right = newRight)

override lazy val canonicalized: Expression = {
override lazy val canonicalized: Expression = dataType match {
case _: DecimalType =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some comments to explain the reason?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you @gengliangwang for the catching.
can we make it more fine-grained ? Not all decimal add will fail, so we can check if we can reorder them safely. e.g., precision and scale in all left and right are same.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan comment added

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ulysses-you Are you sure about that? My concern is that if both left and right contains integer contains decimal Adds, the result may still be different after sorting all the sub Adds

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding an extra Cast into the canonicalized form if needed like:

  override lazy val canonicalized: Expression = {
    // TODO: do not reorder consecutive `Add`s with different `evalMode`
    val reordered = orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode))
    if (dataType != reordered.dataType) {
      Cast(reordered, dataType)
    } else {
      reordered
    }
  }

Copy link
Contributor

@peter-toth peter-toth Oct 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, maybe adding an extra Cast is not a good idea as the 2 expressions with different dataTypes shouldn't be considered equal, but if reordered's data type matches the original then why can't we reorder?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peter-toth The ideal solution would be adding extra casts in all the canonicalization of the children ofComplexTypeMergingExpression if the data type is changed. However, there are also overriding in some of the ComplexTypeMergingExpression.
So I would take your suggestion to reorder the Add if the result data type is not changed. Thank you.

orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode))
val reorderResult =
orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode))
if (resolved && reorderResult.resolved && reorderResult.dataType == dataType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a big concern but the cost of re-calculate the data type. I'm fine with this

@gengliangwang
Copy link
Member Author

Merging to master. @cloud-fan @peter-toth @ulysses-you thanks for the review.

@gengliangwang gengliangwang changed the title [SPARK-40903][SQL] Avoid reordering decimal Add for canonicalization [SPARK-40903][SQL] Avoid reordering decimal Add for canonicalization if data type is changed Oct 26, 2022
SandishKumarHN pushed a commit to SandishKumarHN/spark that referenced this pull request Dec 12, 2022
…if data type is changed

### What changes were proposed in this pull request?

Avoid reordering Add for canonicalizing if it is decimal type and the result data type is changed.
Expressions are canonicalized for comparisons and explanations. For non-decimal Add expression, the order can be sorted by hashcode, and the result is supposed to be the same.
However, for Add expression of Decimal type, the behavior is different: Given decimal (p1, s1) and another decimal (p2, s2), the result integral part is `max(p1-s1, p2-s2) +1`, the result decimal part is `max(s1, s2)`. Thus the result data type is `(max(p1-s1, p2-s2) +1 + max(s1, s2), max(s1, s2))`.
Thus the order matters:

For `(decimal(12,5) + decimal(12,6)) + decimal(3, 2)`, the first add `decimal(12,5) + decimal(12,6)` results in `decimal(14, 6)`, and then `decimal(14, 6) + decimal(3, 2)`  results in `decimal(15, 6)`
For `(decimal(12, 6) + decimal(3,2)) + decimal(12, 5)`, the first add `decimal(12, 6) + decimal(3,2)` results in `decimal(13, 6)`, and then `decimal(13, 6) + decimal(12, 5)` results in `decimal(14, 6)`

In the following query:
```
create table foo(a decimal(12, 5), b decimal(12, 6)) using orc
select sum(coalesce(a+b+ 1.75, a)) from foo
```
At first `coalesce(a+b+ 1.75, a)` is resolved as `coalesce(a+b+ 1.75, cast(a as decimal(15, 6))`. In the canonicalized version, the expression becomes `coalesce(1.75+b+a, cast(a as decimal(15, 6))`. As explained above, `1.75+b+a` is of decimal(14, 6), which is different from  `cast(a as decimal(15, 6)`. Thus the following error will happen:
```
java.lang.IllegalArgumentException: requirement failed: All input types must be the same except nullable, containsNull, valueContainsNull flags. The input types found are
	DecimalType(14,6)
	DecimalType(15,6)
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.sql.catalyst.expressions.ComplexTypeMergingExpression.dataTypeCheck(Expression.scala:1149)
	at org.apache.spark.sql.catalyst.expressions.ComplexTypeMergingExpression.dataTypeCheck$(Expression.scala:1143)
```
This PR is to fix the bug.
### Why are the changes needed?

Bug fix
### Does this PR introduce _any_ user-facing change?

No
### How was this patch tested?

A new test case

Closes apache#38379 from gengliangwang/fixDecimalAdd.

Lead-authored-by: Gengliang Wang <gengliang@apache.org>
Co-authored-by: Gengliang Wang <ltnwgl@gmail.com>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants