Skip to content

Commit c619f0a

Browse files
committed
Type Coercion should support every type to have null value
1 parent e70aff6 commit c619f0a

File tree

2 files changed

+35
-13
lines changed

2 files changed

+35
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@ object HiveTypeCoercion {
2626
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
2727
// The conversion for integral and floating point types have a linear widening hierarchy:
2828
val numericPrecedence =
29-
Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
30-
// Boolean is only wider than Void
31-
val booleanPrecedence = Seq(NullType, BooleanType)
32-
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil
29+
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
30+
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil
3331
}
3432

3533
/**
@@ -55,12 +53,17 @@ trait HiveTypeCoercion {
5553

5654
trait TypeWidening {
5755
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
58-
// Try and find a promotion rule that contains both types in question.
59-
val applicableConversion =
60-
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
61-
62-
// If found return the widest common type, otherwise None
63-
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
56+
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
57+
if (valueTypes.distinct.size > 1) {
58+
// Try and find a promotion rule that contains both types in question.
59+
val applicableConversion =
60+
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
61+
62+
// If found return the widest common type, otherwise None
63+
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
64+
} else {
65+
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
66+
}
6467
}
6568
}
6669

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class HiveTypeCoercionSuite extends FunSuite {
2626
val rules = new HiveTypeCoercion { }
2727
import rules._
2828

29-
test("tightest common bound for numeric and boolean types") {
29+
test("tightest common bound for types") {
3030
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
3131
var found = WidenTypes.findTightestCommonType(t1, t2)
3232
assert(found == tightestCommon,
@@ -37,6 +37,9 @@ class HiveTypeCoercionSuite extends FunSuite {
3737
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
3838
}
3939

40+
// Null
41+
widenTest(NullType, NullType, Some(NullType))
42+
4043
// Boolean
4144
widenTest(NullType, BooleanType, Some(BooleanType))
4245
widenTest(BooleanType, BooleanType, Some(BooleanType))
@@ -60,12 +63,28 @@ class HiveTypeCoercionSuite extends FunSuite {
6063
widenTest(DoubleType, DoubleType, Some(DoubleType))
6164

6265
// Integral mixed with floating point.
63-
widenTest(NullType, FloatType, Some(FloatType))
64-
widenTest(NullType, DoubleType, Some(DoubleType))
6566
widenTest(IntegerType, FloatType, Some(FloatType))
6667
widenTest(IntegerType, DoubleType, Some(DoubleType))
6768
widenTest(IntegerType, DoubleType, Some(DoubleType))
6869
widenTest(LongType, FloatType, Some(FloatType))
6970
widenTest(LongType, DoubleType, Some(DoubleType))
71+
72+
// StringType
73+
widenTest(NullType, StringType, Some(StringType))
74+
widenTest(StringType, StringType, Some(StringType))
75+
widenTest(IntegerType, StringType, None)
76+
widenTest(LongType, StringType, None)
77+
78+
// TimestampType
79+
widenTest(NullType, TimestampType, Some(TimestampType))
80+
widenTest(TimestampType, TimestampType, Some(TimestampType))
81+
widenTest(IntegerType, TimestampType, None)
82+
widenTest(StringType, TimestampType, None)
83+
84+
// ComplexType
85+
widenTest(NullType, MapType(IntegerType, StringType, false), Some(MapType(IntegerType, StringType, false)))
86+
widenTest(NullType, StructType(Seq()), Some(StructType(Seq())))
87+
widenTest(StringType, MapType(IntegerType, StringType, true), None)
88+
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
7089
}
7190
}

0 commit comments

Comments
 (0)