Skip to content

Conversation

skambha
Copy link
Contributor

@skambha skambha commented Feb 18, 2020

What changes were proposed in this pull request?

JIRA SPARK-28067: Wrong results are returned for aggregate sum with decimals with whole stage codegen enabled

Repro:
WholeStage enabled enabled -> Wrong results
WholeStage disabled -> Returns exception Decimal precision 39 exceeds max precision 38

Issues:

  1. Wrong results are returned which is bad
  2. Inconsistency between whole stage enabled and disabled.

Cause:
Sum does not take care of possibility of overflow for the intermediate steps. ie the updateExpressions and mergeExpressions.

This PR makes the following changes:

  • Throw exception if there is an decimal overflow when computing the sum.
  • This will be consistent with how Spark behaves when whole stage codegen is disabled.

Pros:

  • No wrong results
  • Consistent behavior between wholestage enabled and disabled
  • DB’s have similar behavior, there is precedence

Before Fix: - WRONG RESULTS

scala> val df = Seq(
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
df: org.apache.spark.sql.DataFrame = [decNum: decimal(38,18), intNum: int]

scala> val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(sum("decNum"))
df2: org.apache.spark.sql.DataFrame = [sum(decNum): decimal(38,18)]

scala> df2.show(40,false)
+---------------------------------------+                                       
|sum(decNum)                            |
+---------------------------------------+
|20000000000000000000.000000000000000000|
+---------------------------------------+

After fix:

scala> val df = Seq(
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
df: org.apache.spark.sql.DataFrame = [decNum: decimal(38,18), intNum: int]

scala> val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(sum("decNum"))
df2: org.apache.spark.sql.DataFrame = [sum(decNum): decimal(38,18)]

scala> df2.show(40,false)
20/02/18 13:36:19 ERROR Executor: Exception in task 1.0 in stage 1.0 (TID 9)    
java.lang.ArithmeticException: Decimal(expanded,100000000000000000000.000000000000000000,39,18}) cannot be represented as Decimal(38, 18).

Why are the changes needed?

The changes are needed in order to fix the wrong results that are returned for decimal aggregate sum.

Does this PR introduce any user-facing change?

Prior to this fix, user would see wrong results on aggregate sum that involved decimal overflow, but now the user will see exception. This behavior is consistent as well with how Spark behaves when whole stage code gen is disabled.

How was this patch tested?

New test has been added and existing tests for sql, catalyst and hive suites were run ok.

…overflow, throw exception and make it consistent to when wholestage codegen is disabled. Also fix the affected test from spark-28224
@AmplabJenkins
Copy link

Can one of the admins verify this patch?

@skambha
Copy link
Contributor Author

skambha commented Feb 18, 2020

Please see my notes in this JIRA for the two approaches to fix this issue. This is a implementation for approach 1 fix. This is simple and straightforward compared to the approach2 PR.

I have another pr 27627 that takes approach 2 to fix this issue. Both these will fix the incorrect results (which is good). Each have their pros and cons as listed in my comment in the JIRA.

Copy link
Contributor Author

@skambha skambha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPARK-28224 took care of decimal overflow for sum only partially for 2 values. In this test case that was added as part of SPARK-28224, if you add another row into the dataset, you will get incorrect results and not return null on overflow.

In this PR we address decimal overflow in aggregate sum by throwing an exception. Hence this test has been modified.

Seq("true", "false").foreach { codegenEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegenEnabled)) {
val structDf = largeDecimals.select("a").agg(sum("a"))
if (!ansiEnabled) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPARK-28224 took care of decimal overflow for sum only partially for 2 values. In this test case that was added as part of SPARK-28224, if you add another row into the dataset, you will get incorrect results and not return null on overflow.

In this PR we address decimal overflow in aggregate sum by throwing an exception. Hence this test has been modified.

@HyukjinKwon
Copy link
Member

cc @mgaido91

@mgaido91
Copy link
Contributor

This PR would introduce regressions. Checking every sum means that temporary overflows would cause an exception. Eg., if you sum MAX_INT, 10, -100, then MAX_INT + 10 would cause the exception. In the current code, this sum is handled properly and returns the correct result, because the temporary overflow is fixed by summing -100. So we would raise exceptions even when not needed. IIRC, other DBs treat this properly, so temporary overflow don't cause exceptions.

The proper fix for this would be to use as buffer a larger data type than the returned one. I remember I had a PR for that (#25347). You can check the comments and history of it.

@skambha
Copy link
Contributor Author

skambha commented Feb 21, 2020

This PR would introduce regressions. Checking every sum means that temporary overflows would cause an exception. Eg., if you sum MAX_INT, 10, -100, then MAX_INT + 10 would cause the exception. In the current code, this sum is handled properly and returns the correct result, because the temporary overflow is fixed by summing -100. So we would raise exceptions even when not needed. IIRC, other DBs treat this properly, so temporary overflow don't cause exceptions.

I see what you are saying, but this PR is targeted to the Aggregate sum of the decimal type (result type is decimal type) only and not for int or long. Sum of ints is handled the same way as before and does not introduce any regressions for the above mentioned use case. [1]

This PR is trying to handle the use case regarding aggregate Sum for decimal:

  • Sum of decimal type overflows and returns wrong results.
  • Note, In the current code (without this PR also), the same operation of sum on decimal type will throw an exception when whole stage code gen is disabled.

(Furthermore, even if spark.sql.ansi.enabled is set to true, we do not return null. This conf property is to ensure that any overflows will return null.)

Here, we are dealing with a correctness issue. This pr's approach is to follow the result returned by the whole stage codegen disabled codepath.

Actually this issue is mentioned in PR/SPARK-23179 [3] as a special case. SPARK-28224 partially addressed this.

fwiw, I checked this on MS SQL Server and it throws an error as well. [2]

The proper fix for this would be to use as buffer a larger data type than the returned one. I remember I had a PR for that (#25347). You can check the comments and history of it.

Sure. I checked this (#25347), and this deals with increasing the datatype for the aggregate sum of long's to decimal to avoid temporary overflow. The decision was to not make the change because a) since it is not a correctness issue, and b) because of the performance hit and c) workaround exists - that if the user sees exception because of temporary overflow, they can cast it to a decimal. [4].

[1] —> SPARK-26218 Overflow on arithmetic operations returns incorrect result
[2] http://sqlfiddle.com/#!18/e7ecc/1
[3] —> SPARK-23179 Support option to throw exception if overflow occurs during Decimal arithmetic
[4] #25347 (comment)

Thanks for your comments.

@skambha
Copy link
Contributor Author

skambha commented Feb 21, 2020

@mgaido91, Since you worked on a lot of the overflow issues, if you can review the two approaches listed here in SPARK-28067 and add your thoughts, I'd appreciate it. Thanks.

@mgaido91
Copy link
Contributor

well' in this PR you are changing the logical plan, that's weird that the 2 executions mode return different results and we have to fix the plan for this.

@github-actions
Copy link

github-actions bot commented Jun 8, 2020

We're closing this PR because it hasn't been updated in a while. This isn't a judgement on the merit of the PR in any way. It's just a way of keeping the PR queue manageable.
If you'd like to revive this PR, please reopen it and ask a committer to remove the Stale tag!

@github-actions github-actions bot added the Stale label Jun 8, 2020
@skambha
Copy link
Contributor Author

skambha commented Jun 8, 2020

Closing this in favor of the other approach in #27627 which got merged into trunk.

@skambha skambha closed this Jun 8, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants