Skip to content

Commit ee953d2

Browse files
committed
[SPARK-28201][SQL] Revisit MakeDecimal behavior on overflow
1 parent facf9c3 commit ee953d2

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
22+
import org.apache.spark.sql.internal.SQLConf
2223
import org.apache.spark.sql.types._
2324

2425
/**
@@ -46,19 +47,35 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
4647
*/
4748
case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
4849

50+
private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow
51+
private lazy val doEval = if (nullOnOverflow) {
52+
input: Long => new Decimal().setOrNull(input, precision, scale)
53+
} else {
54+
input: Long => new Decimal().set(input, precision, scale)
55+
}
56+
4957
override def dataType: DataType = DecimalType(precision, scale)
50-
override def nullable: Boolean = true
58+
override def nullable: Boolean = child.nullable || nullOnOverflow
5159
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
5260

53-
protected override def nullSafeEval(input: Any): Any =
54-
Decimal(input.asInstanceOf[Long], precision, scale)
61+
protected override def nullSafeEval(input: Any): Any = doEval(input.asInstanceOf[Long])
5562

5663
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
5764
nullSafeCodeGen(ctx, ev, eval => {
65+
val setMethod = if (nullOnOverflow) {
66+
"setOrNull"
67+
} else {
68+
"set"
69+
}
70+
val setNull = if (nullable) {
71+
s"${ev.isNull} = ${ev.value} == null;"
72+
} else {
73+
""
74+
}
5875
s"""
59-
${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale);
60-
${ev.isNull} = ${ev.value} == null;
61-
"""
76+
|${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale);
77+
|$setNull
78+
|""".stripMargin
6279
})
6380
}
6481
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
7676
*/
7777
def set(unscaled: Long, precision: Int, scale: Int): Decimal = {
7878
if (setOrNull(unscaled, precision, scale) == null) {
79-
throw new IllegalArgumentException("Unscaled value too large for precision")
79+
throw new ArithmeticException("Unscaled value too large for precision")
8080
}
8181
this
8282
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.internal.SQLConf
2122
import org.apache.spark.sql.types.{Decimal, DecimalType, LongType}
2223

2324
class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -31,8 +32,24 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
3132
}
3233

3334
test("MakeDecimal") {
34-
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
35-
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
35+
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
36+
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
37+
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
38+
val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
39+
checkEvaluation(overflowExpr, null)
40+
checkEvaluationWithMutableProjection(overflowExpr, null)
41+
evaluateWithoutCodegen(overflowExpr, null)
42+
checkEvaluationWithUnsafeProjection(overflowExpr, null)
43+
}
44+
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
45+
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
46+
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
47+
val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
48+
intercept[ArithmeticException](checkEvaluationWithMutableProjection(overflowExpr, null))
49+
intercept[ArithmeticException](evaluateWithoutCodegen(overflowExpr, null))
50+
intercept[ArithmeticException](checkEvaluationWithUnsafeProjection(overflowExpr, null))
51+
}
52+
3653
}
3754

3855
test("PromotePrecision") {

0 commit comments

Comments
 (0)