Skip to content

Commit 88e6b1f

Browse files
committed
handle null in schemaRDD()
1 parent 82624e2 commit 88e6b1f

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

python/pyspark/sql.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,13 @@ def jsonRDD(self, rdd, schema=None):
12311231
... "field3.field5[0] as f3 from table3")
12321232
>>> srdd6.collect()
12331233
[Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
1234+
1235+
>>> sqlCtx.jsonRDD(sc.parallelize(['{}',
1236+
... '{"key0": {"key1": "value1"}}'])).collect()
1237+
[Row(key0=None), Row(key0=Row(key1=u'value1'))]
1238+
>>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}',
1239+
... '{"key0": {"key1": "value1"}}'])).collect()
1240+
[Row(key0=None), Row(key0=Row(key1=u'value1'))]
12341241
"""
12351242

12361243
def func(iterator):

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -382,21 +382,26 @@ class SchemaRDD(
382382
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
383383
import scala.collection.Map
384384

385-
def toJava(obj: Any, dataType: DataType): Any = dataType match {
386-
case struct: StructType => rowToArray(obj.asInstanceOf[Row], struct)
387-
case array: ArrayType => obj match {
388-
case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava
389-
case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava
390-
case arr if arr != null && arr.getClass.isArray =>
391-
arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
392-
case other => other
393-
}
394-
case mt: MapType => obj.asInstanceOf[Map[_, _]].map {
385+
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
386+
case (null, _) => null
387+
388+
case (obj: Row, struct: StructType) => rowToArray(obj, struct)
389+
390+
case (seq: Seq[Any], array: ArrayType) =>
391+
seq.map(x => toJava(x, array.elementType)).asJava
392+
case (list: JList[_], array: ArrayType) =>
393+
list.map(x => toJava(x, array.elementType)).asJava
394+
case (arr, array: ArrayType) if arr.getClass.isArray =>
395+
arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
396+
397+
case (obj: Map[_, _], mt: MapType) => obj.map {
395398
case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
396399
}.asJava
400+
397401
// Pyrolite can handle Timestamp
398-
case other => obj
402+
case (other, _) => other
399403
}
404+
400405
def rowToArray(row: Row, structType: StructType): Array[Any] = {
401406
val fields = structType.fields.map(field => field.dataType)
402407
row.zip(fields).map {

0 commit comments

Comments
 (0)