Skip to content

Commit f671cdb

Browse files
committed
WIP: added PythonTestInputStream
1 parent 56fae45 commit f671cdb

File tree

5 files changed

+41
-15
lines changed

5 files changed

+41
-15
lines changed

examples/src/main/python/streaming/test_oprations.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,14 @@
66
from pyspark.streaming.duration import *
77

88
if __name__ == "__main__":
9-
if len(sys.argv) != 3:
10-
print >> sys.stderr, "Usage: wordcount <hostname> <port>"
11-
exit(-1)
129
conf = SparkConf()
1310
conf.setAppName("PythonStreamingNetworkWordCount")
1411
ssc = StreamingContext(conf=conf, duration=Seconds(1))
1512

16-
lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
17-
words = lines.flatMap(lambda line: line.split(" "))
18-
# ssc.checkpoint("checkpoint")
19-
mapped_words = words.map(lambda word: (word, 1))
20-
count = mapped_words.reduceByKey(add)
13+
test_input = ssc._testInputStream([1,1,1,1])
14+
mapped = test_input.map(lambda x: (x, 1))
15+
mapped.pyprint()
2116

22-
count.pyprint()
2317
ssc.start()
24-
ssc.awaitTermination()
18+
# ssc.awaitTermination()
2519
# ssc.stop()

python/pyspark/streaming/context.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import sys
1919
from signal import signal, SIGTERM, SIGINT
20+
from tempfile import NamedTemporaryFile
2021

2122
from pyspark.conf import SparkConf
2223
from pyspark.files import SparkFiles
@@ -138,3 +139,27 @@ def checkpoint(self, directory):
138139
"""
139140
"""
140141
self._jssc.checkpoint(directory)
142+
143+
def _testInputStream(self, test_input, numSlices=None):
144+
145+
numSlices = numSlices or self._sc.defaultParallelism
146+
# Calling the Java parallelize() method with an ArrayList is too slow,
147+
# because it sends O(n) Py4J commands. As an alternative, serialized
148+
# objects are written to a file and loaded through textFile().
149+
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
150+
# Make sure we distribute data evenly if it's smaller than self.batchSize
151+
if "__len__" not in dir(test_input):
152+
c = list(test_input) # Make it a list so we can compute its length
153+
batchSize = min(len(test_input) // numSlices, self._sc._batchSize)
154+
if batchSize > 1:
155+
serializer = BatchedSerializer(self._sc._unbatched_serializer,
156+
batchSize)
157+
else:
158+
serializer = self._sc._unbatched_serializer
159+
serializer.dump_stream(test_input, tempFile)
160+
tempFile.close()
161+
print tempFile.name
162+
jinput_stream = self._jvm.PythonTestInputStream(self._jssc,
163+
tempFile.name,
164+
numSlices).asJavaDStream()
165+
return DStream(jinput_stream, self, UTF8Deserializer())

python/pyspark/streaming/dstream.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def _mergeCombiners(iterator):
141141
combiners[k] = v
142142
else:
143143
combiners[k] = mergeCombiners(combiners[k], v)
144+
return combiners.iteritems()
144145

145146
return shuffled._mapPartitions(_mergeCombiners)
146147

streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,9 @@ class JavaStreamingContext(val ssc: StreamingContext) {
546546
* JavaStreamingContext object contains a number of utility functions.
547547
*/
548548
object JavaStreamingContext {
549+
implicit def fromStreamingContext(ssc: StreamingContext): JavaStreamingContext = new JavaStreamingContext(ssc)
550+
551+
implicit def toStreamingContext(jssc: JavaStreamingContext): StreamingContext = jssc.ssc
549552

550553
/**
551554
* Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
2323

2424
import org.apache.spark._
2525
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.api.java._
2627
import org.apache.spark.api.python._
2728
import org.apache.spark.broadcast.Broadcast
2829
import org.apache.spark.streaming.{StreamingContext, Duration, Time}
@@ -130,27 +131,29 @@ class PythonTransformedDStream(
130131
/**
131132
* This is a input stream just for the unitest. This is equivalent to a checkpointable,
132133
* replayable, reliable message queue like Kafka. It requires a sequence as input, and
133-
* returns the i_th element at the i_th batch unde manual clock.
134+
* returns the i_th element at the i_th batch under manual clock.
134135
*/
135-
class PythonTestInputStream(ssc_ : StreamingContext, filename: String, numPartitions: Int)
136-
extends InputDStream[Array[Byte]](ssc_) {
136+
class PythonTestInputStream(ssc_ : JavaStreamingContext, filename: String, numPartitions: Int)
137+
extends InputDStream[Array[Byte]](JavaStreamingContext.toStreamingContext(ssc_)){
137138

138139
def start() {}
139140

140141
def stop() {}
141142

142143
def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
143144
logInfo("Computing RDD for time " + validTime)
144-
val index = ((validTime - zeroTime) / slideDuration - 1).toInt
145+
//val index = ((validTime - zeroTime) / slideDuration - 1).toInt
145146
//val selectedInput = if (index < input.size) input(index) else Seq[T]()
146147

147148
// lets us test cases where RDDs are not created
148149
//if (filename == null)
149150
// return None
150151

151152
//val rdd = ssc.sc.makeRDD(selectedInput, numPartitions)
152-
val rdd = PythonRDD.readRDDFromFile(ssc.sc, filename, numPartitions).rdd
153+
val rdd = PythonRDD.readRDDFromFile(JavaSparkContext.fromSparkContext(ssc_.sparkContext), filename, numPartitions).rdd
153154
logInfo("Created RDD " + rdd.id + " with " + filename)
154155
Some(rdd)
155156
}
157+
158+
val asJavaDStream = JavaDStream.fromDStream(this)
156159
}

0 commit comments

Comments
 (0)