Skip to content

Commit f257071

Browse files
author
Davies Liu
committed
add tests for null in RDD
1 parent 23b039a commit f257071

File tree

6 files changed

+52
-69
lines changed

6 files changed

+52
-69
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 15 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -373,68 +373,27 @@ private[spark] object PythonRDD extends Logging {
373373

374374
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
375375

376-
def writeBytes(bytes: Array[Byte]) {
377-
if (bytes == null) {
376+
def write(obj: Any): Unit = obj match {
377+
case null =>
378378
dataOut.writeInt(SpecialLengths.NULL)
379-
} else {
380-
dataOut.writeInt(bytes.length)
381-
dataOut.write(bytes)
382-
}
383-
}
384379

385-
def writeString(str: String) {
386-
if (str == null) {
387-
dataOut.writeInt(SpecialLengths.NULL)
388-
} else {
380+
case arr: Array[Byte] =>
381+
dataOut.writeInt(arr.length)
382+
dataOut.write(arr)
383+
case str: String =>
389384
writeUTF(str, dataOut)
390-
}
391-
}
392385

393-
// The right way to implement this would be to use TypeTags to get the full
394-
// type of T. Since I don't want to introduce breaking changes throughout the
395-
// entire Spark API, I have to use this hacky approach:
396-
if (iter.hasNext) {
397-
val first = iter.next()
398-
val newIter = Seq(first).iterator ++ iter
399-
first match {
400-
case arr: Array[Byte] =>
401-
newIter.asInstanceOf[Iterator[Array[Byte]]].foreach(writeBytes)
402-
case string: String =>
403-
newIter.asInstanceOf[Iterator[String]].foreach(writeString)
404-
case stream: PortableDataStream =>
405-
newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream =>
406-
writeBytes(stream.toArray())
407-
}
408-
case (key: String, stream: PortableDataStream) =>
409-
newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach {
410-
case (key, stream) =>
411-
writeString(key)
412-
writeBytes(stream.toArray())
413-
}
414-
case (key: String, value: String) =>
415-
newIter.asInstanceOf[Iterator[(String, String)]].foreach {
416-
case (key, value) =>
417-
writeString(key)
418-
writeString(value)
419-
}
420-
case (key: Array[Byte], value: Array[Byte]) =>
421-
newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
422-
case (key, value) =>
423-
writeBytes(key)
424-
writeBytes(value)
425-
}
426-
// key is null
427-
case (null, value: Array[Byte]) =>
428-
newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
429-
case (key, value) =>
430-
writeBytes(key)
431-
writeBytes(value)
432-
}
386+
case stream: PortableDataStream =>
387+
write(stream.toArray())
388+
case (key, value) =>
389+
write(key)
390+
write(value)
433391

434-
case other =>
435-
throw new SparkException("Unexpected element type " + other.getClass)
436-
}
392+
case other =>
393+
throw new SparkException("Unexpected element type " + other.getClass)
437394
}
395+
396+
iter.foreach(write)
438397
}
439398

440399
/**

core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.io.{File, InputStream, IOException, OutputStream}
2222
import scala.collection.mutable.ArrayBuffer
2323

2424
import org.apache.spark.SparkContext
25+
import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
2526

2627
private[spark] object PythonUtils {
2728
/** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */
@@ -39,4 +40,8 @@ private[spark] object PythonUtils {
3940
def mergePythonPaths(paths: String*): String = {
4041
paths.filter(_ != "").mkString(File.pathSeparator)
4142
}
43+
44+
def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = {
45+
sc.parallelize(List("a", null, "b"))
46+
}
4247
}

core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,21 @@ import org.scalatest.FunSuite
2323

