Skip to content

Commit 98ac6c2

Browse files
committed
support ssc.transform()
1 parent b983f0f commit 98ac6c2

File tree

6 files changed

+96
-48
lines changed

6 files changed

+96
-48
lines changed

python/pyspark/streaming/context.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pyspark.context import SparkContext
2121
from pyspark.storagelevel import StorageLevel
2222
from pyspark.streaming.dstream import DStream
23+
from pyspark.streaming.util import RDDFunction
2324

2425
from py4j.java_collections import ListConverter
2526
from py4j.java_gateway import java_import
@@ -212,11 +213,20 @@ def queueStream(self, queue, oneAtATime=True, default=None):
212213

213214
def transform(self, dstreams, transformFunc):
214215
"""
215-
Create a new DStream in which each RDD is generated by applying a function on RDDs of
216-
the DStreams. The order of the JavaRDDs in the transform function parameter will be the
217-
same as the order of corresponding DStreams in the list.
216+
Create a new DStream in which each RDD is generated by applying
217+
a function on RDDs of the DStreams. The order of the JavaRDDs in
218+
the transform function parameter will be the same as the order
219+
of corresponding DStreams in the list.
218220
"""
219-
# TODO
221+
jdstreams = ListConverter().convert([d._jdstream for d in dstreams],
222+
SparkContext._gateway._gateway_client)
223+
# change the final serializer to sc.serializer
224+
jfunc = RDDFunction(self._sc,
225+
lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
226+
*[d._jrdd_deserializer for d in dstreams])
227+
228+
jdstream = self._jvm.PythonDStream.callTransform(self._jssc, jdstreams, jfunc)
229+
return DStream(jdstream, self, self._sc.serializer)
220230

221231
def union(self, *dstreams):
222232
"""

python/pyspark/streaming/dstream.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
132132
return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc))
133133

134134
def foreach(self, func):
135-
return self.foreachRDD(lambda rdd, _: rdd.foreach(func))
135+
return self.foreachRDD(lambda _, rdd: rdd.foreach(func))
136136

137137
def foreachRDD(self, func):
138138
"""
@@ -142,7 +142,7 @@ def foreachRDD(self, func):
142142
This is an output operator, so this DStream will be registered as an output
143143
stream and there materialized.
144144
"""
145-
jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self._jrdd_deserializer)
145+
jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer)
146146
api = self._ssc._jvm.PythonDStream
147147
api.callForeachRDD(self._jdstream, jfunc)
148148

@@ -151,10 +151,10 @@ def pprint(self):
151151
Print the first ten elements of each RDD generated in this DStream. This is an output
152152
operator, so this DStream will be registered as an output stream and there materialized.
153153
"""
154-
def takeAndPrint(rdd, time):
154+
def takeAndPrint(timestamp, rdd):
155155
taken = rdd.take(11)
156156
print "-------------------------------------------"
157-
print "Time: %s" % datetime.fromtimestamp(time / 1000.0)
157+
print "Time: %s" % datetime.fromtimestamp(timestamp / 1000.0)
158158
print "-------------------------------------------"
159159
for record in taken[:10]:
160160
print record
@@ -176,15 +176,15 @@ def take(self, n):
176176
"""
177177
rdds = []
178178

179-
def take(rdd, _):
180-
if rdd:
179+
def take(_, rdd):
180+
if rdd and len(rdds) < n:
181181
rdds.append(rdd)
182-
if len(rdds) == n:
183-
# FIXME: NPE in JVM
184-
self._ssc.stop(False)
185182
self.foreachRDD(take)
183+
186184
self._ssc.start()
187-
self._ssc.awaitTermination()
185+
while len(rdds) < n:
186+
time.sleep(0.01)
187+
self._ssc.stop(False, True)
188188
return rdds
189189

190190
def collect(self):
@@ -195,7 +195,7 @@ def collect(self):
195195
"""
196196
result = []
197197

