Skip to content

Commit f0c87dc

Browse files
adrian-wangmarmbrus
authored andcommitted
[SPARK-3363][SQL] Type Coercion should promote null to all other types.
Type Coercion should support every type to have null value Author: Daoyuan Wang <daoyuan.wang@intel.com> Author: Michael Armbrust <michael@databricks.com> Closes #2246 from adrian-wang/spark3363-0 and squashes the following commits: c6241de [Daoyuan Wang] minor code clean 595b417 [Daoyuan Wang] Merge pull request #2 from marmbrus/pr/2246 832e640 [Michael Armbrust] reduce code duplication ef6f986 [Daoyuan Wang] make double boolean miss in jsonRDD compatibleType c619f0a [Daoyuan Wang] Type Coercion should support every type to have null value
1 parent a028330 commit f0c87dc

File tree

3 files changed

+67
-54
lines changed

3 files changed

+67
-54
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,22 @@ object HiveTypeCoercion {
2626
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
2727
// The conversion for integral and floating point types have a linear widening hierarchy:
2828
val numericPrecedence =
29-
Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
30-
// Boolean is only wider than Void
31-
val booleanPrecedence = Seq(NullType, BooleanType)
32-
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil
29+
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
30+
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil
31+
32+
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
33+
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
34+
if (valueTypes.distinct.size > 1) {
35+
// Try and find a promotion rule that contains both types in question.
36+
val applicableConversion =
37+
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
38+
39+
// If found return the widest common type, otherwise None
40+
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
41+
} else {
42+
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
43+
}
44+
}
3345
}
3446

3547
/**
@@ -53,17 +65,6 @@ trait HiveTypeCoercion {
5365
Division ::
5466
Nil
5567

56-
trait TypeWidening {
57-
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
58-
// Try and find a promotion rule that contains both types in question.
59-
val applicableConversion =
60-
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
61-
62-
// If found return the widest common type, otherwise None
63-
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
64-
}
65-
}
66-
6768
/**
6869
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
6970
* instances higher in the query tree.
@@ -144,7 +145,8 @@ trait HiveTypeCoercion {
144145
* - LongType to FloatType
145146
* - LongType to DoubleType
146147
*/
147-
object WidenTypes extends Rule[LogicalPlan] with TypeWidening {
148+
object WidenTypes extends Rule[LogicalPlan] {
149+
import HiveTypeCoercion._
148150

149151
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
150152
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
@@ -352,7 +354,9 @@ trait HiveTypeCoercion {
352354
/**
353355
* Coerces the type of different branches of a CASE WHEN statement to a common type.
354356
*/
355-
object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening {
357+
object CaseWhenCoercion extends Rule[LogicalPlan] {
358+
import HiveTypeCoercion._
359+
356360
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
357361
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
358362
val valueTypes = branches.sliding(2, 2).map {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,20 @@ import org.apache.spark.sql.catalyst.types._
2323

2424
class HiveTypeCoercionSuite extends FunSuite {
2525

26-
val rules = new HiveTypeCoercion { }
27-
import rules._
28-
29-
test("tightest common bound for numeric and boolean types") {
26+
test("tightest common bound for types") {
3027
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
31-
var found = WidenTypes.findTightestCommonType(t1, t2)
28+
var found = HiveTypeCoercion.findTightestCommonType(t1, t2)
3229
assert(found == tightestCommon,
3330
s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
3431
// Test both directions to make sure the widening is symmetric.
35-
found = WidenTypes.findTightestCommonType(t2, t1)
32+
found = HiveTypeCoercion.findTightestCommonType(t2, t1)
3633
assert(found == tightestCommon,
3734
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
3835
}
3936

37+
// Null
38+
widenTest(NullType, NullType, Some(NullType))
39+
4040
// Boolean
4141
widenTest(NullType, BooleanType, Some(BooleanType))
4242
widenTest(BooleanType, BooleanType, Some(BooleanType))
@@ -60,12 +60,28 @@ class HiveTypeCoercionSuite extends FunSuite {
6060
widenTest(DoubleType, DoubleType, Some(DoubleType))
6161

6262
// Integral mixed with floating point.
63-
widenTest(NullType, FloatType, Some(FloatType))
64-
widenTest(NullType, DoubleType, Some(DoubleType))
6563
widenTest(IntegerType, FloatType, Some(FloatType))
6664
widenTest(IntegerType, DoubleType, Some(DoubleType))
6765
widenTest(IntegerType, DoubleType, Some(DoubleType))
6866
widenTest(LongType, FloatType, Some(FloatType))
6967
widenTest(LongType, DoubleType, Some(DoubleType))
68+
69+
// StringType
70+
widenTest(NullType, StringType, Some(StringType))
71+
widenTest(StringType, StringType, Some(StringType))
72+
widenTest(IntegerType, StringType, None)
73+
widenTest(LongType, StringType, None)
74+
75+
// TimestampType
76+
widenTest(NullType, TimestampType, Some(TimestampType))
77+
widenTest(TimestampType, TimestampType, Some(TimestampType))
78+
widenTest(IntegerType, TimestampType, None)
79+
widenTest(StringType, TimestampType, None)
80+
81+
// ComplexType
82+
widenTest(NullType, MapType(IntegerType, StringType, false), Some(MapType(IntegerType, StringType, false)))
83+
widenTest(NullType, StructType(Seq()), Some(StructType(Seq())))
84+
widenTest(StringType, MapType(IntegerType, StringType, true), None)
85+
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
7086
}
7187
}

sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -125,38 +125,31 @@ private[sql] object JsonRDD extends Logging {
125125
* Returns the most general data type for two given data types.
126126
*/
127127
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
128-
// Try and find a promotion rule that contains both types in question.
129-
val applicableConversion = HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p
130-
.contains(t2))
131-
132-
// If found return the widest common type, otherwise None
133-
val returnType = applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
134-
135-
if (returnType.isDefined) {
136-
returnType.get
137-
} else {
138-
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
139-
(t1, t2) match {
140-
case (other: DataType, NullType) => other
141-
case (NullType, other: DataType) => other
142-
case (StructType(fields1), StructType(fields2)) => {
143-
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
144-
case (name, fieldTypes) => {
145-
val dataType = fieldTypes.map(field => field.dataType).reduce(
146-
(type1: DataType, type2: DataType) => compatibleType(type1, type2))
147-
StructField(name, dataType, true)
128+
HiveTypeCoercion.findTightestCommonType(t1, t2) match {
129+
case Some(commonType) => commonType
130+
case None =>
131+
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
132+
(t1, t2) match {
133+
case (other: DataType, NullType) => other
134+
case (NullType, other: DataType) => other
135+
case (StructType(fields1), StructType(fields2)) => {
136+
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
137+
case (name, fieldTypes) => {
138+
val dataType = fieldTypes.map(field => field.dataType).reduce(
139+
(type1: DataType, type2: DataType) => compatibleType(type1, type2))
140+
StructField(name, dataType, true)
141+
}
148142
}
143+
StructType(newFields.toSeq.sortBy {
144+
case StructField(name, _, _) => name
145+
})
149146
}
150-
StructType(newFields.toSeq.sortBy {
151-
case StructField(name, _, _) => name
152-
})
147+
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
148+
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
149+
// TODO: We should use JsonObjectStringType to mark that values of field will be
150+
// strings and every string is a Json object.
151+
case (_, _) => StringType
153152
}
154-
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
155-
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
156-
// TODO: We should use JsonObjectStringType to mark that values of field will be
157-
// strings and every string is a Json object.
158-
case (_, _) => StringType
159-
}
160153
}
161154
}
162155

0 commit comments

Comments
 (0)