Skip to content

[SPARK-7160][SQL] Support converting DataFrames to typed RDDs. #5713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ import java.util.{Map => JavaMap}
import javax.annotation.Nullable

import scala.language.existentials
import scala.reflect.ClassTag

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

/**
* Functions to convert Scala types to Catalyst types and vice versa.
Expand All @@ -39,6 +41,8 @@ object CatalystTypeConverters {
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
import scala.collection.Map

lazy val universe = ScalaReflection.universe

private def isPrimitive(dataType: DataType): Boolean = {
dataType match {
case BooleanType => true
Expand Down Expand Up @@ -454,4 +458,166 @@ object CatalystTypeConverters {
def convertToScala(catalystValue: Any, dataType: DataType): Any = {
createToScalaConverter(dataType)(catalystValue)
}

/**
* Like createToScalaConverter(DataType), creates a function that converts a Catalyst object to a
* Scala object; however, in this case, the Scala object is an instance of a subtype of Product
* (e.g. a case class).
*
* If the given Scala type is not compatible with the given structType, this method ultimately
* throws a ClassCastException when the converter is invoked.
*
* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
private[sql] def createToProductConverter[T <: Product](
structType: StructType)(implicit classTag: ClassTag[T]): InternalRow => T = {

// Use ScalaReflectionLock, to avoid reflection thread safety issues in 2.10.
// https://issues.scala-lang.org/browse/SI-6240
// http://docs.scala-lang.org/overviews/reflection/thread-safety.html
ScalaReflectionLock.synchronized { createToProductConverter(classTag, structType) }
}

private[sql] def createToProductConverter[T <: Product](
classTag: ClassTag[T], structType: StructType): InternalRow => T = {

import universe._

val constructorMirror = {
val mirror = runtimeMirror(Utils.getContextOrSparkClassLoader)
val classSymbol = mirror.classSymbol(classTag.runtimeClass)
val classMirror = mirror.reflectClass(classSymbol)
val constructorSymbol = {
// Adapted from ScalaReflection to find primary constructor.
// https://issues.apache.org/jira/browse/SPARK-4791
val symbol = classSymbol.toType.declaration(nme.CONSTRUCTOR)
if (symbol.isMethod) {
symbol.asMethod
} else {
val candidateSymbol =
symbol.asTerm.alternatives.find { s => s.isMethod && s.asMethod.isPrimaryConstructor }
if (candidateSymbol.isDefined) {
candidateSymbol.get.asMethod
} else {
throw new IllegalArgumentException(s"No primary constructor for ${symbol.name}")
}
}
}
classMirror.reflectConstructor(constructorSymbol)
}

val params = constructorMirror.symbol.paramss.head.toSeq
val paramTypes = params.map { _.asTerm.typeSignature }
val fields = structType.fields
val dataTypes = fields.map { _.dataType }
val converters: Seq[Any => Any] =
paramTypes.zip(dataTypes).map { case (pt, dt) => createToScalaConverter(pt, dt) }

(row: InternalRow) => if (row == null) {
null.asInstanceOf[T]
} else {
val convertedArgs =
converters.zip(row.toSeq(dataTypes)).map { case (converter, arg) => converter(arg) }
try {
constructorMirror.apply(convertedArgs: _*).asInstanceOf[T]
} catch {
case e: IllegalArgumentException => // argument type mismatch
val message =
s"""|Error constructing ${classTag.runtimeClass.getName}: ${e.getMessage};
|paramTypes: ${paramTypes}, dataTypes: ${dataTypes},
|convertedArgs: ${convertedArgs}""".stripMargin.replace("\n", " ")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressing feedback, this error message provides more details when things go wrong calling the constructor.

throw new ClassCastException(message)
}
}
}

/**
* Like createToScalaConverter(DataType), but with a Scala type hint.
*
* Please keep in sync with createToScalaConverter(DataType) and ScalaReflection.schemaFor[T].
*/
private[sql] def createToScalaConverter(
universeType: universe.Type, dataType: DataType): Any => Any = {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressing feedback, this method has all the conversions, as requested.


import universe._

(universeType, dataType) match {
case (t, dt) if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val converter: Any => Any = createToScalaConverter(elementType, dt)
(catalystValue: Any) => Option(converter(catalystValue))

case (t, udt: UserDefinedType[_]) =>
(catalystValue: Any) => if (catalystValue == null) null else udt.deserialize(catalystValue)

case (t, bt: BinaryType) => identity

case (t, at: ArrayType) if t <:< typeOf[Array[_]] =>
throw new UnsupportedOperationException("Array[_] is not supported; try using Seq instead.")

case (t, at: ArrayType) if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val converter: Any => Any = createToScalaConverter(elementType, at.elementType)
(catalystValue: Any) => catalystValue match {
case arrayData: ArrayData => arrayData.toArray[Any](at.elementType).map(converter).toSeq
case o => o
}

case (t, mt: MapType) if t <:< typeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val keyConverter: Any => Any = createToScalaConverter(keyType, mt.keyType)
val valueConverter: Any => Any = createToScalaConverter(valueType, mt.valueType)
(catalystValue: Any) => catalystValue match {
case mapData: MapData =>
val keys = mapData.keyArray().toArray[Any](mt.keyType)
val values = mapData.valueArray().toArray[Any](mt.valueType)
keys.map(keyConverter).zip(values.map(valueConverter)).toMap
case o => o
}

case (t, st: StructType) if t <:< typeOf[Product] =>
val className = t.erasure.typeSymbol.asClass.fullName
val classTag = if (Utils.classIsLoadable(className)) {
scala.reflect.ClassTag(Utils.classForName(className))
} else {
throw new IllegalArgumentException(s"$className is not loadable")
}
createToProductConverter(classTag, st).asInstanceOf[Any => Any]

case (t, StringType) if t <:< typeOf[String] =>
(catalystValue: Any) => catalystValue match {
case utf8: UTF8String => utf8.toString
case o => o
}

case (t, DateType) if t <:< typeOf[Date] =>
(catalystValue: Any) => catalystValue match {
case i: Int => DateTimeUtils.toJavaDate(i)
case o => o
}

case (t, TimestampType) if t <:< typeOf[Timestamp] =>
(catalystValue: Any) => catalystValue match {
case x: Long => DateTimeUtils.toJavaTimestamp(x)
case o => o
}

case (t, _: DecimalType) if t <:< typeOf[BigDecimal] =>
(catalystValue: Any) => catalystValue match {
case d: Decimal => d.toBigDecimal
case o => o
}

case (t, _: DecimalType) if t <:< typeOf[java.math.BigDecimal] =>
(catalystValue: Any) => catalystValue match {
case d: Decimal => d.toJavaBigDecimal
case o => o
}

// Pass non-string primitives through. (Strings are converted from UTF8Strings above.)
// For everything else, hope for the best.
case (t, o) => identity
}
}
}
Loading