Skip to content

Commit b1d2a30

Browse files
tdasgiwa
authored andcommitted
Implemented DStream.foreachRDD in the Python API using Py4J callback server.
1 parent 678e854 commit b1d2a30

File tree

7 files changed

+95
-82
lines changed

7 files changed

+95
-82
lines changed

examples/src/main/python/streaming/network_wordcount.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
fm_lines = lines.flatMap(lambda x: x.split(" "))
1818
mapped_lines = fm_lines.map(lambda x: (x, 1))
1919
reduced_lines = mapped_lines.reduceByKey(add)
20-
21-
fm_lines.pyprint()
22-
mapped_lines.pyprint()
20+
2321
reduced_lines.pyprint()
2422
ssc.start()
2523
ssc.awaitTermination()

python/pyspark/java_gateway.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def run(self):
102102
EchoOutputThread(proc.stdout).start()
103103

104104
# Connect to the gateway
105-
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)
105+
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False, start_callback_server=True)
106106

107107
# Import the classes used by PySpark
108108
java_import(gateway.jvm, "org.apache.spark.SparkConf")

python/pyspark/streaming/dstream.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,6 @@ def print_(self):
4343
#hack to call print function in DStream
4444
getattr(self._jdstream, "print")()
4545

46-
def pyprint(self):
47-
"""
48-
Print the first ten elements of each RDD generated in this DStream. This is an output
49-
operator, so this DStream will be registered as an output stream and there materialized.
50-
51-
"""
52-
self._jdstream.pyprint()
53-
5446
def filter(self, f):
5547
"""
5648
Return DStream containing only the elements that satisfy predicate.
@@ -203,6 +195,38 @@ def getNumPartitions(self):
203195
return 2
204196
>>>>>>> clean up code
205197

198+
def foreachRDD(self, func):
199+
"""
200+
"""
201+
from utils import RDDFunction
202+
wrapped_func = RDDFunction(self.ctx, self._jrdd_deserializer, func)
203+
self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), wrapped_func)
204+
205+
def pyprint(self):
206+
"""
207+
Print the first ten elements of each RDD generated in this DStream. This is an output
208+
operator, so this DStream will be registered as an output stream and there materialized.
209+
210+
"""
211+
def takeAndPrint(rdd, time):
212+
taken = rdd.take(11)
213+
print "-------------------------------------------"
214+
print "Time: %s" % (str(time))
215+
print "-------------------------------------------"
216+
for record in taken[:10]:
217+
print record
218+
if len(taken) > 10:
219+
print "..."
220+
print
221+
222+
self.foreachRDD(takeAndPrint)
223+
224+
225+
#def transform(self, func):
226+
# from utils import RDDFunction
227+
# wrapped_func = RDDFunction(self.ctx, self._jrdd_deserializer, func)
228+
# jdstream = self.ctx._jvm.PythonTransformedDStream(self._jdstream.dstream(), wrapped_func).toJavaDStream
229+
# return DStream(jdstream, self._ssc, ...) ## DO NOT KNOW HOW
206230

207231
class PipelinedDStream(DStream):
208232
def __init__(self, prev, func, preservesPartitioning=False):
@@ -222,7 +246,6 @@ def pipeline_func(split, iterator):
222246
self._prev_jdstream = prev._prev_jdstream # maintain the pipeline
223247
self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
224248
self.is_cached = False
225-
self.is_checkpointed = False
226249
self._ssc = prev._ssc
227250
self.ctx = prev.ctx
228251
self.prev = prev
@@ -259,4 +282,5 @@ def _jdstream(self):
259282
return self._jdstream_val
260283

261284
def _is_pipelinable(self):
262-
return not (self.is_cached or self.is_checkpointed)
285+
return not (self.is_cached)
286+

python/pyspark/streaming/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,27 @@
1515
# limitations under the License.
1616
#
1717

18+
from pyspark.rdd import RDD
19+
20+
class RDDFunction():
21+
def __init__(self, ctx, jrdd_deserializer, func):
22+
self.ctx = ctx
23+
self.deserializer = jrdd_deserializer
24+
self.func = func
25+
26+
def call(self, jrdd, time):
27+
# Wrap JavaRDD into python's RDD class
28+
rdd = RDD(jrdd, self.ctx, self.deserializer)
29+
# Call user defined RDD function
30+
self.func(rdd, time)
31+
32+
def __str__(self):
33+
return "%s, %s" % (str(self.deserializer), str(self.func))
34+
35+
class Java:
36+
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']
37+
38+
1839

1940
def msDurationToString(ms):
2041
"""

streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,6 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
5454
dstream.print()
5555
}
5656

