Skip to content

Commit 77e52fd

Browse files
kazuyukitanimuradongjoon-hyun
authored andcommitted
[SPARK-45786][SQL] Fix inaccurate Decimal multiplication and division results
### What changes were proposed in this pull request? This PR fixes inaccurate Decimal multiplication and division results. ### Why are the changes needed? Decimal multiplication and division results may be inaccurate due to rounding issues. #### Multiplication: ``` scala> sql("select -14120025096157587712113961295153.858047 * -0.4652").show(truncate=false) +----------------------------------------------------+ |(-14120025096157587712113961295153.858047 * -0.4652)| +----------------------------------------------------+ |6568635674732509803675414794505.574764 | +----------------------------------------------------+ ``` The correct answer is `6568635674732509803675414794505.574763` Please note that the last digit is `3` instead of `4` as ``` scala> java.math.BigDecimal("-14120025096157587712113961295153.858047").multiply(java.math.BigDecimal("-0.4652")) val res21: java.math.BigDecimal = 6568635674732509803675414794505.5747634644 ``` Since the factional part `.574763` is followed by `4644`, it should not be rounded up. #### Division: ``` scala> sql("select -0.172787979 / 533704665545018957788294905796.5").show(truncate=false) +-------------------------------------------------+ |(-0.172787979 / 533704665545018957788294905796.5)| +-------------------------------------------------+ |-3.237521E-31 | +-------------------------------------------------+ ``` The correct answer is `-3.237520E-31` Please note that the last digit is `0` instead of `1` as ``` scala> java.math.BigDecimal("-0.172787979").divide(java.math.BigDecimal("533704665545018957788294905796.5"), 100, java.math.RoundingMode.DOWN) val res22: java.math.BigDecimal = -3.237520489418037889998826491401059986665344697406144511563561222578738E-31 ``` Since the factional part `.237520` is followed by `4894...`, it should not be rounded up. ### Does this PR introduce _any_ user-facing change? Yes, users will see correct Decimal multiplication and division results. Directly multiplying and dividing with `org.apache.spark.sql.types.Decimal()` (not via SQL) will return 39 digit at maximum instead of 38 at maximum and round down instead of round half-up ### How was this patch tested? Test added ### Was this patch authored or co-authored using generative AI tooling? No Closes #43678 from kazuyukitanimura/SPARK-45786. Authored-by: Kazuyuki Tanimura <ktanimura@apple.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com> (cherry picked from commit 5ef3a84) Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
1 parent f916162 commit 77e52fd

File tree

3 files changed

+120
-9
lines changed

3 files changed

+120
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
497497

498498
def / (that: Decimal): Decimal =
499499
if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal,
500-
DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode))
500+
DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode))
501501

502502
def % (that: Decimal): Decimal =
503503
if (that.isZero) null
@@ -545,7 +545,11 @@ object Decimal {
545545

546546
val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong)
547547

548-
private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP)
548+
// SPARK-45786 Using RoundingMode.HALF_UP with MathContext may cause inaccurate SQL results
549+
// because TypeCoercion later rounds again. Instead, always round down and use 1 digit longer
550+
// precision than DecimalType.MAX_PRECISION. Then, TypeCoercion will properly round up/down
551+
// the last extra digit.
552+
private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.DOWN)
549553

550554
private[sql] val ZERO = Decimal(0)
551555
private[sql] val ONE = Decimal(1)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.math.RoundingMode
2021
import java.sql.{Date, Timestamp}
2122
import java.time.{Duration, Period}
2223
import java.time.temporal.ChronoUnit
@@ -225,6 +226,112 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
225226
}
226227
}
227228

