Skip to content

Commit c28f520

Browse files
committed
support updateStateByKey
1 parent d357b70 commit c28f520

File tree

4 files changed

+83
-21
lines changed

4 files changed

+83
-21
lines changed

python/pyspark/streaming/dstream.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,9 @@ def reduceByKeyAndWindow(self, func, invFunc,
366366
windowDuration, slideDuration, numPartitions=None):
367367
reduced = self.reduceByKey(func)
368368

369-
def reduceFunc(a, t):
370-
return a.reduceByKey(func, numPartitions)
369+
def reduceFunc(a, b, t):
370+
b = b.reduceByKey(func, numPartitions)
371+
return a.union(b).reduceByKey(func, numPartitions) if a else b
371372

372373
def invReduceFunc(a, b, t):
373374
b = b.reduceByKey(func, numPartitions)
@@ -378,19 +379,30 @@ def invReduceFunc(a, b, t):
378379
windowDuration = Seconds(windowDuration)
379380
if not isinstance(slideDuration, Duration):
380381
slideDuration = Seconds(slideDuration)
381-
serializer = reduced._jrdd_deserializer
382-
jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer)
382+
jreduceFunc = RDDFunction2(self.ctx, reduceFunc, reduced._jrdd_deserializer)
383383
jinvReduceFunc = RDDFunction2(self.ctx, invReduceFunc, reduced._jrdd_deserializer)
384384
dstream = self.ctx._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(),
385385
jreduceFunc, jinvReduceFunc,
386386
windowDuration._jduration,
387387
slideDuration._jduration)
388-
return DStream(dstream.asJavaDStream(), self._ssc, serializer)
388+
return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer)
389+
390+
def updateStateByKey(self, updateFunc, numPartitions=None):
391+
"""
392+
:param updateFunc: [(k, vs, s)] -> [(k, s)]
393+
"""
394+
def reduceFunc(a, b, t):
395+
if a is None:
396+
g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None))
397+
else:
398+
g = a.cogroup(b).map(lambda (k, (va, vb)):
399+
(k, list(vb), list(va)[0] if len(va) else None))
400+
return g.mapPartitions(lambda x: updateFunc(x) or [])
389401

390-
def updateStateByKey(self, updateFunc):
391-
# FIXME: convert updateFunc to java JFunction2
392-
jFunc = updateFunc
393-
return self._jdstream.updateStateByKey(jFunc)
402+
jreduceFunc = RDDFunction2(self.ctx, reduceFunc,
403+
self.ctx.serializer, self._jrdd_deserializer)
404+
dstream = self.ctx._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
405+
return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer)
394406

395407

396408
class TransformedDStream(DStream):