2424
class PythonRDDSuite extends FunSuite {
2525

26-
test("Writing large strings to the worker") {
27-
val input: List[String] = List("a"*100000)
28-
val buffer = new DataOutputStream(new ByteArrayOutputStream)
29-
PythonRDD.writeIteratorToStream(input.iterator, buffer)
30-
}
26+
test("Writing large strings to the worker") {
27+
val input: List[String] = List("a"*100000)
28+
val buffer = new DataOutputStream(new ByteArrayOutputStream)
29+
PythonRDD.writeIteratorToStream(input.iterator, buffer)
30+
}
3131

32-
}
32+
test("Handle nulls gracefully") {
33+
val buffer = new DataOutputStream(new ByteArrayOutputStream)
34+
PythonRDD.writeIteratorToStream(List("a", null).iterator, buffer)
35+
PythonRDD.writeIteratorToStream(List(null, "a").iterator, buffer)
36+
PythonRDD.writeIteratorToStream(List("a".getBytes, null).iterator, buffer)
37+
PythonRDD.writeIteratorToStream(List(null, "a".getBytes).iterator, buffer)
3338

39+
PythonRDD.writeIteratorToStream(List((null, null), ("a", null), (null, "b")).iterator, buffer)
40+
PythonRDD.writeIteratorToStream(
41+
List((null, null), ("a".getBytes, null), (null, "b".getBytes)).iterator, buffer)
42+
}
43+
}

python/pyspark/serializers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ def load_stream(self, stream):
134134

135135
def _write_with_length(self, obj, stream):
136136
serialized = self.dumps(obj)
137+
if serialized is None:
138+
raise ValueError("serialized value should not be None")
137139
if len(serialized) > (1 << 31):
138140
raise ValueError("can not serialize object larger than 2G")
139141
write_int(len(serialized), stream)

python/pyspark/streaming/kafka.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def utf8_decoder(s):
3333
class KafkaUtils(object):
3434

3535
@staticmethod
36-
def createStream(ssc, zkQuorum, groupId, topics,
36+
def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
3737
storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
3838
keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
3939
"""
@@ -44,22 +44,23 @@ def createStream(ssc, zkQuorum, groupId, topics,
4444
:param groupId: The group id for this consumer.
4545
:param topics: Dict of (topic_name -> numPartitions) to consume.
4646
Each partition is consumed in its own thread.
47+
:param kafkaParams: Additional params for Kafka
4748
:param storageLevel: RDD storage level.
48-
:param keyDecoder: A function used to decode key
49-
:param valueDecoder: A function used to decode value
49+
:param keyDecoder: A function used to decode key (default is utf8_decoder)
50+
:param valueDecoder: A function used to decode value (default is utf8_decoder)
5051
:return: A DStream object
5152
"""
5253
java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils")
5354

54-
param = {
55+
kafkaParams.update({
5556
"zookeeper.connect": zkQuorum,
5657
"group.id": groupId,
5758
"zookeeper.connection.timeout.ms": "10000",
58-
}
59+
})
5960
if not isinstance(topics, dict):
6061
raise TypeError("topics should be dict")
6162
jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client)
62-
jparam = MapConverter().convert(param, ssc.sparkContext._gateway._gateway_client)
63+
jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client)
6364
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
6465

6566
def getClassByName(name):

python/pyspark/tests.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@
4646

4747
from pyspark.conf import SparkConf
4848
from pyspark.context import SparkContext
49+
from pyspark.rdd import RDD
4950
from pyspark.files import SparkFiles
5051
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
51-
CloudPickleSerializer, CompressedSerializer
52+
CloudPickleSerializer, CompressedSerializer, UTF8Deserializer
5253
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
5354
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
5455
UserDefinedType, DoubleType
@@ -714,6 +715,11 @@ def test_sample(self):
714715
wr_s21 = rdd.sample(True, 0.4, 21).collect()
715716
self.assertNotEqual(set(wr_s11), set(wr_s21))
716717

718+
def test_null_in_rdd(self):
719+
jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc)
720+
rdd = RDD(jrdd, self.sc, UTF8Deserializer())
721+
self.assertEqual([u"a", None, u"b"], rdd.collect())
722+
717723

718724
class ProfilerTests(PySparkTestCase):
719725

0 commit comments

Comments
 (0)