Skip to content

Commit 512fb32

Browse files
luluortacloud-fan
authored andcommitted
[SPARK-26218][SQL][FOLLOW UP] Fix the corner case of codegen when casting float to Integer
### What changes were proposed in this pull request? This is a followup of [#27151](#27151). It fixes the same issue for the codegen path. ### Why are the changes needed? Result corrupt. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added Unit test. Closes #30585 from luluorta/SPARK-26218. Authored-by: luluorta <luluorta@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent ff13f57 commit 512fb32

File tree

2 files changed

+24
-33
lines changed
  • sql/catalyst/src
    • main/scala/org/apache/spark/sql/catalyst/expressions
    • test/scala/org/apache/spark/sql/catalyst/expressions

2 files changed

+24
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,25 +1393,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
13931393
"""
13941394
}
13951395

1396-
private[this] def lowerAndUpperBound(
1397-
fractionType: String,
1398-
integralType: String): (String, String) = {
1399-
assert(fractionType == "float" || fractionType == "double")
1400-
val typeIndicator = fractionType.charAt(0)
1401-
val (min, max) = integralType.toLowerCase(Locale.ROOT) match {
1402-
case "long" => (Long.MinValue, Long.MaxValue)
1403-
case "int" => (Int.MinValue, Int.MaxValue)
1404-
case "short" => (Short.MinValue, Short.MaxValue)
1405-
case "byte" => (Byte.MinValue, Byte.MaxValue)
1396+
private[this] def lowerAndUpperBound(integralType: String): (String, String) = {
1397+
val (min, max, typeIndicator) = integralType.toLowerCase(Locale.ROOT) match {
1398+
case "long" => (Long.MinValue, Long.MaxValue, "L")
1399+
case "int" => (Int.MinValue, Int.MaxValue, "")
1400+
case "short" => (Short.MinValue, Short.MaxValue, "")
1401+
case "byte" => (Byte.MinValue, Byte.MaxValue, "")
14061402
}
14071403
(min.toString + typeIndicator, max.toString + typeIndicator)
14081404
}
14091405

1410-
private[this] def castFractionToIntegralTypeCode(
1411-
fractionType: String,
1412-
integralType: String): CastFunction = {
1406+
private[this] def castFractionToIntegralTypeCode(integralType: String): CastFunction = {
14131407
assert(ansiEnabled)
1414-
val (min, max) = lowerAndUpperBound(fractionType, integralType)
1408+
val (min, max) = lowerAndUpperBound(integralType)
14151409
val mathClass = classOf[Math].getName
14161410
// When casting floating values to integral types, Spark uses the method `Numeric.toInt`
14171411
// Or `Numeric.toLong` directly. For positive floating values, it is equivalent to `Math.floor`;
@@ -1449,12 +1443,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
14491443
(c, evPrim, evNull) => code"$evNull = true;"
14501444
case TimestampType => castTimestampToIntegralTypeCode(ctx, "byte")
14511445
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "byte")
1452-
case _: ShortType | _: IntegerType | _: LongType if ansiEnabled =>
1446+
case ShortType | IntegerType | LongType if ansiEnabled =>
14531447
castIntegralTypeToIntegralTypeExactCode("byte")
1454-
case _: FloatType if ansiEnabled =>
1455-
castFractionToIntegralTypeCode("float", "byte")
1456-
case _: DoubleType if ansiEnabled =>
1457-
castFractionToIntegralTypeCode("double", "byte")
1448+
case FloatType | DoubleType if ansiEnabled =>
1449+
castFractionToIntegralTypeCode("byte")
14581450
case x: NumericType =>
14591451
(c, evPrim, evNull) => code"$evPrim = (byte) $c;"
14601452
}
@@ -1482,12 +1474,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
14821474
(c, evPrim, evNull) => code"$evNull = true;"
14831475
case TimestampType => castTimestampToIntegralTypeCode(ctx, "short")
14841476
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "short")
1485-
case _: IntegerType | _: LongType if ansiEnabled =>
1477+
case IntegerType | LongType if ansiEnabled =>
14861478
castIntegralTypeToIntegralTypeExactCode("short")
1487-
case _: FloatType if ansiEnabled =>
1488-
castFractionToIntegralTypeCode("float", "short")
1489-
case _: DoubleType if ansiEnabled =>
1490-
castFractionToIntegralTypeCode("double", "short")
1479+
case FloatType | DoubleType if ansiEnabled =>
1480+
castFractionToIntegralTypeCode("short")
14911481
case x: NumericType =>
14921482
(c, evPrim, evNull) => code"$evPrim = (short) $c;"
14931483
}
@@ -1513,11 +1503,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
15131503
(c, evPrim, evNull) => code"$evNull = true;"
15141504
case TimestampType => castTimestampToIntegralTypeCode(ctx, "int")
15151505
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "int")
1516-
case _: LongType if ansiEnabled => castIntegralTypeToIntegralTypeExactCode("int")
1517-
case _: FloatType if ansiEnabled =>
1518-
castFractionToIntegralTypeCode("float", "int")
1519-
case _: DoubleType if ansiEnabled =>
1520-
castFractionToIntegralTypeCode("double", "int")
1506+
case LongType if ansiEnabled => castIntegralTypeToIntegralTypeExactCode("int")
1507+
case FloatType | DoubleType if ansiEnabled =>
1508+
castFractionToIntegralTypeCode("int")
15211509
case x: NumericType =>
15221510
(c, evPrim, evNull) => code"$evPrim = (int) $c;"
15231511
}
@@ -1544,10 +1532,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
15441532
case TimestampType =>
15451533
(c, evPrim, evNull) => code"$evPrim = (long) ${timestampToLongCode(c)};"
15461534
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "long")
1547-
case _: FloatType if ansiEnabled =>
1548-
castFractionToIntegralTypeCode("float", "long")
1549-
case _: DoubleType if ansiEnabled =>
1550-
castFractionToIntegralTypeCode("double", "long")
1535+
case FloatType | DoubleType if ansiEnabled =>
1536+
castFractionToIntegralTypeCode("long")
15511537
case x: NumericType =>
15521538
(c, evPrim, evNull) => code"$evPrim = (long) $c;"
15531539
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,11 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
975975
}
976976
}
977977
}
978+
979+
test("SPARK-26218: Fix the corner case of codegen when casting float to Integer") {
980+
checkExceptionInExpression[ArithmeticException](
981+
cast(cast(Literal("2147483648"), FloatType), IntegerType), "overflow")
982+
}
978983
}
979984

980985
/**

0 commit comments

Comments
 (0)