Skip to content

Commit 3344803

Browse files
ueshinmarmbrus
authored andcommitted
[SPARK-4293][SQL] Make Cast be able to handle complex types.
Inserting data of type including `ArrayType.containsNull == false` or `MapType.valueContainsNull == false` or `StructType.fields.exists(_.nullable == false)` into Hive table will fail because `Cast` inserted by `HiveMetastoreCatalog.PreInsertionCasts` rule of `Analyzer` can't handle these types correctly. Complex type cast rule proposal: - Cast for non-complex types should be able to cast the same as before. - Cast for `ArrayType` can evaluate if - Element type can cast - Nullability rule doesn't break - Cast for `MapType` can evaluate if - Key type can cast - Nullability for casted key type is `false` - Value type can cast - Nullability rule for value type doesn't break - Cast for `StructType` can evaluate if - The field size is the same - Each field can cast - Nullability rule for each field doesn't break - The nested structure should be the same. Nullability rule: - If the casted type is `nullable == true`, the target nullability should be `true` Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #3150 from ueshin/issues/SPARK-4293 and squashes the following commits: e935939 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4293 ba14003 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4293 8999868 [Takuya UESHIN] Fix a test title. f677c30 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4293 287f410 [Takuya UESHIN] Add tests to insert data of types ArrayType / MapType / StructType with nullability is false into Hive table. 4f71bb8 [Takuya UESHIN] Make Cast be able to handle complex types.
1 parent c152dde commit 3344803

File tree

2 files changed

+353
-44
lines changed

2 files changed

+353
-44
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 117 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,14 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
2727

2828
/** Cast the child expression to the target data type. */
2929
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
30+
31+
override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
32+
3033
override def foldable = child.foldable
3134

