Skip to content

Commit 28c6620

Browse files
giwaKen Takagiwa
authored and
Ken Takagiwa
committed
Implemented DStream.foreachRDD in the Python API using Py4J callback server
2 parents cc2092b + 54e2e8c commit 28c6620

File tree

9 files changed

+104
-82
lines changed

9 files changed

+104
-82
lines changed

examples/src/main/python/streaming/network_wordcount.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
fm_lines = lines.flatMap(lambda x: x.split(" "))
1818
mapped_lines = fm_lines.map(lambda x: (x, 1))
1919
reduced_lines = mapped_lines.reduceByKey(add)
20-
21-
fm_lines.pyprint()
22-
mapped_lines.pyprint()
20+
2321
reduced_lines.pyprint()
2422
ssc.start()
2523
ssc.awaitTermination()

python/lib/py4j-0.8.1-src.zip

11 Bytes
Binary file not shown.

python/pyspark/java_gateway.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def run(self):
7676
EchoOutputThread(proc.stdout).start()
7777

7878
# Connect to the gateway
79-
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)
79+
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False, start_callback_server=True)
8080

8181
# Import the classes used by PySpark
8282
java_import(gateway.jvm, "org.apache.spark.SparkConf")

python/pyspark/streaming/dstream.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,6 @@ def print_(self):
4343
#hack to call print function in DStream
4444
getattr(self._jdstream, "print")()
4545

46-
def pyprint(self):
47-
"""
48-
Print the first ten elements of each RDD generated in this DStream. This is an output
49-
operator, so this DStream will be registered as an output stream and there materialized.
50-
51-
"""
52-
self._jdstream.pyprint()
53-
5446
def filter(self, f):
5547
"""
5648
Return DStream containing only the elements that satisfy predicate.
@@ -190,6 +182,38 @@ def getNumPartitions(self):
190182
# TODO: remove hardcoding. RDD has NumPartitions but DStream does not have.
191183
return 2
192184

185+
def foreachRDD(self, func):
186+
"""
187+
"""
188+
from utils import RDDFunction
189+
wrapped_func = RDDFunction(self.ctx, self._jrdd_deserializer, func)
190+
self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), wrapped_func)
191+
192+
def pyprint(self):
193+
"""
194+
Print the first ten elements of each RDD generated in this DStream. This is an output
195+
operator, so this DStream will be registered as an output stream and there materialized.
196+
197+
"""
198+
def takeAndPrint(rdd, time):
199+
taken = rdd.take(11)
200+
print "-------------------------------------------"
201+
print "Time: %s" % (str(time))
202+
print "-------------------------------------------"
203+
for record in taken[:10]:
204+
print record
205+
if len(taken) > 10:
206+
print "..."
207+
print
208+
209+
self.foreachRDD(takeAndPrint)
210+
211+
212+
#def transform(self, func):
213+
# from utils import RDDFunction
214+
# wrapped_func = RDDFunction(self.ctx, self._jrdd_deserializer, func)
215+
# jdstream = self.ctx._jvm.PythonTransformedDStream(self._jdstream.dstream(), wrapped_func).toJavaDStream
216+
# return DStream(jdstream, self._ssc, ...) ## DO NOT KNOW HOW
193217

194218
class PipelinedDStream(DStream):
195219
def __init__(self, prev, func, preservesPartitioning=False):
@@ -209,7 +233,6 @@ def pipeline_func(split, iterator):
209233
self._prev_jdstream = prev._prev_jdstream # maintain the pipeline
210234
self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
211235
self.is_cached = False
212-
self.is_checkpointed = False
213236
self._ssc = prev._ssc
214237
self.ctx = prev.ctx
215238
self.prev = prev
@@ -246,4 +269,5 @@ def _jdstream(self):
246269
return self._jdstream_val
247270

248271
def _is_pipelinable(self):
249-
return not (self.is_cached or self.is_checkpointed)
272+
return not (self.is_cached)
273+

python/pyspark/streaming/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,28 @@
1515
# limitations under the License.
1616
#
1717

18+
from pyspark.rdd import RDD
19+
20+
21+
class RDDFunction():
22+
def __init__(self, ctx, jrdd_deserializer, func):
23+
self.ctx = ctx
24+
self.deserializer = jrdd_deserializer
25+
self.func = func
26+
27+
def call(self, jrdd, time):
28+
# Wrap JavaRDD into python's RDD class
29+
rdd = RDD(jrdd, self.ctx, self.deserializer)
30+
# Call user defined RDD function
31+
self.func(rdd, time)
32+
33+
def __str__(self):
34+
return "%s, %s" % (str(self.deserializer), str(self.func))
35+
36+
class Java:
37+
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']
38+
39+
1840

1941
def msDurationToString(ms):
2042
"""

streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,6 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
5454
dstream.print()
5555
}
5656

