Skip to content

Commit 3139d64

Browse files
mgaido91cloud-fan
authored andcommitted
[SPARK-23179][SQL] Support option to throw exception if overflow occurs during Decimal arithmetic
## What changes were proposed in this pull request? SQL ANSI 2011 states that in case of overflow during arithmetic operations, an exception should be thrown. This is what most of the SQL DBs do (eg. SQLServer, DB2). Hive currently returns NULL (as Spark does) but HIVE-18291 is open to be SQL compliant. The PR introduce an option to decide which behavior Spark should follow, ie. returning NULL on overflow or throwing an exception. ## How was this patch tested? added UTs Closes #20350 from mgaido91/SPARK-23179. Authored-by: Marco Gaido <marcogaido91@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 7cbe01e commit 3139d64

File tree

10 files changed

+223
-37
lines changed

10 files changed

+223
-37
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ object DecimalPrecision extends TypeCoercionRule {
8282
PromotePrecision(Cast(e, dataType))
8383
}
8484

85+
private def nullOnOverflow: Boolean = SQLConf.get.decimalOperationsNullOnOverflow
86+
8587
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
8688
// fix decimal precision for expressions
8789
case q => q.transformExpressionsUp(
@@ -105,7 +107,7 @@ object DecimalPrecision extends TypeCoercionRule {
105107
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
106108
}
107109
CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
108-
resultType)
110+
resultType, nullOnOverflow)
109111

110112
case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
111113
val resultScale = max(s1, s2)
@@ -116,7 +118,7 @@ object DecimalPrecision extends TypeCoercionRule {
116118
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
117119
}
118120
CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
119-
resultType)
121+
resultType, nullOnOverflow)
120122

121123
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
122124
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
@@ -126,7 +128,7 @@ object DecimalPrecision extends TypeCoercionRule {
126128
}
127129
val widerType = widerDecimalType(p1, s1, p2, s2)
128130
CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
129-
resultType)
131+
resultType, nullOnOverflow)
130132

131133
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
132134
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
@@ -148,7 +150,7 @@ object DecimalPrecision extends TypeCoercionRule {
148150
}
149151
val widerType = widerDecimalType(p1, s1, p2, s2)
150152
CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
151-
resultType)
153+
resultType, nullOnOverflow)
152154

153155
case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
154156
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
@@ -159,7 +161,7 @@ object DecimalPrecision extends TypeCoercionRule {
159161
// resultType may have lower precision, so we cast them into wider type first.
160162
val widerType = widerDecimalType(p1, s1, p2, s2)
161163
CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
162-
resultType)
164+
resultType, nullOnOverflow)
163165

164166
case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
165167
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
@@ -170,7 +172,7 @@ object DecimalPrecision extends TypeCoercionRule {
170172
// resultType may have lower precision, so we cast them into wider type first.
171173
val widerType = widerDecimalType(p1, s1, p2, s2)
172174
CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
173-
resultType)
175+
resultType, nullOnOverflow)
174176

175177
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
176178
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
236236
collect(left, negate) ++ collect(right, !negate)
237237
case UnaryMinus(child) =>
238238
collect(child, !negate)
239-
case CheckOverflow(child, _) =>
239+
case CheckOverflow(child, _, _) =>
240240
collect(child, negate)
241241
case PromotePrecision(child) =>
242242
collect(child, negate)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ object RowEncoder {
114114
d,
115115
"fromDecimal",
116116
inputObject :: Nil,
117-
returnNullable = false), d)
117+
returnNullable = false), d, SQLConf.get.decimalOperationsNullOnOverflow)
118118

119119
case StringType => createSerializerForString(inputObject)
120120

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,30 +81,34 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
8181

8282
/**
8383
* Rounds the decimal to given scale and check whether the decimal can fit in provided precision
84-
* or not, returns null if not.
84+
* or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an
85+
* `ArithmeticException` is thrown.
8586
*/
86-
case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression {
87+
case class CheckOverflow(
88+
child: Expression,
89+
dataType: DecimalType,
90+
nullOnOverflow: Boolean) extends UnaryExpression {
8791

8892
override def nullable: Boolean = true
8993

9094
override def nullSafeEval(input: Any): Any =
91-
input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale)
95+
input.asInstanceOf[Decimal].toPrecision(
96+
dataType.precision,
97+
dataType.scale,
98+
Decimal.ROUND_HALF_UP,
99+
nullOnOverflow)
92100

93101
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
94102
nullSafeCodeGen(ctx, ev, eval => {
95-
val tmp = ctx.freshName("tmp")
96103
s"""
97-
| Decimal $tmp = $eval.clone();
98-
| if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) {
99-
| ${ev.value} = $tmp;
100-
| } else {
101-
| ${ev.isNull} = true;
102-
| }
104+
|${ev.value} = $eval.toPrecision(
105+
| ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow);
106+
|${ev.isNull} = ${ev.value} == null;
103107
""".stripMargin
104108
})
105109
}
106110

107-
override def toString: String = s"CheckOverflow($child, $dataType)"
111+
override def toString: String = s"CheckOverflow($child, $dataType, $nullOnOverflow)"
108112