198-
def get_output(rdd, time):
198+
def get_output(_, rdd):
199199
r = rdd.collect()
200200
result.append(r)
201201
self.foreachRDD(get_output)
@@ -317,7 +317,7 @@ def transform(self, func):
317317
Return a new DStream in which each RDD is generated by applying a function
318318
on each RDD of 'this' DStream.
319319
"""
320-
return TransformedDStream(self, lambda a, t: func(a), True)
320+
return TransformedDStream(self, lambda t, a: func(a), True)
321321

322322
def transformWithTime(self, func):
323323
"""
@@ -331,7 +331,7 @@ def transformWith(self, func, other, keepSerializer=False):
331331
Return a new DStream in which each RDD is generated by applying a function
332332
on each RDD of 'this' DStream and 'other' DStream.
333333
"""
334-
jfunc = RDDFunction(self.ctx, lambda a, b, t: func(a, b), self._jrdd_deserializer)
334+
jfunc = RDDFunction(self.ctx, lambda t, a, b: func(a, b), self._jrdd_deserializer)
335335
dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
336336
other._jdstream.dstream(), jfunc)
337337
jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer
@@ -549,14 +549,14 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None
549549
self._check_window(windowDuration, slideDuration)
550550
reduced = self.reduceByKey(func)
551551

552-
def reduceFunc(a, b, t):
552+
def reduceFunc(t, a, b):
553553
b = b.reduceByKey(func, numPartitions)
554554
r = a.union(b).reduceByKey(func, numPartitions) if a else b
555555
if filterFunc:
556556
r = r.filter(filterFunc)
557557
return r
558558

559-
def invReduceFunc(a, b, t):
559+
def invReduceFunc(t, a, b):
560560
b = b.reduceByKey(func, numPartitions)
561561
joined = a.leftOuterJoin(b, numPartitions)
562562
return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)
@@ -582,7 +582,7 @@ def updateStateByKey(self, updateFunc, numPartitions=None):
582582
@param updateFunc State update function ([(k, vs, s)] -> [(k, s)]).
583583
If `s` is None, then `k` will be eliminated.
584584
"""
585-
def reduceFunc(a, b, t):
585+
def reduceFunc(t, a, b):
586586
if a is None:
587587
g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None))
588588
else:
@@ -610,7 +610,7 @@ def __init__(self, prev, func, reuse=False):
610610
not prev.is_cached and not prev.is_checkpointed):
611611
prev_func = prev.func
612612
old_func = func
613-
func = lambda rdd, t: old_func(prev_func(rdd, t), t)
613+
func = lambda t, rdd: old_func(t, prev_func(t, rdd))
614614
reuse = reuse and prev.reuse
615615
prev = prev.prev
616616

@@ -625,7 +625,7 @@ def _jdstream(self):
625625
return self._jdstream_val
626626

627627
func = self.func
628-
jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self.prev._jrdd_deserializer)
628+
jfunc = RDDFunction(self.ctx, func, self.prev._jrdd_deserializer)
629629
jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(),
630630
jfunc, self.reuse).asJavaDStream()
631631
self._jdstream_val = jdstream

python/pyspark/streaming/tests.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,19 @@ def test_union(self):
374374
expected = [i * 2 for i in input]
375375
self.assertEqual(expected, result[:3])
376376

377+
def test_transform(self):
378+
dstream1 = self.ssc.queueStream([[1]])
379+
dstream2 = self.ssc.queueStream([[2]])
380+
dstream3 = self.ssc.queueStream([[3]])
381+
382+
def func(rdds):
383+
rdd1, rdd2, rdd3 = rdds
384+
return rdd2.union(rdd3).union(rdd1)
385+
386+
dstream = self.ssc.transform([dstream1, dstream2, dstream3], func)
387+
388+
self.assertEqual([2, 3, 1], dstream.first().collect())
389+
377390

378391
if __name__ == "__main__":
379392
unittest.main()

python/pyspark/streaming/util.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,33 @@ class RDDFunction(object):
2222
"""
2323
This class is for py4j callback.
2424
"""
25-
def __init__(self, ctx, func, deserializer, deserializer2=None):
25+
def __init__(self, ctx, func, *deserializers):
2626
self.ctx = ctx
2727
self.func = func
28-
self.deserializer = deserializer
29-
self.deserializer2 = deserializer2 or deserializer
28+
self.deserializers = deserializers
29+
emptyRDD = getattr(self.ctx, "_emptyRDD", None)
30+
if emptyRDD is None:
31+
self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache()
32+
self.emptyRDD = emptyRDD
3033

31-
def call(self, jrdd, jrdd2, milliseconds):
34+
def call(self, milliseconds, jrdds):
3235
try:
33-
emptyRDD = getattr(self.ctx, "_emptyRDD", None)
34-
if emptyRDD is None:
35-
self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache()
36+
# extend deserializers with the first one
37+
sers = self.deserializers
38+
if len(sers) < len(jrdds):
39+
sers += (sers[0],) * (len(jrdds) - len(sers))
3640