32-
override def nullable = (child.dataType, dataType) match {
35+
override def nullable = forceNullable(child.dataType, dataType) || child.nullable
36+
37+
private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match {
3338
case (StringType, _: NumericType) => true
3439
case (StringType, TimestampType) => true
3540
case (DoubleType, TimestampType) => true
@@ -41,8 +46,62 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
4146
case (DateType, BooleanType) => true
4247
case (DoubleType, _: DecimalType) => true
4348
case (FloatType, _: DecimalType) => true
44-
case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
45-
case _ => child.nullable
49+
case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
50+
case _ => false
51+
}
52+
53+
private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to
54+
55+
private[this] def resolve(from: DataType, to: DataType): Boolean = {
56+
(from, to) match {
57+
case (from, to) if from == to => true
58+
59+
case (NullType, _) => true
60+
61+
case (_, StringType) => true
62+
63+
case (StringType, BinaryType) => true
64+
65+
case (StringType, BooleanType) => true
66+
case (DateType, BooleanType) => true
67+
case (TimestampType, BooleanType) => true
68+
case (_: NumericType, BooleanType) => true
69+
70+
case (StringType, TimestampType) => true
71+
case (BooleanType, TimestampType) => true
72+
case (DateType, TimestampType) => true
73+
case (_: NumericType, TimestampType) => true
74+
75+
case (_, DateType) => true
76+
77+
case (StringType, _: NumericType) => true
78+
case (BooleanType, _: NumericType) => true
79+
case (DateType, _: NumericType) => true
80+
case (TimestampType, _: NumericType) => true
81+
case (_: NumericType, _: NumericType) => true
82+
83+
case (ArrayType(from, fn), ArrayType(to, tn)) =>
84+
resolve(from, to) &&
85+
resolvableNullability(fn || forceNullable(from, to), tn)
86+
87+
case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
88+
resolve(fromKey, toKey) &&
89+
(!forceNullable(fromKey, toKey)) &&
90+
resolve(fromValue, toValue) &&
91+
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
92+
93+
case (StructType(fromFields), StructType(toFields)) =>
94+
fromFields.size == toFields.size &&
95+
fromFields.zip(toFields).forall {
96+
case (fromField, toField) =>
97+
resolve(fromField.dataType, toField.dataType) &&
98+
resolvableNullability(
99+
fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
100+
toField.nullable)
101+
}
102+
103+
case _ => false
104+
}
46105
}
47106

48107
override def toString = s"CAST($child, $dataType)"
@@ -53,20 +112,20 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
53112
@inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
54113

55114
// UDFToString
56-
private[this] def castToString: Any => Any = child.dataType match {
115+
private[this] def castToString(from: DataType): Any => Any = from match {
57116
case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
58117
case DateType => buildCast[Date](_, dateToString)
59118
case TimestampType => buildCast[Timestamp](_, timestampToString)
60119
case _ => buildCast[Any](_, _.toString)
61120
}
62121

63122
// BinaryConverter
64-
private[this] def castToBinary: Any => Any = child.dataType match {
123+
private[this] def castToBinary(from: DataType): Any => Any = from match {
65124
case StringType => buildCast[String](_, _.getBytes("UTF-8"))
66125
}
67126

68127
// UDFToBoolean
69-
private[this] def castToBoolean: Any => Any = child.dataType match {
128+
private[this] def castToBoolean(from: DataType): Any => Any = from match {
70129
case StringType =>
71130
buildCast[String](_, _.length() != 0)
72131
case TimestampType =>
@@ -91,7 +150,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
91150
}
92151

93152
// TimestampConverter
94-
private[this] def castToTimestamp: Any => Any = child.dataType match {
153+
private[this] def castToTimestamp(from: DataType): Any => Any = from match {
95154
case StringType =>
96155
buildCast[String](_, s => {
97156
// Throw away extra if more than 9 decimal places
@@ -133,7 +192,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
133192
})
134193
}
135194

136-
private[this] def decimalToTimestamp(d: Decimal) = {
195+
private[this] def decimalToTimestamp(d: Decimal) = {
137196
val seconds = Math.floor(d.toDouble).toLong
138197
val bd = (d.toBigDecimal - seconds) * 1000000000
139198
val nanos = bd.intValue()
@@ -172,11 +231,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
172231
}
173232

174233
// DateConverter
175-
private[this] def castToDate: Any => Any = child.dataType match {
234+
private[this] def castToDate(from: DataType): Any => Any = from match {
176235
case StringType =>
177236
buildCast[String](_, s =>
178-
try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null }
179-
)
237+
try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null })
180238
case TimestampType =>
181239
// throw valid precision more than seconds, according to Hive.
182240
// Timestamp.nanos is in 0 to 999,999,999, no more than a second.
@@ -199,7 +257,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
199257
}
200258

201259
// LongConverter
202-
private[this] def castToLong: Any => Any = child.dataType match {
260+
private[this] def castToLong(from: DataType): Any => Any = from match {
203261
case StringType =>
204262
buildCast[String](_, s => try s.toLong catch {
205263
case _: NumberFormatException => null
@@ -210,14 +268,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
210268
buildCast[Date](_, d => dateToLong(d))
211269
case TimestampType =>
212270
buildCast[Timestamp](_, t => timestampToLong(t))
213-
case DecimalType() =>
214-
buildCast[Decimal](_, _.toLong)
215271
case x: NumericType =>
216272
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
217273
}
218274

219275
// IntConverter
220-
private[this] def castToInt: Any => Any = child.dataType match {
276+
private[this] def castToInt(from: DataType): Any => Any = from match {
221277
case StringType =>
222278
buildCast[String](_, s => try s.toInt catch {
223279
case _: NumberFormatException => null
@@ -228,14 +284,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
228284
buildCast[Date](_, d => dateToLong(d))
229285
case TimestampType =>
230286
buildCast[Timestamp](_, t => timestampToLong(t).toInt)
231-
case DecimalType() =>
232-
buildCast[Decimal](_, _.toInt)
233287
case x: NumericType =>
234288
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
235289
}
236290

237291
// ShortConverter
238-
private[this] def castToShort: Any => Any = child.dataType match {
292+
private[this] def castToShort(from: DataType): Any => Any = from match {
239293
case StringType =>
240294
buildCast[String](_, s => try s.toShort catch {
241295
case _: NumberFormatException => null
@@ -246,14 +300,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
246300
buildCast[Date](_, d => dateToLong(d))
247301
case TimestampType =>
248302
buildCast[Timestamp](_, t => timestampToLong(t).toShort)
249-
case DecimalType() =>
250-
buildCast[Decimal](_, _.toShort)
251303
case x: NumericType =>
252304
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
253305
}
254306

255307
// ByteConverter
256-
private[this] def castToByte: Any => Any = child.dataType match {
308+
private[this] def castToByte(from: DataType): Any => Any = from match {
257309
case StringType =>
258310
buildCast[String](_, s => try s.toByte catch {
259311
case _: NumberFormatException => null
@@ -264,8 +316,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
264316
buildCast[Date](_, d => dateToLong(d))
265317
case TimestampType =>
266318
buildCast[Timestamp](_, t => timestampToLong(t).toByte)
267-
case DecimalType() =>
268-
buildCast[Decimal](_, _.toByte)
269319
case x: NumericType =>
270320
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
271321
}
@@ -285,7 +335,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
285335
}
286336
}
287337

288-
private[this] def castToDecimal(target: DecimalType): Any => Any = child.dataType match {
338+
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
289339
case StringType =>
290340
buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
291341
case _: NumberFormatException => null
@@ -301,7 +351,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
301351
b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
302352
case LongType =>
303353
b => changePrecision(Decimal(b.asInstanceOf[Long]), target)
304-
case x: NumericType => // All other numeric types can be represented precisely as Doubles
354+
case x: NumericType => // All other numeric types can be represented precisely as Doubles
305355
b => try {
306356
changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target)
307357
} catch {
@@ -310,7 +360,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
310360
}
311361

312362
// DoubleConverter
313-
private[this] def castToDouble: Any => Any = child.dataType match {
363+
private[this] def castToDouble(from: DataType): Any => Any = from match {
314364
case StringType =>
315365
buildCast[String](_, s => try s.toDouble catch {
316366
case _: NumberFormatException => null
@@ -321,14 +371,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
321371
buildCast[Date](_, d => dateToDouble(d))
322372
case TimestampType =>
323373
buildCast[Timestamp](_, t => timestampToDouble(t))
324-
case DecimalType() =>
325-
buildCast[Decimal](_, _.toDouble)
326374
case x: NumericType =>
327375
b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
328376
}
329377

330378
// FloatConverter
331-
private[this] def castToFloat: Any => Any = child.dataType match {
379+
private[this] def castToFloat(from: DataType): Any => Any = from match {
332380
case StringType =>
333381
buildCast[String](_, s => try s.toFloat catch {
334382
case _: NumberFormatException => null
@@ -339,28 +387,53 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
339387
buildCast[Date](_, d => dateToDouble(d))
340388
case TimestampType =>
341389
buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
342-
case DecimalType() =>
343-
buildCast[Decimal](_, _.toFloat)
344390
case x: NumericType =>
345391
b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
346392
}
347393

348-
private[this] lazy val cast: Any => Any = dataType match {
394+
private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = {
395+
val elementCast = cast(from.elementType, to.elementType)
396+
buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v)))
397+
}
398+
399+
private[this] def castMap(from: MapType, to: MapType): Any => Any = {
400+
val keyCast = cast(from.keyType, to.keyType)
401+
val valueCast = cast(from.valueType, to.valueType)
402+
buildCast[Map[Any, Any]](_, _.map {
403+
case (key, value) => (keyCast(key), if (value == null) null else valueCast(value))
404+
})
405+
}
406+
407+
private[this] def castStruct(from: StructType, to: StructType): Any => Any = {
408+
val casts = from.fields.zip(to.fields).map {
409+
case (fromField, toField) => cast(fromField.dataType, toField.dataType)
410+
}
411+
buildCast[Row](_, row => Row(row.zip(casts).map {
412+
case (v, cast) => if (v == null) null else cast(v)
413+
}: _*))
414+
}
415+
416+
private[this] def cast(from: DataType, to: DataType): Any => Any = to match {
349417
case dt if dt == child.dataType => identity[Any]
350-
case StringType => castToString
351-
case BinaryType => castToBinary
352-
case DateType => castToDate
353-
case decimal: DecimalType => castToDecimal(decimal)
354-
case TimestampType => castToTimestamp
355-
case BooleanType => castToBoolean
356-
case ByteType => castToByte
357-
case ShortType => castToShort
358-
case IntegerType => castToInt
359-
case FloatType => castToFloat
360-
case LongType => castToLong
361-
case DoubleType => castToDouble
418+
case StringType => castToString(from)
419+
case BinaryType => castToBinary(from)
420+
case DateType => castToDate(from)
421+
case decimal: DecimalType => castToDecimal(from, decimal)
422+
case TimestampType => castToTimestamp(from)
423+
case BooleanType => castToBoolean(from)
424+
case ByteType => castToByte(from)
425+
case ShortType => castToShort(from)
426+
case IntegerType => castToInt(from)
427+
case FloatType => castToFloat(from)
428+
case LongType => castToLong(from)
429+
case DoubleType => castToDouble(from)
430+
case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array)
431+
case map: MapType => castMap(from.asInstanceOf[MapType], map)
432+
case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
362433
}
363434

435+
private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
436+
364437
override def eval(input: Row): Any = {
365438
val evaluated = child.eval(input)
366439
if (evaluated == null) null else cast(evaluated)

0 commit comments

Comments
 (0)