57-
/**
58-
* Print the first ten elements of each PythonRDD generated in the PythonDStream. This is an output
59-
* operator, so this PythonDStream will be registered as an output stream and there materialized.
60-
* This function is for PythonAPI.
61-
*/
62-
//TODO move this function to PythonDStream
63-
def pyprint() = dstream.pyprint()
64-
6557
/**
6658
* Return a new DStream in which each RDD has a single element generated by counting each RDD
6759
* of this DStream.

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ class PythonDStream[T: ClassTag](
5656
}
5757
}
5858

59+
def foreachRDD(foreachFunc: PythonRDDFunction) {
60+
new PythonForeachDStream(this, context.sparkContext.clean(foreachFunc, false)).register()
61+
}
62+
5963
val asJavaDStream = JavaDStream.fromDStream(this)
6064

6165
/**
@@ -160,6 +164,40 @@ DStream[Array[Byte]](prev.ssc){
160164
case None => None
161165
}
162166
}
167+
168+
val asJavaDStream = JavaDStream.fromDStream(this)
169+
}
170+
171+
class PythonForeachDStream(
172+
prev: DStream[Array[Byte]],
173+
foreachFunction: PythonRDDFunction
174+
) extends ForEachDStream[Array[Byte]](
175+
prev,
176+
(rdd: RDD[Array[Byte]], time: Time) => {
177+
foreachFunction.call(rdd.toJavaRDD(), time.milliseconds)
178+
}
179+
) {
180+
181+
this.register()
182+
}
183+
/*
184+
This does not work. Ignore this for now. -TD
185+
class PythonTransformedDStream(
186+
prev: DStream[Array[Byte]],
187+
transformFunction: PythonRDDFunction
188+
) extends DStream[Array[Byte]](prev.ssc) {
189+
190+
override def dependencies = List(prev)
191+
192+
override def slideDuration: Duration = prev.slideDuration
193+
194+
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
195+
prev.getOrCompute(validTime).map(rdd => {
196+
transformFunction.call(rdd.toJavaRDD(), validTime.milliseconds).rdd
197+
})
198+
}
199+
163200
val asJavaDStream = JavaDStream.fromDStream(this)
164201
//val asJavaPairDStream : JavaPairDStream[Long, Array[Byte]] = JavaPairDStream.fromJavaDStream(this)
165202
}
203+
*/

streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -623,66 +623,6 @@ abstract class DStream[T: ClassTag] (
623623
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
624624
}
625625

626-
//TODO: move pyprint to PythonDStream and executed by py4j call back function
627-
/**
628-
* Print the first ten elements of each PythonRDD generated in this PythonDStream. This is an output
629-
* operator, so this PythonDStream will be registered as an output stream and there materialized.
630-
* Since serialized Python object is readable by Python, pyprint writes out binary data to
631-
* temporary file and run python script to deserialized and print the first ten elements
632-
*
633-
* Currently call python script directly. We should avoid this
634-
*/
635-
private[streaming] def pyprint() {
636-
def foreachFunc = (rdd: RDD[T], time: Time) => {
637-
val iter = rdd.take(11).iterator
638-
639-
// Generate a temporary file
640-
val prefix = "spark"
641-
val suffix = ".tmp"
642-
val tempFile = File.createTempFile(prefix, suffix)
643-
val tempFileStream = new DataOutputStream(new FileOutputStream(tempFile.getAbsolutePath))
644-
// Write out serialized python object to temporary file
645-
PythonRDD.writeIteratorToStream(iter, tempFileStream)
646-
tempFileStream.close()
647-
648-
// pythonExec should be passed from python. Move pyprint to PythonDStream
649-
val pythonExec = new ProcessBuilder().environment().get("PYSPARK_PYTHON")
650-
651-
val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
652-
// Call python script to deserialize and print result in stdout
653-
val pb = new ProcessBuilder(pythonExec, sparkHome + "/python/pyspark/streaming/pyprint.py", tempFile.getAbsolutePath)
654-
val workerEnv = pb.environment()
655-
656-
// envVars also should be pass from python
657-
val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH")
658-
workerEnv.put("PYTHONPATH", pythonPath)
659-
val worker = pb.start()
660-
val is = worker.getInputStream()
661-
val isr = new InputStreamReader(is)
662-
val br = new BufferedReader(isr)
663-
664-
println ("-------------------------------------------")
665-
println ("Time: " + time)
666-
println ("-------------------------------------------")
667-
668-
// Print values which is from python std out
669-
var line = ""
670-
breakable {
671-
while (true) {
672-
line = br.readLine()
673-
if (line == null) break()
674-
println(line)
675-
}
676-
}
677-
// Delete temporary file
678-
tempFile.delete()
679-
println()
680-
681-
}
682-
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
683-
}
684-
685-
686626
/**
687627
* Return a new DStream in which each RDD contains all the elements in seen in a
688628
* sliding window of time over this DStream. The new DStream generates RDDs with

0 commit comments

Comments
 (0)