python/pyspark/streaming/tests.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,25 @@ def func(dstream):
294294
[('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
295295
self._test_func(input, func, expected)
296296

297+
def update_state_by_key(self):
298+
299+
def updater(it):
300+
for k, vs, s in it:
301+
if not s:
302+
s = vs
303+
else:
304+
s.extend(vs)
305+
yield (k, s)
306+
307+
input = [[('k', i)] for i in range(5)]
308+
309+
def func(dstream):
310+
return dstream.updateStateByKey(updater)
311+
312+
expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
313+
expected = [[('k', v)] for v in expected]
314+
self._test_func(input, func, expected)
315+
297316

298317
class TestStreamingContext(unittest.TestCase):
299318
def setUp(self):

python/pyspark/streaming/util.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,16 @@ class RDDFunction2(object):
5050
This class is for py4j callback. This class is related with
5151
org.apache.spark.streaming.api.python.PythonRDDFunction2.
5252
"""
53-
def __init__(self, ctx, func, jrdd_deserializer):
53+
def __init__(self, ctx, func, jrdd_deserializer, jrdd_deserializer2=None):
5454
self.ctx = ctx
5555
self.func = func
56-
self.deserializer = jrdd_deserializer
56+
self.jrdd_deserializer = jrdd_deserializer
57+
self.jrdd_deserializer2 = jrdd_deserializer2 or jrdd_deserializer
5758

5859
def call(self, jrdd, jrdd2, milliseconds):
5960
try:
60-
rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else None
61-
other = RDD(jrdd2, self.ctx, self.deserializer) if jrdd2 else None
61+
rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else None
62+
other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else None
6263
r = self.func(rdd, other, milliseconds)
6364
if r:
6465
return r._jrdd
@@ -67,7 +68,7 @@ def call(self, jrdd, jrdd2, milliseconds):
6768
traceback.print_exc()
6869

6970
def __repr__(self):
70-
return "RDDFunction(%s, %s)" % (str(self.deserializer), str(self.func))
71+
return "RDDFunction2(%s)" % (str(self.func))
7172

7273
class Java:
7374
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction2']

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ private[spark] class PythonTransformed2DStream (parent: DStream[_], parent2: DSt
118118
private[spark]
119119
class PythonReducedWindowedDStream(
120120
parent: DStream[Array[Byte]],
121-
reduceFunc: PythonRDDFunction,
121+
reduceFunc: PythonRDDFunction2,
122122
invReduceFunc: PythonRDDFunction2,
123123
_windowDuration: Duration,
124124
_slideDuration: Duration
@@ -149,10 +149,6 @@ class PythonReducedWindowedDStream(
149149
override def parentRememberDuration: Duration = rememberDuration + windowDuration
150150

151151
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
152-
None
153-
val reduceF = reduceFunc
154-
val invReduceF = invReduceFunc
155-
156152
val currentTime = validTime
157153
val currentWindow = new Interval(currentTime - windowDuration + parent.slideDuration,
158154
currentTime)
@@ -196,7 +192,7 @@ class PythonReducedWindowedDStream(
196192
parent.slice(previousWindow.endTime, currentWindow.endTime - parent.slideDuration)
197193

198194
if (newRDDs.size > 0) {
199-
Some(reduceFunc.call(JavaRDD.fromRDD(ssc.sc.union(newRDDs).union(subbed)), validTime.milliseconds))
195+
Some(reduceFunc.call(JavaRDD.fromRDD(subbed), JavaRDD.fromRDD(ssc.sc.union(newRDDs)), validTime.milliseconds))
200196
} else {
201197
Some(subbed)
202198
}
@@ -205,7 +201,7 @@ class PythonReducedWindowedDStream(
205201
val currentRDDs =
206202
parent.slice(currentWindow.beginTime, currentWindow.endTime - parent.slideDuration)
207203
if (currentRDDs.size > 0) {
208-
Some(reduceFunc.call(JavaRDD.fromRDD(ssc.sc.union(currentRDDs)), validTime.milliseconds))
204+
Some(reduceFunc.call(null, JavaRDD.fromRDD(ssc.sc.union(currentRDDs)), validTime.milliseconds))
209205
} else {
210206
None
211207
}
@@ -216,6 +212,40 @@ class PythonReducedWindowedDStream(
216212
}
217213

218214

215+
/**
216+
* Copied from ReducedWindowedDStream
217+
*/
218+
private[spark]
219+
class PythonStateDStream(
220+
parent: DStream[Array[Byte]],
221+
reduceFunc: PythonRDDFunction2
222+
) extends DStream[Array[Byte]](parent.ssc) {
223+
224+
super.persist(StorageLevel.MEMORY_ONLY)
225+
226+
override def dependencies = List(parent)
227+
228+
override def slideDuration: Duration = parent.slideDuration
229+
230+
override val mustCheckpoint = true
231+
232+
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
233+
val lastState = getOrCompute(validTime - slideDuration)
234+
val newRDD = parent.getOrCompute(validTime)
235+
if (newRDD.isDefined) {
236+
if (lastState.isDefined) {
237+
Some(reduceFunc.call(JavaRDD.fromRDD(lastState.get), JavaRDD.fromRDD(newRDD.get), validTime.milliseconds))
238+
} else {
239+
Some(reduceFunc.call(null, JavaRDD.fromRDD(newRDD.get), validTime.milliseconds))
240+
}
241+
} else {
242+
lastState
243+
}
244+
}
245+
246+
val asJavaDStream = JavaDStream.fromDStream(this)
247+
}
248+
219249
/**
220250
* This is used for foreachRDD() in Python
221251
*/

0 commit comments

Comments
 (0)