Skip to content

Commit fce0ef5

Browse files
committed
rafactor of foreachRDD()
1 parent 7001b51 commit fce0ef5

File tree

2 files changed

+26
-32
lines changed

2 files changed

+26
-32
lines changed

python/pyspark/streaming/dstream.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def foreachRDD(self, func):
142142
stream and there materialized.
143143
"""
144144
jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self._jrdd_deserializer)
145-
self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), jfunc)
145+
api = self._ssc._jvm.PythonDStream
146+
api.callForeachRDD(self._jdstream, jfunc)
146147

147148
def pprint(self):
148149
"""

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

+24-31
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ package org.apache.spark.streaming.api.python
2020
import java.util.{ArrayList => JArrayList}
2121
import scala.collection.JavaConversions._
2222

23-
import org.apache.spark.rdd.RDD
2423
import org.apache.spark.api.java._
24+
import org.apache.spark.api.java.function.{Function2 => JFunction2}
2525
import org.apache.spark.api.python._
26+
import org.apache.spark.rdd.RDD
2627
import org.apache.spark.storage.StorageLevel
2728
import org.apache.spark.streaming.{Interval, Duration, Time}
2829
import org.apache.spark.streaming.dstream._
@@ -35,19 +36,22 @@ trait PythonRDDFunction {
3536
def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]]
3637
}
3738

38-
class RDDFunction(pfunc: PythonRDDFunction) {
39-
def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
40-
val jrdd = if (rdd.isDefined) {
39+
class RDDFunction(pfunc: PythonRDDFunction) extends Serializable {
40+
41+
def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
42+
apply(rdd, None, time)
43+
}
44+
45+
def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = {
46+
if (rdd.isDefined) {
4147
JavaRDD.fromRDD(rdd.get)
4248
} else {
4349
null
4450
}
45-
val jrdd2 = if (rdd2.isDefined) {
46-
JavaRDD.fromRDD(rdd2.get)
47-
} else {
48-
null
49-
}
50-
val r = pfunc.call(jrdd, jrdd2, time.milliseconds)
51+
}
52+
53+
def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
54+
val r = pfunc.call(wrapRDD(rdd), wrapRDD(rdd2), time.milliseconds)
5155
if (r != null) {
5256
Some(r.rdd)
5357
} else {
@@ -66,7 +70,13 @@ abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (p
6670
val asJavaDStream = JavaDStream.fromDStream(this)
6771
}
6872

69-
object PythonDStream {
73+
private[spark] object PythonDStream {
74+
75+
// helper function for DStream.foreachRDD(),
76+
// cannot be `foreachRDD`, it will confusing py4j
77+
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction): Unit = {
78+
jdstream.dstream.foreachRDD((rdd, time) => pyfunc.call(rdd, null, time.milliseconds))
79+
}
7080

7181
// convert list of RDD into queue of RDDs, for ssc.queueStream()
7282
def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = {
@@ -97,7 +107,7 @@ private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: Python
97107
if (reuse && lastResult != null) {
98108
Some(lastResult.copyTo(rdd1.get))
99109
} else {
100-
val r = func(rdd1, None, validTime)
110+
val r = func(rdd1, validTime)
101111
if (reuse && r.isDefined && lastResult == null) {
102112
r.get match {
103113
case rdd: PythonRDD =>
@@ -206,8 +216,9 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
206216
// Get the RDD of the reduced value of the previous window
207217
val previousWindowRDD = getOrCompute(previousWindow.endTime)
208218

219+
// for small window, reduce once will be better than twice
209220
if (windowDuration > slideDuration * 5 && previousWindowRDD.isDefined) {
210-
// subtle the values from old RDDs
221+
// subtract the values from old RDDs
211222
val oldRDDs =
212223
parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration)
213224
val subbed = if (oldRDDs.size > 0) {
@@ -236,22 +247,4 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
236247
}
237248
}
238249
}
239-
}
240-
241-
/**
242-
* This is used for foreachRDD() in Python
243-
*/
244-
class PythonForeachDStream(
245-
prev: DStream[Array[Byte]],
246-
foreachFunction: PythonRDDFunction
247-
) extends ForEachDStream[Array[Byte]](
248-
prev,
249-
(rdd: RDD[Array[Byte]], time: Time) => {
250-
if (rdd != null) {
251-
foreachFunction.call(rdd, null, time.milliseconds)
252-
}
253-
}
254-
) {
255-
256-
this.register()
257250
}

0 commit comments

Comments
 (0)