@@ -27,9 +27,14 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
2727
2828/** Cast the child expression to the target data type. */
2929case class Cast (child : Expression , dataType : DataType ) extends UnaryExpression with Logging {
30+
31+ override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
32+
3033 override def foldable = child.foldable
3134
32- override def nullable = (child.dataType, dataType) match {
35+ override def nullable = forceNullable(child.dataType, dataType) || child.nullable
36+
37+ private [this ] def forceNullable (from : DataType , to : DataType ) = (from, to) match {
3338 case (StringType , _ : NumericType ) => true
3439 case (StringType , TimestampType ) => true
3540 case (DoubleType , TimestampType ) => true
@@ -41,8 +46,62 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
4146 case (DateType , BooleanType ) => true
4247 case (DoubleType , _ : DecimalType ) => true
4348 case (FloatType , _ : DecimalType ) => true
44- case (_, DecimalType .Fixed (_, _)) => true // TODO: not all upcasts here can really give null
45- case _ => child.nullable
49+ case (_, DecimalType .Fixed (_, _)) => true // TODO: not all upcasts here can really give null
50+ case _ => false
51+ }
52+
53+ private [this ] def resolvableNullability (from : Boolean , to : Boolean ) = ! from || to
54+
55+ private [this ] def resolve (from : DataType , to : DataType ): Boolean = {
56+ (from, to) match {
57+ case (from, to) if from == to => true
58+
59+ case (NullType , _) => true
60+
61+ case (_, StringType ) => true
62+
63+ case (StringType , BinaryType ) => true
64+
65+ case (StringType , BooleanType ) => true
66+ case (DateType , BooleanType ) => true
67+ case (TimestampType , BooleanType ) => true
68+ case (_ : NumericType , BooleanType ) => true
69+
70+ case (StringType , TimestampType ) => true
71+ case (BooleanType , TimestampType ) => true
72+ case (DateType , TimestampType ) => true
73+ case (_ : NumericType , TimestampType ) => true
74+
75+ case (_, DateType ) => true
76+
77+ case (StringType , _ : NumericType ) => true
78+ case (BooleanType , _ : NumericType ) => true
79+ case (DateType , _ : NumericType ) => true
80+ case (TimestampType , _ : NumericType ) => true
81+ case (_ : NumericType , _ : NumericType ) => true
82+
83+ case (ArrayType (from, fn), ArrayType (to, tn)) =>
84+ resolve(from, to) &&
85+ resolvableNullability(fn || forceNullable(from, to), tn)
86+
87+ case (MapType (fromKey, fromValue, fn), MapType (toKey, toValue, tn)) =>
88+ resolve(fromKey, toKey) &&
89+ (! forceNullable(fromKey, toKey)) &&
90+ resolve(fromValue, toValue) &&
91+ resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
92+
93+ case (StructType (fromFields), StructType (toFields)) =>
94+ fromFields.size == toFields.size &&
95+ fromFields.zip(toFields).forall {
96+ case (fromField, toField) =>
97+ resolve(fromField.dataType, toField.dataType) &&
98+ resolvableNullability(
99+ fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
100+ toField.nullable)
101+ }
102+
103+ case _ => false
104+ }
46105 }
47106
48107 override def toString = s " CAST( $child, $dataType) "
@@ -53,20 +112,20 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
53112 @ inline private [this ] def buildCast [T ](a : Any , func : T => Any ): Any = func(a.asInstanceOf [T ])
54113
55114 // UDFToString
56- private [this ] def castToString : Any => Any = child.dataType match {
115+ private [this ] def castToString ( from : DataType ) : Any => Any = from match {
57116 case BinaryType => buildCast[Array [Byte ]](_, new String (_, " UTF-8" ))
58117 case DateType => buildCast[Date ](_, dateToString)
59118 case TimestampType => buildCast[Timestamp ](_, timestampToString)
60119 case _ => buildCast[Any ](_, _.toString)
61120 }
62121
63122 // BinaryConverter
64- private [this ] def castToBinary : Any => Any = child.dataType match {
123+ private [this ] def castToBinary ( from : DataType ) : Any => Any = from match {
65124 case StringType => buildCast[String ](_, _.getBytes(" UTF-8" ))
66125 }
67126
68127 // UDFToBoolean
69- private [this ] def castToBoolean : Any => Any = child.dataType match {
128+ private [this ] def castToBoolean ( from : DataType ) : Any => Any = from match {
70129 case StringType =>
71130 buildCast[String ](_, _.length() != 0 )
72131 case TimestampType =>
@@ -91,7 +150,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
91150 }
92151
93152 // TimestampConverter
94- private [this ] def castToTimestamp : Any => Any = child.dataType match {
153+ private [this ] def castToTimestamp ( from : DataType ) : Any => Any = from match {
95154 case StringType =>
96155 buildCast[String ](_, s => {
97156 // Throw away extra if more than 9 decimal places
@@ -133,7 +192,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
133192 })
134193 }
135194
136- private [this ] def decimalToTimestamp (d : Decimal ) = {
195+ private [this ] def decimalToTimestamp (d : Decimal ) = {
137196 val seconds = Math .floor(d.toDouble).toLong
138197 val bd = (d.toBigDecimal - seconds) * 1000000000
139198 val nanos = bd.intValue()
@@ -172,11 +231,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
172231 }
173232
174233 // DateConverter
175- private [this ] def castToDate : Any => Any = child.dataType match {
234+ private [this ] def castToDate ( from : DataType ) : Any => Any = from match {
176235 case StringType =>
177236 buildCast[String ](_, s =>
178- try Date .valueOf(s) catch { case _ : java.lang.IllegalArgumentException => null }
179- )
237+ try Date .valueOf(s) catch { case _ : java.lang.IllegalArgumentException => null })
180238 case TimestampType =>
181239 // throw valid precision more than seconds, according to Hive.
182240 // Timestamp.nanos is in 0 to 999,999,999, no more than a second.
@@ -199,7 +257,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
199257 }
200258
201259 // LongConverter
202- private [this ] def castToLong : Any => Any = child.dataType match {
260+ private [this ] def castToLong ( from : DataType ) : Any => Any = from match {
203261 case StringType =>
204262 buildCast[String ](_, s => try s.toLong catch {
205263 case _ : NumberFormatException => null
@@ -210,14 +268,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
210268 buildCast[Date ](_, d => dateToLong(d))
211269 case TimestampType =>
212270 buildCast[Timestamp ](_, t => timestampToLong(t))
213- case DecimalType () =>
214- buildCast[Decimal ](_, _.toLong)
215271 case x : NumericType =>
216272 b => x.numeric.asInstanceOf [Numeric [Any ]].toLong(b)
217273 }
218274
219275 // IntConverter
220- private [this ] def castToInt : Any => Any = child.dataType match {
276+ private [this ] def castToInt ( from : DataType ) : Any => Any = from match {
221277 case StringType =>
222278 buildCast[String ](_, s => try s.toInt catch {
223279 case _ : NumberFormatException => null
@@ -228,14 +284,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
228284 buildCast[Date ](_, d => dateToLong(d))
229285 case TimestampType =>
230286 buildCast[Timestamp ](_, t => timestampToLong(t).toInt)
231- case DecimalType () =>
232- buildCast[Decimal ](_, _.toInt)
233287 case x : NumericType =>
234288 b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b)
235289 }
236290
237291 // ShortConverter
238- private [this ] def castToShort : Any => Any = child.dataType match {
292+ private [this ] def castToShort ( from : DataType ) : Any => Any = from match {
239293 case StringType =>
240294 buildCast[String ](_, s => try s.toShort catch {
241295 case _ : NumberFormatException => null
@@ -246,14 +300,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
246300 buildCast[Date ](_, d => dateToLong(d))
247301 case TimestampType =>
248302 buildCast[Timestamp ](_, t => timestampToLong(t).toShort)
249- case DecimalType () =>
250- buildCast[Decimal ](_, _.toShort)
251303 case x : NumericType =>
252304 b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toShort
253305 }
254306
255307 // ByteConverter
256- private [this ] def castToByte : Any => Any = child.dataType match {
308+ private [this ] def castToByte ( from : DataType ) : Any => Any = from match {
257309 case StringType =>
258310 buildCast[String ](_, s => try s.toByte catch {
259311 case _ : NumberFormatException => null
@@ -264,8 +316,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
264316 buildCast[Date ](_, d => dateToLong(d))
265317 case TimestampType =>
266318 buildCast[Timestamp ](_, t => timestampToLong(t).toByte)
267- case DecimalType () =>
268- buildCast[Decimal ](_, _.toByte)
269319 case x : NumericType =>
270320 b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toByte
271321 }
@@ -285,7 +335,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
285335 }
286336 }
287337
288- private [this ] def castToDecimal (target : DecimalType ): Any => Any = child.dataType match {
338+ private [this ] def castToDecimal (from : DataType , target : DecimalType ): Any => Any = from match {
289339 case StringType =>
290340 buildCast[String ](_, s => try changePrecision(Decimal (s.toDouble), target) catch {
291341 case _ : NumberFormatException => null
@@ -301,7 +351,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
301351 b => changePrecision(b.asInstanceOf [Decimal ].clone(), target)
302352 case LongType =>
303353 b => changePrecision(Decimal (b.asInstanceOf [Long ]), target)
304- case x : NumericType => // All other numeric types can be represented precisely as Doubles
354+ case x : NumericType => // All other numeric types can be represented precisely as Doubles
305355 b => try {
306356 changePrecision(Decimal (x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b)), target)
307357 } catch {
@@ -310,7 +360,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
310360 }
311361
312362 // DoubleConverter
313- private [this ] def castToDouble : Any => Any = child.dataType match {
363+ private [this ] def castToDouble ( from : DataType ) : Any => Any = from match {
314364 case StringType =>
315365 buildCast[String ](_, s => try s.toDouble catch {
316366 case _ : NumberFormatException => null
@@ -321,14 +371,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
321371 buildCast[Date ](_, d => dateToDouble(d))
322372 case TimestampType =>
323373 buildCast[Timestamp ](_, t => timestampToDouble(t))
324- case DecimalType () =>
325- buildCast[Decimal ](_, _.toDouble)
326374 case x : NumericType =>
327375 b => x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b)
328376 }
329377
330378 // FloatConverter
331- private [this ] def castToFloat : Any => Any = child.dataType match {
379+ private [this ] def castToFloat ( from : DataType ) : Any => Any = from match {
332380 case StringType =>
333381 buildCast[String ](_, s => try s.toFloat catch {
334382 case _ : NumberFormatException => null
@@ -339,28 +387,53 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
339387 buildCast[Date ](_, d => dateToDouble(d))
340388 case TimestampType =>
341389 buildCast[Timestamp ](_, t => timestampToDouble(t).toFloat)
342- case DecimalType () =>
343- buildCast[Decimal ](_, _.toFloat)
344390 case x : NumericType =>
345391 b => x.numeric.asInstanceOf [Numeric [Any ]].toFloat(b)
346392 }
347393
348- private [this ] lazy val cast : Any => Any = dataType match {
394+ private [this ] def castArray (from : ArrayType , to : ArrayType ): Any => Any = {
395+ val elementCast = cast(from.elementType, to.elementType)
396+ buildCast[Seq [Any ]](_, _.map(v => if (v == null ) null else elementCast(v)))
397+ }
398+
399+ private [this ] def castMap (from : MapType , to : MapType ): Any => Any = {
400+ val keyCast = cast(from.keyType, to.keyType)
401+ val valueCast = cast(from.valueType, to.valueType)
402+ buildCast[Map [Any , Any ]](_, _.map {
403+ case (key, value) => (keyCast(key), if (value == null ) null else valueCast(value))
404+ })
405+ }
406+
407+ private [this ] def castStruct (from : StructType , to : StructType ): Any => Any = {
408+ val casts = from.fields.zip(to.fields).map {
409+ case (fromField, toField) => cast(fromField.dataType, toField.dataType)
410+ }
411+ buildCast[Row ](_, row => Row (row.zip(casts).map {
412+ case (v, cast) => if (v == null ) null else cast(v)
413+ }: _* ))
414+ }
415+
416+ private [this ] def cast (from : DataType , to : DataType ): Any => Any = to match {
349417 case dt if dt == child.dataType => identity[Any ]
350- case StringType => castToString
351- case BinaryType => castToBinary
352- case DateType => castToDate
353- case decimal : DecimalType => castToDecimal(decimal)
354- case TimestampType => castToTimestamp
355- case BooleanType => castToBoolean
356- case ByteType => castToByte
357- case ShortType => castToShort
358- case IntegerType => castToInt
359- case FloatType => castToFloat
360- case LongType => castToLong
361- case DoubleType => castToDouble
418+ case StringType => castToString(from)
419+ case BinaryType => castToBinary(from)
420+ case DateType => castToDate(from)
421+ case decimal : DecimalType => castToDecimal(from, decimal)
422+ case TimestampType => castToTimestamp(from)
423+ case BooleanType => castToBoolean(from)
424+ case ByteType => castToByte(from)
425+ case ShortType => castToShort(from)
426+ case IntegerType => castToInt(from)
427+ case FloatType => castToFloat(from)
428+ case LongType => castToLong(from)
429+ case DoubleType => castToDouble(from)
430+ case array : ArrayType => castArray(from.asInstanceOf [ArrayType ], array)
431+ case map : MapType => castMap(from.asInstanceOf [MapType ], map)
432+ case struct : StructType => castStruct(from.asInstanceOf [StructType ], struct)
362433 }
363434
435+ private [this ] lazy val cast : Any => Any = cast(child.dataType, dataType)
436+
364437 override def eval (input : Row ): Any = {
365438 val evaluated = child.eval(input)
366439 if (evaluated == null ) null else cast(evaluated)
0 commit comments