229+
test("SPARK-45786: Decimal multiply, divide, remainder, quot") {
230+
// Some known cases
231+
checkEvaluation(
232+
Multiply(
233+
Literal(Decimal(BigDecimal("-14120025096157587712113961295153.858047"), 38, 6)),
234+
Literal(Decimal(BigDecimal("-0.4652"), 4, 4))
235+
),
236+
Decimal(BigDecimal("6568635674732509803675414794505.574763"))
237+
)
238+
checkEvaluation(
239+
Multiply(
240+
Literal(Decimal(BigDecimal("-240810500742726"), 15, 0)),
241+
Literal(Decimal(BigDecimal("-5677.6988688550027099967697071"), 29, 25))
242+
),
243+
Decimal(BigDecimal("1367249507675382200.164877854336665327"))
244+
)
245+
checkEvaluation(
246+
Divide(
247+
Literal(Decimal(BigDecimal("-0.172787979"), 9, 9)),
248+
Literal(Decimal(BigDecimal("533704665545018957788294905796.5"), 31, 1))
249+
),
250+
Decimal(BigDecimal("-3.237520E-31"))
251+
)
252+
checkEvaluation(
253+
Divide(
254+
Literal(Decimal(BigDecimal("-0.574302343618"), 12, 12)),
255+
Literal(Decimal(BigDecimal("-795826820326278835912868.106"), 27, 3))
256+
),
257+
Decimal(BigDecimal("7.21642358550E-25"))
258+
)
259+
260+
// Random tests
261+
val rand = scala.util.Random
262+
def makeNum(p: Int, s: Int): String = {
263+
val int1 = rand.nextLong()
264+
val int2 = rand.nextLong().abs
265+
val frac1 = rand.nextLong().abs
266+
val frac2 = rand.nextLong().abs
267+
s"$int1$int2".take(p - s + (int1 >>> 63).toInt) + "." + s"$frac1$frac2".take(s)
268+
}
269+
270+
(0 until 100).foreach { _ =>
271+
val p1 = rand.nextInt(38) + 1 // 1 <= p1 <= 38
272+
val s1 = rand.nextInt(p1 + 1) // 0 <= s1 <= p1
273+
val p2 = rand.nextInt(38) + 1
274+
val s2 = rand.nextInt(p2 + 1)
275+
276+
val n1 = makeNum(p1, s1)
277+
val n2 = makeNum(p2, s2)
278+
279+
val mulActual = Multiply(
280+
Literal(Decimal(BigDecimal(n1), p1, s1)),
281+
Literal(Decimal(BigDecimal(n2), p2, s2))
282+
)
283+
val mulExact = new java.math.BigDecimal(n1).multiply(new java.math.BigDecimal(n2))
284+
285+
val divActual = Divide(
286+
Literal(Decimal(BigDecimal(n1), p1, s1)),
287+
Literal(Decimal(BigDecimal(n2), p2, s2))
288+
)
289+
val divExact = new java.math.BigDecimal(n1)
290+
.divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN)
291+
292+
val remActual = Remainder(
293+
Literal(Decimal(BigDecimal(n1), p1, s1)),
294+
Literal(Decimal(BigDecimal(n2), p2, s2))
295+
)
296+
val remExact = new java.math.BigDecimal(n1).remainder(new java.math.BigDecimal(n2))
297+
298+
val quotActual = IntegralDivide(
299+
Literal(Decimal(BigDecimal(n1), p1, s1)),
300+
Literal(Decimal(BigDecimal(n2), p2, s2))
301+
)
302+
val quotExact =
303+
new java.math.BigDecimal(n1).divideToIntegralValue(new java.math.BigDecimal(n2))
304+
305+
Seq(true, false).foreach { allowPrecLoss =>
306+
withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss.toString) {
307+
val mulType = Multiply(null, null).resultDecimalType(p1, s1, p2, s2)
308+
val mulResult = Decimal(mulExact.setScale(mulType.scale, RoundingMode.HALF_UP))
309+
val mulExpected =
310+
if (mulResult.precision > DecimalType.MAX_PRECISION) null else mulResult
311+
checkEvaluation(mulActual, mulExpected)
312+
313+
val divType = Divide(null, null).resultDecimalType(p1, s1, p2, s2)
314+
val divResult = Decimal(divExact.setScale(divType.scale, RoundingMode.HALF_UP))
315+
val divExpected =
316+
if (divResult.precision > DecimalType.MAX_PRECISION) null else divResult
317+
checkEvaluation(divActual, divExpected)
318+
319+
val remType = Remainder(null, null).resultDecimalType(p1, s1, p2, s2)
320+
val remResult = Decimal(remExact.setScale(remType.scale, RoundingMode.HALF_UP))
321+
val remExpected =
322+
if (remResult.precision > DecimalType.MAX_PRECISION) null else remResult
323+
checkEvaluation(remActual, remExpected)
324+
325+
val quotType = IntegralDivide(null, null).resultDecimalType(p1, s1, p2, s2)
326+
val quotResult = Decimal(quotExact.setScale(quotType.scale, RoundingMode.HALF_UP))
327+
val quotExpected =
328+
if (quotResult.precision > DecimalType.MAX_PRECISION) null else quotResult
329+
checkEvaluation(quotActual, quotExpected.toLong)
330+
}
331+
}
332+
}
333+
}
334+
228335
private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = {
229336
testFunc(_.toDouble)
230337
testFunc(Decimal(_))

sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ org.apache.spark.SparkArithmeticException
155155
"config" : "\"spark.sql.ansi.enabled\"",
156156
"precision" : "38",
157157
"scale" : "6",
158-
"value" : "1000000000000000000000000000000000000.00000000000000000000000000000000000000"
158+
"value" : "1000000000000000000000000000000000000.000000000000000000000000000000000000000"
159159
},
160160
"queryContext" : [ {
161161
"objectType" : "",
@@ -204,7 +204,7 @@ org.apache.spark.SparkArithmeticException
204204
"config" : "\"spark.sql.ansi.enabled\"",
205205
"precision" : "38",
206206
"scale" : "6",
207-
"value" : "10123456789012345678901234567890123456.00000000000000000000000000000000000000"
207+
"value" : "10123456789012345678901234567890123456.000000000000000000000000000000000000000"
208208
},
209209
"queryContext" : [ {
210210
"objectType" : "",
@@ -229,7 +229,7 @@ org.apache.spark.SparkArithmeticException
229229
"config" : "\"spark.sql.ansi.enabled\"",
230230
"precision" : "38",
231231
"scale" : "6",
232-
"value" : "101234567890123456789012345678901234.56000000000000000000000000000000000000"
232+
"value" : "101234567890123456789012345678901234.560000000000000000000000000000000000000"
233233
},
234234
"queryContext" : [ {
235235
"objectType" : "",
@@ -254,7 +254,7 @@ org.apache.spark.SparkArithmeticException
254254
"config" : "\"spark.sql.ansi.enabled\"",
255255
"precision" : "38",
256256
"scale" : "6",
257-
"value" : "10123456789012345678901234567890123.45600000000000000000000000000000000000"
257+
"value" : "10123456789012345678901234567890123.456000000000000000000000000000000000000"
258258
},
259259
"queryContext" : [ {
260260
"objectType" : "",
@@ -279,7 +279,7 @@ org.apache.spark.SparkArithmeticException
279279
"config" : "\"spark.sql.ansi.enabled\"",
280280
"precision" : "38",
281281
"scale" : "6",
282-
"value" : "1012345678901234567890123456789012.34560000000000000000000000000000000000"
282+
"value" : "1012345678901234567890123456789012.345600000000000000000000000000000000000"
283283
},
284284
"queryContext" : [ {
285285
"objectType" : "",
@@ -304,7 +304,7 @@ org.apache.spark.SparkArithmeticException
304304
"config" : "\"spark.sql.ansi.enabled\"",
305305
"precision" : "38",
306306
"scale" : "6",
307-
"value" : "101234567890123456789012345678901.23456000000000000000000000000000000000"
307+
"value" : "101234567890123456789012345678901.234560000000000000000000000000000000000"
308308
},
309309
"queryContext" : [ {
310310
"objectType" : "",
@@ -337,7 +337,7 @@ org.apache.spark.SparkArithmeticException
337337
"config" : "\"spark.sql.ansi.enabled\"",
338338
"precision" : "38",
339339
"scale" : "6",
340-
"value" : "101234567890123456789012345678901.23456000000000000000000000000000000000"
340+
"value" : "101234567890123456789012345678901.234560000000000000000000000000000000000"
341341
},
342342
"queryContext" : [ {
343343
"objectType" : "",

0 commit comments

Comments
 (0)