57-
/**
58-
* Print the first ten elements of each PythonRDD generated in the PythonDStream. This is an output
59-
* operator, so this PythonDStream will be registered as an output stream and there materialized.
60-
* This function is for PythonAPI.
61-
*/
62-
//TODO move this function to PythonDStream
63-
def pyprint() = dstream.pyprint()
64-
6557
/**
6658
* Return a new DStream in which each RDD has a single element generated by counting each RDD
6759
* of this DStream.

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ class PythonDStream[T: ClassTag](
5656
}
5757
}
5858

59+
def foreachRDD(foreachFunc: PythonRDDFunction) {
60+
new PythonForeachDStream(this, context.sparkContext.clean(foreachFunc, false)).register()
61+
}
62+
5963
val asJavaDStream = JavaDStream.fromDStream(this)
6064
}
6165

@@ -85,5 +89,39 @@ DStream[Array[Byte]](prev.ssc){
8589
case None => None
8690
}
8791
}
92+
93+
val asJavaDStream = JavaDStream.fromDStream(this)
94+
}
95+
96+
class PythonForeachDStream(
97+
prev: DStream[Array[Byte]],
98+
foreachFunction: PythonRDDFunction
99+
) extends ForEachDStream[Array[Byte]](
100+
prev,
101+
(rdd: RDD[Array[Byte]], time: Time) => {
102+
foreachFunction.call(rdd.toJavaRDD(), time.milliseconds)
103+
}
104+
) {
105+
106+
this.register()
107+
}
108+
/*
109+
This does not work. Ignore this for now. -TD
110+
class PythonTransformedDStream(
111+
prev: DStream[Array[Byte]],
112+
transformFunction: PythonRDDFunction
113+
) extends DStream[Array[Byte]](prev.ssc) {
114+
115+
override def dependencies = List(prev)
116+
117+
override def slideDuration: Duration = prev.slideDuration
118+
119+
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
120+
prev.getOrCompute(validTime).map(rdd => {
121+
transformFunction.call(rdd.toJavaRDD(), validTime.milliseconds).rdd
122+
})
123+
}
124+
88125
val asJavaDStream = JavaDStream.fromDStream(this)
89126
}
127+
*/
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package org.apache.spark.streaming.api.python;
2+
3+
import org.apache.spark.api.java.JavaRDD;
4+
import org.apache.spark.streaming.Time;
5+
6+
public interface PythonRDDFunction {
7+
JavaRDD<byte[]> call(JavaRDD<byte[]> rdd, long time);
8+
}

streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -623,66 +623,6 @@ abstract class DStream[T: ClassTag] (
623623
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
624624
}
625625

626-
//TODO: move pyprint to PythonDStream and executed by py4j call back function
627-
/**
628-
* Print the first ten elements of each PythonRDD generated in this PythonDStream. This is an output
629-
* operator, so this PythonDStream will be registered as an output stream and there materialized.
630-
* Since serialized Python object is readable by Python, pyprint writes out binary data to
631-
* temporary file and run python script to deserialized and print the first ten elements
632-
*
633-
* Currently call python script directly. We should avoid this
634-
*/
635-
private[streaming] def pyprint() {
636-
def foreachFunc = (rdd: RDD[T], time: Time) => {
637-
val iter = rdd.take(11).iterator
638-
639-
// Generate a temporary file
640-
val prefix = "spark"
641-
val suffix = ".tmp"
642-
val tempFile = File.createTempFile(prefix, suffix)
643-
val tempFileStream = new DataOutputStream(new FileOutputStream(tempFile.getAbsolutePath))
644-
// Write out serialized python object to temporary file
645-
PythonRDD.writeIteratorToStream(iter, tempFileStream)
646-
tempFileStream.close()
647-
648-
// pythonExec should be passed from python. Move pyprint to PythonDStream
649-
val pythonExec = new ProcessBuilder().environment().get("PYSPARK_PYTHON")
650-
651-
val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
652-
// Call python script to deserialize and print result in stdout
653-
val pb = new ProcessBuilder(pythonExec, sparkHome + "/python/pyspark/streaming/pyprint.py", tempFile.getAbsolutePath)
654-
val workerEnv = pb.environment()
655-
656-
// envVars also should be pass from python
657-
val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH")
658-
workerEnv.put("PYTHONPATH", pythonPath)
659-
val worker = pb.start()
660-
val is = worker.getInputStream()
661-
val isr = new InputStreamReader(is)
662-
val br = new BufferedReader(isr)
663-
664-
println ("-------------------------------------------")
665-
println ("Time: " + time)
666-
println ("-------------------------------------------")
667-
668-
// Print values which is from python std out
669-
var line = ""
670-
breakable {
671-
while (true) {
672-
line = br.readLine()
673-
if (line == null) break()
674-
println(line)
675-
}
676-
}
677-
// Delete temporary file
678-
tempFile.delete()
679-
println()
680-
681-
}
682-
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
683-
}
684-
685-
686626
/**
687627
* Return a new DStream in which each RDD contains all the elements in seen in a
688628
* sliding window of time over this DStream. The new DStream generates RDDs with

0 commit comments

Comments
 (0)