Skip to content

Commit 832e640

Browse files
committed
reduce code duplication
1 parent ef6f986 commit 832e640

File tree

3 files changed

+43
-53
lines changed

3 files changed

+43
-53
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ object HiveTypeCoercion {
2828
val numericPrecedence =
2929
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
3030
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil
31+
32+
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
33+
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
34+
if (valueTypes.distinct.size > 1) {
35+
// Try and find a promotion rule that contains both types in question.
36+
val applicableConversion =
37+
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
38+
39+
// If found return the widest common type, otherwise None
40+
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
41+
} else {
42+
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
43+
}
44+
}
3145
}
3246

3347
/**
@@ -51,22 +65,6 @@ trait HiveTypeCoercion {
5165
Division ::
5266
Nil
5367

54-
trait TypeWidening {
55-
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
56-
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
57-
if (valueTypes.distinct.size > 1) {
58-
// Try and find a promotion rule that contains both types in question.
59-
val applicableConversion =
60-
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
61-
62-
// If found return the widest common type, otherwise None
63-
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
64-
} else {
65-
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
66-
}
67-
}
68-
}
69-
7068
/**
7169
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
7270
* instances higher in the query tree.
@@ -147,7 +145,8 @@ trait HiveTypeCoercion {
147145
* - LongType to FloatType
148146
* - LongType to DoubleType
149147
*/
150-
object WidenTypes extends Rule[LogicalPlan] with TypeWidening {
148+
object WidenTypes extends Rule[LogicalPlan] {
149+
import HiveTypeCoercion._
151150

152151
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
153152
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
@@ -343,7 +342,9 @@ trait HiveTypeCoercion {
343342
/**
344343
* Coerces the type of different branches of a CASE WHEN statement to a common type.
345344
*/
346-
object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening {
345+
object CaseWhenCoercion extends Rule[LogicalPlan] {
346+
import HiveTypeCoercion._
347+
347348
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
348349
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
349350
val valueTypes = branches.sliding(2, 2).map {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,13 @@ import org.apache.spark.sql.catalyst.types._
2323

2424
class HiveTypeCoercionSuite extends FunSuite {
2525

26-
val rules = new HiveTypeCoercion { }
27-
import rules._
28-
2926
test("tightest common bound for types") {
3027
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
31-
var found = WidenTypes.findTightestCommonType(t1, t2)
28+
var found = HiveTypeCoercion.findTightestCommonType(t1, t2)
3229
assert(found == tightestCommon,
3330
s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
3431
// Test both directions to make sure the widening is symmetric.
35-
found = WidenTypes.findTightestCommonType(t2, t1)
32+
found = HiveTypeCoercion.findTightestCommonType(t2, t1)
3633
assert(found == tightestCommon,
3734
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
3835
}

sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -125,39 +125,31 @@ private[sql] object JsonRDD extends Logging {
125125
* Returns the most general data type for two given data types.
126126
*/
127127
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
128-
// Try and find a promotion rule that contains both types in question.
129-
val applicableConversion = HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p
130-
.contains(t2))
131-
132-
// If found return the widest common type, otherwise None
133-
val returnType = applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
134-
135-
if (returnType.isDefined) {
136-
returnType.get
137-
} else {
138-
// t1 or t2 is a StructType, ArrayType, BooleanType, or an unexpected type.
139-
(t1, t2) match {
140-
case (other: DataType, NullType) => other
141-
case (NullType, other: DataType) => other
142-
case (StructType(fields1), StructType(fields2)) => {
143-
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
144-
case (name, fieldTypes) => {
145-
val dataType = fieldTypes.map(field => field.dataType).reduce(
146-
(type1: DataType, type2: DataType) => compatibleType(type1, type2))
147-
StructField(name, dataType, true)
128+
HiveTypeCoercion.findTightestCommonType(t1,t2) match {
129+
case Some(commonType) => commonType
130+
case None =>
131+
// t1 or t2 is a StructType, ArrayType, BooleanType, or an unexpected type.
132+
(t1, t2) match {
133+
case (other: DataType, NullType) => other
134+
case (NullType, other: DataType) => other
135+
case (StructType(fields1), StructType(fields2)) => {
136+
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
137+
case (name, fieldTypes) => {
138+
val dataType = fieldTypes.map(field => field.dataType).reduce(
139+
(type1: DataType, type2: DataType) => compatibleType(type1, type2))
140+
StructField(name, dataType, true)
141+
}
148142
}
143+
StructType(newFields.toSeq.sortBy {
144+
case StructField(name, _, _) => name
145+
})
149146
}
150-
StructType(newFields.toSeq.sortBy {
151-
case StructField(name, _, _) => name
152-
})
147+
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
148+
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
149+
// TODO: We should use JsonObjectStringType to mark that values of field will be
150+
// strings and every string is a Json object.
151+
case (_, _) => StringType
153152
}
154-
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
155-
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
156-
// TODO: We should use JsonObjectStringType to mark that values of field will be
157-
// strings and every string is a Json object.
158-
case (BooleanType, BooleanType) => BooleanType
159-
case (_, _) => StringType
160-
}
161153
}
162154
}
163155

0 commit comments

Comments
 (0)