37-
rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else emptyRDD
38-
other = RDD(jrdd2, self.ctx, self.deserializer2) if jrdd2 else emptyRDD
39-
r = self.func(rdd, other, milliseconds)
41+
rdds = [RDD(jrdd, self.ctx, ser) if jrdd else self.emptyRDD
42+
for jrdd, ser in zip(jrdds, sers)]
43+
r = self.func(milliseconds, *rdds)
4044
if r:
4145
return r._jrdd
4246
except Exception:
4347
import traceback
4448
traceback.print_exc()
4549

4650
def __repr__(self):
47-
return "RDDFunction2(%s)" % (str(self.func))
51+
return "RDDFunction(%s)" % (str(self.func))
4852

4953
class Java:
5054
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']

streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ class StreamingContext private[streaming] (
413413
dstreams: Seq[DStream[_]],
414414
transformFunc: (Seq[RDD[_]], Time) => RDD[T]
415415
): DStream[T] = {
416-
new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc))
416+
new TransformedDStream[T](dstreams, (transformFunc))
417417
}
418418

419419
/** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,32 @@
1717

1818
package org.apache.spark.streaming.api.python
1919

20-
import java.util.{ArrayList => JArrayList}
20+
import java.util.{ArrayList => JArrayList, List => JList}
2121
import scala.collection.JavaConversions._
22+
import scala.collection.JavaConverters._
23+
import scala.collection.mutable
2224

2325
import org.apache.spark.api.java._
24-
import org.apache.spark.api.java.function.{Function2 => JFunction2}
2526
import org.apache.spark.api.python._
2627
import org.apache.spark.rdd.RDD
2728
import org.apache.spark.storage.StorageLevel
2829
import org.apache.spark.streaming.{Interval, Duration, Time}
2930
import org.apache.spark.streaming.dstream._
3031
import org.apache.spark.streaming.api.java._
3132

33+
3234
/**
3335
* Interface for Python callback function with three arguments
3436
*/
3537
trait PythonRDDFunction {
36-
def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]]
38+
def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
3739
}
3840

39-
class RDDFunction(pfunc: PythonRDDFunction) extends Serializable {
40-
41-
def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
42-
apply(rdd, None, time)
43-
}
41+
/**
42+
* Wrapper for PythonRDDFunction
43+
*/
44+
class RDDFunction(pfunc: PythonRDDFunction)
45+
extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable {
4446

4547
def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = {
4648
if (rdd.isDefined) {
@@ -50,14 +52,25 @@ class RDDFunction(pfunc: PythonRDDFunction) extends Serializable {
5052
}
5153
}
5254

53-
def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
54-
val r = pfunc.call(wrapRDD(rdd), wrapRDD(rdd2), time.milliseconds)
55-
if (r != null) {
56-
Some(r.rdd)
55+
def some(jrdd: JavaRDD[Array[Byte]]): Option[RDD[Array[Byte]]] = {
56+
if (jrdd != null) {
57+
Some(jrdd.rdd)
5758
} else {
5859
None
5960
}
6061
}
62+
63+
def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
64+
some(pfunc.call(time.milliseconds, List(wrapRDD(rdd)).asJava))
65+
}
66+
67+
def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
68+
some(pfunc.call(time.milliseconds, List(wrapRDD(rdd), wrapRDD(rdd2)).asJava))
69+
}
70+
71+
def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = {
72+
pfunc.call(time.milliseconds, rdds)
73+
}
6174
}
6275

6376
private[python]
@@ -74,8 +87,16 @@ private[spark] object PythonDStream {
7487

7588
// helper function for DStream.foreachRDD(),
7689
// cannot be `foreachRDD`, it will confusing py4j
77-
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction): Unit = {
78-
jdstream.dstream.foreachRDD((rdd, time) => pyfunc.call(rdd, null, time.milliseconds))
90+
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction){
91+
val func = new RDDFunction(pyfunc)
92+
jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
93+
}
94+
95+
// helper function for ssc.transform()
96+
def callTransform(ssc: JavaStreamingContext, jdsteams: JList[JavaDStream[_]], pyfunc: PythonRDDFunction)
97+
:JavaDStream[Array[Byte]] = {
98+
val func = new RDDFunction(pyfunc)
99+
ssc.transform(jdsteams, func)
79100
}
80101

81102
// convert list of RDD into queue of RDDs, for ssc.queueStream()

0 commit comments

Comments
 (0)