Skip to content

Commit 60e18ce

Browse files
committed
SPARK-1414. Python API for SparkContext.wholeTextFiles
Also clarified comment on each file having to fit in memory Author: Matei Zaharia <matei@databricks.com> Closes #327 from mateiz/py-whole-files and squashes the following commits: 9ad64a5 [Matei Zaharia] SPARK-1414. Python API for SparkContext.wholeTextFiles
1 parent d956cc2 commit 60e18ce

File tree

5 files changed

+49
-7
lines changed

5 files changed

+49
-7
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ class SparkContext(
395395
* (a-hdfs-path/part-nnnnn, its content)
396396
* }}}
397397
*
398-
* @note Small files are perferred, large file is also allowable, but may cause bad performance.
398+
* @note Small files are preferred, as each file will be loaded fully in memory.
399399
*/
400400
def wholeTextFiles(path: String): RDD[(String, String)] = {
401401
newAPIHadoopFile(

core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
177177
* (a-hdfs-path/part-nnnnn, its content)
178178
* }}}
179179
*
180-
* @note Small files are perferred, large file is also allowable, but may cause bad performance.
180+
* @note Small files are preferred, as each file will be loaded fully in memory.
181181
*/
182182
def wholeTextFiles(path: String): JavaPairRDD[String, String] =
183183
new JavaPairRDD(sc.wholeTextFiles(path))

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.api.python
1919

2020
import java.io._
2121
import java.net._
22+
import java.nio.charset.Charset
2223
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
2324

2425
import scala.collection.JavaConversions._
@@ -206,6 +207,7 @@ private object SpecialLengths {
206207
}
207208

208209
private[spark] object PythonRDD {
210+
val UTF8 = Charset.forName("UTF-8")
209211

210212
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
211213
JavaRDD[Array[Byte]] = {
@@ -266,7 +268,7 @@ private[spark] object PythonRDD {
266268
}
267269

268270
def writeUTF(str: String, dataOut: DataOutputStream) {
269-
val bytes = str.getBytes("UTF-8")
271+
val bytes = str.getBytes(UTF8)
270272
dataOut.writeInt(bytes.length)
271273
dataOut.write(bytes)
272274
}
@@ -286,7 +288,7 @@ private[spark] object PythonRDD {
286288

287289
private
288290
class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
289-
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
291+
override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
290292
}
291293

292294
/**

python/pyspark/context.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from pyspark.conf import SparkConf
2929
from pyspark.files import SparkFiles
3030
from pyspark.java_gateway import launch_gateway
31-
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
31+
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
32+
PairDeserializer
3233
from pyspark.storagelevel import StorageLevel
3334
from pyspark import rdd
3435
from pyspark.rdd import RDD
@@ -257,6 +258,45 @@ def textFile(self, name, minSplits=None):
257258
return RDD(self._jsc.textFile(name, minSplits), self,
258259
UTF8Deserializer())
259260

261+
def wholeTextFiles(self, path):
262+
"""
263+
Read a directory of text files from HDFS, a local file system
264+
(available on all nodes), or any Hadoop-supported file system
265+
URI. Each file is read as a single record and returned in a
266+
key-value pair, where the key is the path of each file, the
267+
value is the content of each file.
268+
269+
For example, if you have the following files::
270+
271+
hdfs://a-hdfs-path/part-00000
272+
hdfs://a-hdfs-path/part-00001
273+
...
274+
hdfs://a-hdfs-path/part-nnnnn
275+
276+
Do C{rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")},
277+
then C{rdd} contains::
278+
279+
(a-hdfs-path/part-00000, its content)
280+
(a-hdfs-path/part-00001, its content)
281+
...
282+
(a-hdfs-path/part-nnnnn, its content)
283+
284+
NOTE: Small files are preferred, as each file will be loaded
285+
fully in memory.
286+
287+
>>> dirPath = os.path.join(tempdir, "files")
288+
>>> os.mkdir(dirPath)
289+
>>> with open(os.path.join(dirPath, "1.txt"), "w") as file1:
290+
... file1.write("1")
291+
>>> with open(os.path.join(dirPath, "2.txt"), "w") as file2:
292+
... file2.write("2")
293+
>>> textFiles = sc.wholeTextFiles(dirPath)
294+
>>> sorted(textFiles.collect())
295+
[(u'.../1.txt', u'1'), (u'.../2.txt', u'2')]
296+
"""
297+
return RDD(self._jsc.wholeTextFiles(path), self,
298+
PairDeserializer(UTF8Deserializer(), UTF8Deserializer()))
299+
260300
def _checkpointFile(self, name, input_deserializer):
261301
jrdd = self._jsc.checkpointFile(name)
262302
return RDD(jrdd, self, input_deserializer)
@@ -425,7 +465,7 @@ def _test():
425465
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
426466
globs['tempdir'] = tempfile.mkdtemp()
427467
atexit.register(lambda: shutil.rmtree(globs['tempdir']))
428-
(failure_count, test_count) = doctest.testmod(globs=globs)
468+
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
429469
globs['sc'].stop()
430470
if failure_count:
431471
exit(-1)

python/pyspark/serializers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class MarshalSerializer(FramedSerializer):
290290

291291
class UTF8Deserializer(Serializer):
292292
"""
293-
Deserializes streams written by getBytes.
293+
Deserializes streams written by String.getBytes.
294294
"""
295295

296296
def loads(self, stream):

0 commit comments

Comments
 (0)