Skip to content

Commit ae3278d

Browse files
committed
Throw ClassCastException errors during inbound conversions.
1 parent 7ca7fcb commit ae3278d

File tree

1 file changed

+47
-36
lines changed

1 file changed

+47
-36
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ object CatalystTypeConverters {
3737
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
3838
import scala.collection.Map
3939

40-
private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any] = {
40+
private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
4141
val converter = dataType match {
4242
case udt: UserDefinedType[_] => UDTConverter(udt)
4343
case arrayType: ArrayType => ArrayConverter(arrayType.elementType)
@@ -55,81 +55,78 @@ object CatalystTypeConverters {
5555
case DoubleType => DoubleConverter
5656
case _ => IdentityConverter
5757
}
58-
converter.asInstanceOf[CatalystTypeConverter[Any, Any]]
58+
converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
5959
}
6060

6161
/**
6262
* Converts a Scala type to its Catalyst equivalent (and vice versa).
63+
*
64+
* @tparam ScalaInputType The type of Scala values that can be converted to Catalyst.
65+
* @tparam ScalaOutputType The type of Scala values returned when converting Catalyst to Scala.
66+
* @tparam CatalystType The internal Catalyst type used to represent values of this Scala type.
6367
*/
64-
private abstract class CatalystTypeConverter[ScalaType, CatalystType] extends Serializable {
68+
private abstract class CatalystTypeConverter[ScalaInputType, ScalaOutputType, CatalystType]
69+
extends Serializable {
6570

6671
/**
6772
* Converts a Scala type to its Catalyst equivalent while automatically handling nulls
6873
* and Options.
6974
*/
7075
final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = {
7176
maybeScalaValue match {
72-
case opt: Option[ScalaType] =>
77+
case opt: Option[ScalaInputType] =>
7378
if (opt.isDefined) {
7479
toCatalystImpl(opt.get)
7580
} else {
7681
null.asInstanceOf[CatalystType]
7782
}
7883
case null => null.asInstanceOf[CatalystType]
79-
case scalaValue: ScalaType => toCatalystImpl(scalaValue)
84+
case scalaValue: ScalaInputType => toCatalystImpl(scalaValue)
8085
}
8186
}
8287

8388
/**
8489
* Given a Catalyst row, convert the value at column `column` to its Scala equivalent.
8590
*/
86-
final def toScala(row: Row, column: Int): Any = {
87-
if (row.isNullAt(column)) null else toScalaImpl(row, column)
91+
final def toScala(row: Row, column: Int): ScalaOutputType = {
92+
if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column)
8893
}
8994

9095
/**
9196
* Convert a Catalyst value to its Scala equivalent.
9297
*/
93-
def toScala(@Nullable catalystValue: CatalystType): ScalaType
98+
def toScala(@Nullable catalystValue: CatalystType): ScalaOutputType
9499

95100
/**
96101
* Converts a Scala value to its Catalyst equivalent.
97102
* @param scalaValue the Scala value, guaranteed not to be null.
98103
* @return the Catalyst value.
99104
*/
100-
protected def toCatalystImpl(scalaValue: ScalaType): CatalystType
105+
protected def toCatalystImpl(scalaValue: ScalaInputType): CatalystType
101106

102107
/**
103108
* Given a Catalyst row, convert the value at column `column` to its Scala equivalent.
104109
* This method will only be called on non-null columns.
105110
*/
106-
protected def toScalaImpl(row: Row, column: Int): ScalaType
107-
}
108-
109-
/**
110-
* Convenience wrapper to write type converters for primitives. We use a converter for primitives
111-
* so that we can use type-specific field accessors when converting Catalyst rows to Scala rows.
112-
*/
113-
private abstract class PrimitiveCatalystTypeConverter[T] extends CatalystTypeConverter[T, T] {
114-
override final def toScala(catalystValue: T): T = catalystValue
115-
override final def toCatalystImpl(scalaValue: T): T = scalaValue
111+
protected def toScalaImpl(row: Row, column: Int): ScalaOutputType
116112
}
117113

118-
private object IdentityConverter extends CatalystTypeConverter[Any, Any] {
114+
private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] {
119115
override def toCatalystImpl(scalaValue: Any): Any = scalaValue
120116
override def toScala(catalystValue: Any): Any = catalystValue
121117
override def toScalaImpl(row: Row, column: Int): Any = row(column)
122118
}
123119

124-
private case class UDTConverter(udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any] {
120+
private case class UDTConverter(
121+
udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
125122
override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
126123
override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue)
127124
override def toScalaImpl(row: Row, column: Int): Any = toScala(row(column))
128125
}
129126

130-
// Converter for array, seq, iterables.
127+
/** Converter for arrays, sequences, and Java iterables. */
131128
private case class ArrayConverter(
132-
elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any]] {
129+
elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] {
133130

134131
private[this] val elementConverter = getConverterForType(elementType)
135132

@@ -162,8 +159,8 @@ object CatalystTypeConverters {
162159

163160
private case class MapConverter(
164161
keyType: DataType,
165-
valueType: DataType
166-
) extends CatalystTypeConverter[Any, Map[Any, Any]] {
162+
valueType: DataType)
163+
extends CatalystTypeConverter[Any, Map[Any, Any], Map[Any, Any]] {
167164

168165
private[this] val keyConverter = getConverterForType(keyType)
169166
private[this] val valueConverter = getConverterForType(valueType)
@@ -200,7 +197,7 @@ object CatalystTypeConverters {
200197
}
201198

202199
private case class StructConverter(
203-
structType: StructType) extends CatalystTypeConverter[Any, Row] {
200+
structType: StructType) extends CatalystTypeConverter[Any, Row, Row] {
204201

205202
private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) }
206203

@@ -242,7 +239,7 @@ object CatalystTypeConverters {
242239
override def toScalaImpl(row: Row, column: Int): Row = toScala(row(column).asInstanceOf[Row])
243240
}
244241

245-
private object StringConverter extends CatalystTypeConverter[Any, Any] {
242+
private object StringConverter extends CatalystTypeConverter[Any, String, Any] {
246243
override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match {
247244
case str: String => UTF8String(str)
248245
case utf8: UTF8String => utf8
@@ -255,14 +252,14 @@ object CatalystTypeConverters {
255252
override def toScalaImpl(row: Row, column: Int): String = row(column).toString
256253
}
257254

258-
private object DateConverter extends CatalystTypeConverter[Date, Any] {
255+
private object DateConverter extends CatalystTypeConverter[Date, Date, Any] {
259256
override def toCatalystImpl(scalaValue: Date): Int = DateUtils.fromJavaDate(scalaValue)
260257
override def toScala(catalystValue: Any): Date =
261258
if (catalystValue == null) null else DateUtils.toJavaDate(catalystValue.asInstanceOf[Int])
262259
override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column))
263260
}
264261

265-
private object BigDecimalConverter extends CatalystTypeConverter[Any, Decimal] {
262+
private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
266263
override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match {
267264
case d: BigDecimal => Decimal(d)
268265
case d: JavaBigDecimal => Decimal(d)
@@ -275,32 +272,46 @@ object CatalystTypeConverters {
275272
}
276273
}
277274

278-
private object BooleanConverter extends PrimitiveCatalystTypeConverter[Boolean] {
275+
private object BooleanConverter extends CatalystTypeConverter[Boolean, Boolean, Boolean] {
279276
override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column)
277+
override def toScala(catalystValue: Boolean): Boolean = catalystValue
278+
override protected def toCatalystImpl(scalaValue: Boolean): Boolean = scalaValue
280279
}
281280

282-
private object ByteConverter extends PrimitiveCatalystTypeConverter[Byte] {
281+
private object ByteConverter extends CatalystTypeConverter[Byte, Byte, Byte] {
283282
override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column)
283+
override def toScala(catalystValue: Byte): Byte = catalystValue
284+
override protected def toCatalystImpl(scalaValue: Byte): Byte = scalaValue
284285
}
285286

286-
private object ShortConverter extends PrimitiveCatalystTypeConverter[Short] {
287+
private object ShortConverter extends CatalystTypeConverter[Short, Short, Short] {
287288
override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column)
289+
override def toScala(catalystValue: Short): Short = catalystValue
290+
override protected def toCatalystImpl(scalaValue: Short): Short = scalaValue
288291
}
289292

290-
private object IntConverter extends PrimitiveCatalystTypeConverter[Int] {
293+
private object IntConverter extends CatalystTypeConverter[Int, Int, Int] {
291294
override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column)
295+
override def toScala(catalystValue: Int): Int = catalystValue
296+
override protected def toCatalystImpl(scalaValue: Int): Int = scalaValue
292297
}
293298

294-
private object LongConverter extends PrimitiveCatalystTypeConverter[Long] {
299+
private object LongConverter extends CatalystTypeConverter[Long, Long, Long] {
295300
override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column)
301+
override def toScala(catalystValue: Long): Long = catalystValue
302+
override protected def toCatalystImpl(scalaValue: Long): Long = scalaValue
296303
}
297304

298-
private object FloatConverter extends PrimitiveCatalystTypeConverter[Float] {
305+
private object FloatConverter extends CatalystTypeConverter[Float, Float, Float] {
299306
override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column)
307+
override def toScala(catalystValue: Float): Float = catalystValue
308+
override protected def toCatalystImpl(scalaValue: Float): Float = scalaValue
300309
}
301310

302-
private object DoubleConverter extends PrimitiveCatalystTypeConverter[Double] {
311+
private object DoubleConverter extends CatalystTypeConverter[Double, Double, Double] {
303312
override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column)
313+
override def toScala(catalystValue: Double): Double = catalystValue
314+
override protected def toCatalystImpl(scalaValue: Double): Double = scalaValue
304315
}
305316

306317
/**

0 commit comments

Comments
 (0)