109113
override def sql: String = child.sql
110114
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,8 +1138,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
11381138
val evaluationCode = dataType match {
11391139
case DecimalType.Fixed(_, s) =>
11401140
s"""
1141-
${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr());
1142-
${ev.isNull} = ${ev.value} == null;"""
1141+
|${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s,
1142+
| Decimal.$modeStr(), true);
1143+
|${ev.isNull} = ${ev.value} == null;
1144+
""".stripMargin
11431145
case ByteType =>
11441146
if (_scale < 0) {
11451147
s"""

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,6 +1441,16 @@ object SQLConf {
14411441
.booleanConf
14421442
.createWithDefault(true)
14431443

1444+
val DECIMAL_OPERATIONS_NULL_ON_OVERFLOW =
1445+
buildConf("spark.sql.decimalOperations.nullOnOverflow")
1446+
.internal()
1447+
.doc("When true (default), if an overflow on a decimal occurs, then NULL is returned. " +
1448+
"Spark's older versions and Hive behave in this way. If turned to false, SQL ANSI 2011 " +
1449+
"specification will be followed instead: an arithmetic exception is thrown, as most " +
1450+
"of the SQL databases do.")
1451+
.booleanConf
1452+
.createWithDefault(true)
1453+
14441454
val LITERAL_PICK_MINIMUM_PRECISION =
14451455
buildConf("spark.sql.legacy.literal.pickMinimumPrecision")
14461456
.internal()
@@ -2205,6 +2215,8 @@ class SQLConf extends Serializable with Logging {
22052215

22062216
def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)
22072217

2218+
def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW)
2219+
22082220
def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)
22092221

22102222
def continuousStreamingEpochBacklogQueueSize: Int =

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,25 @@ final class Decimal extends Ordered[Decimal] with Serializable {
249249
/**
250250
* Create new `Decimal` with given precision and scale.
251251
*
252-
* @return a non-null `Decimal` value if successful or `null` if overflow would occur.
252+
* @return a non-null `Decimal` value if successful. Otherwise, if `nullOnOverflow` is true, null
253+
* is returned; if `nullOnOverflow` is false, an `ArithmeticException` is thrown.
253254
*/
254255
private[sql] def toPrecision(
255256
precision: Int,
256257
scale: Int,
257-
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = {
258+
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP,
259+
nullOnOverflow: Boolean = true): Decimal = {
258260
val copy = clone()
259-
if (copy.changePrecision(precision, scale, roundMode)) copy else null
261+
if (copy.changePrecision(precision, scale, roundMode)) {
262+
copy
263+
} else {
264+
if (nullOnOverflow) {
265+
null
266+
} else {
267+
throw new ArithmeticException(
268+
s"$toDebugString cannot be represented as Decimal($precision, $scale).")
269+
}
270+
}
260271
}
261272

262273
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,26 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
4545

4646
test("CheckOverflow") {
4747
val d1 = Decimal("10.1")
48-
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10"))
49-
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1)
50-
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1)
51-
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null)
48+
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0), true), Decimal("10"))
49+
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1), true), d1)
50+
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2), true), d1)
51+
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3), true), null)
52+
intercept[ArithmeticException](CheckOverflow(Literal(d1), DecimalType(4, 3), false).eval())
53+
intercept[ArithmeticException](checkEvaluationWithMutableProjection(
54+
CheckOverflow(Literal(d1), DecimalType(4, 3), false), null))
5255

5356
val d2 = Decimal(101, 3, 1)
54-
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10"))
55-
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2)
56-
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2)
57-
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null)
57+
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0), true), Decimal("10"))
58+
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1), true), d2)
59+
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2), true), d2)
60+
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3), true), null)
61+
intercept[ArithmeticException](CheckOverflow(Literal(d2), DecimalType(4, 3), false).eval())
62+
intercept[ArithmeticException](checkEvaluationWithMutableProjection(
63+
CheckOverflow(Literal(d2), DecimalType(4, 3), false), null))
5864

59-
checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null)
65+
checkEvaluation(CheckOverflow(
66+
Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), true), null)
67+
checkEvaluation(CheckOverflow(
68+
Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), false), null)
6069
}
61-
6270
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,28 @@ select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.1
8383
select 123456789123456789.1234567890 * 1.123456789123456789;
8484
select 12345678912345.123456789123 / 0.000000012345678;
8585

86+
-- throw an exception instead of returning NULL, according to SQL ANSI 2011
87+
set spark.sql.decimalOperations.nullOnOverflow=false;
88+
89+
-- test operations between decimals and constants
90+
select id, a*10, b/10 from decimals_test order by id;
91+
92+
-- test operations on constants
93+
select 10.3 * 3.0;
94+
select 10.3000 * 3.0;
95+
select 10.30000 * 30.0;
96+
select 10.300000000000000000 * 3.000000000000000000;
97+
select 10.300000000000000000 * 3.0000000000000000000;
98+
99+
-- arithmetic operations causing an overflow throw exception
100+
select (5e36 + 0.1) + 5e36;
101+
select (-4e36 - 0.1) - 7e36;
102+
select 12345678901234567890.0 * 12345678901234567890.0;
103+
select 1e35 / 0.1;
104+
105+
-- arithmetic operations causing a precision loss throw exception
106+
select 123456789123456789.1234567890 * 1.123456789123456789;
107+
select 123456789123456789.1234567890 * 1.123456789123456789;
108+
select 12345678912345.123456789123 / 0.000000012345678;
109+
86110
drop table decimals_test;

0 commit comments

Comments
 (0)