Skip to content

[SPARK-28200][SQL] Decimal overflow handling in ExpressionEncoder #25016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, IsNull, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, CreateNamedStruct, Expression, IsNull, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

object SerializerBuildHelper {

private def nullOnOverflow: Boolean = SQLConf.get.decimalOperationsNullOnOverflow

def createSerializerForBoolean(inputObject: Expression): Expression = {
Invoke(inputObject, "booleanValue", BooleanType)
}
Expand Down Expand Up @@ -99,25 +102,25 @@ object SerializerBuildHelper {
}

def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
StaticInvoke(
CheckOverflow(StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil,
returnNullable = false)
returnNullable = false), DecimalType.SYSTEM_DEFAULT, nullOnOverflow)
}

def createSerializerForScalaBigDecimal(inputObject: Expression): Expression = {
createSerializerForJavaBigDecimal(inputObject)
}

def createSerializerForJavaBigInteger(inputObject: Expression): Expression = {
StaticInvoke(
CheckOverflow(StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
"apply",
inputObject :: Nil,
returnNullable = false)
returnNullable = false), DecimalType.BigIntDecimal, nullOnOverflow)
}

def createSerializerForScalaBigInt(inputObject: Expression): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ClosureCleaner
Expand Down Expand Up @@ -379,6 +380,78 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
assert(e.getMessage.contains("tuple with more than 22 elements are not supported"))
}

// Scala / Java big decimals ----------------------------------------------------------

encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18),
"scala decimal within precision/scale limit")
encodeDecodeTest(new java.math.BigDecimal(("9" * 20) + "." + "9" * 18),
"java decimal within precision/scale limit")

encodeDecodeTest(-BigDecimal(("9" * 20) + "." + "9" * 18),
"negative scala decimal within precision/scale limit")
encodeDecodeTest(new java.math.BigDecimal(("9" * 20) + "." + "9" * 18).negate,
"negative java decimal within precision/scale limit")

testOverflowingBigNumeric(BigDecimal("1" * 21), "scala big decimal")
testOverflowingBigNumeric(new java.math.BigDecimal("1" * 21), "java big decimal")

testOverflowingBigNumeric(-BigDecimal("1" * 21), "negative scala big decimal")
testOverflowingBigNumeric(new java.math.BigDecimal("1" * 21).negate, "negative java big decimal")

testOverflowingBigNumeric(BigDecimal(("1" * 21) + ".123"),
"scala big decimal with fractional part")
testOverflowingBigNumeric(new java.math.BigDecimal(("1" * 21) + ".123"),
"java big decimal with fractional part")

testOverflowingBigNumeric(BigDecimal(("1" * 21) + "." + "9999" * 100),
"scala big decimal with long fractional part")
testOverflowingBigNumeric(new java.math.BigDecimal(("1" * 21) + "." + "9999" * 100),
"java big decimal with long fractional part")

// Scala / Java big integers ----------------------------------------------------------

encodeDecodeTest(BigInt("9" * 38), "scala big integer within precision limit")
encodeDecodeTest(new BigInteger("9" * 38), "java big integer within precision limit")

encodeDecodeTest(-BigInt("9" * 38),
"negative scala big integer within precision limit")
encodeDecodeTest(new BigInteger("9" * 38).negate(),
"negative java big integer within precision limit")

testOverflowingBigNumeric(BigInt("1" * 39), "scala big int")
testOverflowingBigNumeric(new BigInteger("1" * 39), "java big integer")

testOverflowingBigNumeric(-BigInt("1" * 39), "negative scala big int")
testOverflowingBigNumeric(new BigInteger("1" * 39).negate, "negative java big integer")

testOverflowingBigNumeric(BigInt("9" * 100), "scala very large big int")
testOverflowingBigNumeric(new BigInteger("9" * 100), "java very big int")

private def testOverflowingBigNumeric[T: TypeTag](bigNumeric: T, testName: String): Unit = {
Seq(true, false).foreach { allowNullOnOverflow =>
testAndVerifyNotLeakingReflectionObjects(
s"overflowing $testName, allowNullOnOverflow=$allowNullOnOverflow") {
withSQLConf(
SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> allowNullOnOverflow.toString
) {
// Need to construct Encoder here rather than implicitly resolving it
// so that SQLConf changes are respected.
val encoder = ExpressionEncoder[T]()
if (allowNullOnOverflow) {
val convertedBack = encoder.resolveAndBind().fromRow(encoder.toRow(bigNumeric))
assert(convertedBack === null)
} else {
val e = intercept[RuntimeException] {
encoder.toRow(bigNumeric)
}
assert(e.getMessage.contains("Error while encoding"))
assert(e.getCause.getClass === classOf[ArithmeticException])
}
}
}
}
}

private def encodeDecodeTest[T : ExpressionEncoder](
input: T,
testName: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,32 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
assert(row.toSeq(schema).head == decimal)
}

test("SPARK-23179: RowEncoder should respect nullOnOverflow for decimals") {
val schema = new StructType().add("decimal", DecimalType.SYSTEM_DEFAULT)
testDecimalOverflow(schema, Row(BigDecimal("9" * 100)))
testDecimalOverflow(schema, Row(new java.math.BigDecimal("9" * 100)))
}

private def testDecimalOverflow(schema: StructType, row: Row): Unit = {
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
val encoder = RowEncoder(schema).resolveAndBind()
intercept[Exception] {
encoder.toRow(row)
} match {
case e: ArithmeticException =>
assert(e.getMessage.contains("cannot be represented as Decimal"))
case e: RuntimeException =>
assert(e.getCause.isInstanceOf[ArithmeticException])
assert(e.getCause.getMessage.contains("cannot be represented as Decimal"))
}
}

withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
val encoder = RowEncoder(schema).resolveAndBind()
assert(encoder.fromRow(encoder.toRow(row)).get(0) == null)
}
}

test("RowEncoder should preserve schema nullability") {
val schema = new StructType().add("int", IntegerType, nullable = false)
val encoder = RowEncoder(schema).resolveAndBind()
Expand Down