Skip to content

Commit 6c2691d

Browse files
kanzhangpwendell
authored andcommitted
[SPARK-1690] Tolerating empty elements when saving Python RDD to text files
Tolerate empty strings in PythonRDD Author: Kan Zhang <kzhang@apache.org> Closes apache#644 from kanzhang/SPARK-1690 and squashes the following commits: c62ad33 [Kan Zhang] Adding Python doctest 473ec4b [Kan Zhang] [SPARK-1690] Tolerating empty elements when saving Python RDD to text files
1 parent 3776f2f commit 6c2691d

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ private[spark] class PythonRDD[T: ClassTag](
9494
val obj = new Array[Byte](length)
9595
stream.readFully(obj)
9696
obj
97+
case 0 => Array.empty[Byte]
9798
case SpecialLengths.TIMING_DATA =>
9899
// Timing data from worker
99100
val bootTime = stream.readLong()
@@ -123,7 +124,7 @@ private[spark] class PythonRDD[T: ClassTag](
123124
stream.readFully(update)
124125
accumulator += Collections.singletonList(update)
125126
}
126-
Array.empty[Byte]
127+
null
127128
}
128129
} catch {
129130

@@ -143,7 +144,7 @@ private[spark] class PythonRDD[T: ClassTag](
143144

144145
var _nextObj = read()
145146

146-
def hasNext = _nextObj.length != 0
147+
def hasNext = _nextObj != null
147148
}
148149
new InterruptibleIterator(context, stdoutIterator)
149150
}

python/pyspark/rdd.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,14 @@ def saveAsTextFile(self, path):
891891
>>> from glob import glob
892892
>>> ''.join(sorted(input(glob(tempFile.name + "/part-0000*"))))
893893
'0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n'
894+
895+
Empty lines are tolerated when saving to text files.
896+
897+
>>> tempFile2 = NamedTemporaryFile(delete=True)
898+
>>> tempFile2.close()
899+
>>> sc.parallelize(['', 'foo', '', 'bar', '']).saveAsTextFile(tempFile2.name)
900+
>>> ''.join(sorted(input(glob(tempFile2.name + "/part-0000*"))))
901+
'\\n\\n\\nbar\\nfoo\\n'
894902
"""
895903
def func(split, iterator):
896904
for x in iterator:

0 commit comments

Comments
 (0)