Skip to content

Commit 847f9b9

Browse files
committed
add more docs, add first(), take()
1 parent e059ca2 commit 847f9b9

File tree

4 files changed

+243
-26
lines changed

4 files changed

+243
-26
lines changed

python/pyspark/streaming/context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def _initialize_context(self, sc, duration):
9696
return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))
9797

9898
def _jduration(self, seconds):
99+
"""
100+
Create Duration object given number of seconds
101+
"""
99102
return self._jvm.Duration(int(seconds * 1000))
100103

101104
@property

python/pyspark/streaming/dstream.py

Lines changed: 220 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from itertools import chain, ifilter, imap
1919
import operator
20+
import time
2021
from datetime import datetime
2122

2223
from pyspark import RDD
@@ -163,6 +164,29 @@ def takeAndPrint(rdd, time):
163164

164165
self.foreachRDD(takeAndPrint)
165166

167+
def first(self):
168+
"""
169+
Return the first RDD in the stream.
170+
"""
171+
return self.take(1)[0]
172+
173+
def take(self, n):
174+
"""
175+
Return the first `n` RDDs in the stream (will start and stop).
176+
"""
177+
rdds = []
178+
179+
def take(rdd, _):
180+
if rdd:
181+
rdds.append(rdd)
182+
if len(rdds) == n:
183+
# FIXME: NPE in JVM
184+
self._ssc.stop(False)
185+
self.foreachRDD(take)
186+
self._ssc.start()
187+
self._ssc.awaitTermination()
188+
return rdds
189+
166190
def collect(self):
167191
"""
168192
Collect each RDDs into the returned list.
@@ -289,93 +313,261 @@ def saveAsPickleFile(rdd, time):
289313
return self.foreachRDD(saveAsPickleFile)
290314

291315
def transform(self, func):
316+
"""
317+
Return a new DStream in which each RDD is generated by applying a function
318+
on each RDD of 'this' DStream.
319+
"""
292320
return TransformedDStream(self, lambda a, t: func(a), True)
293321

294322
def transformWithTime(self, func):
323+
"""
324+
Return a new DStream in which each RDD is generated by applying a function
325+
on each RDD of 'this' DStream.
326+
"""
295327
return TransformedDStream(self, func, False)
296328

297329
def transformWith(self, func, other, keepSerializer=False):
330+
"""
331+
Return a new DStream in which each RDD is generated by applying a function
332+
on each RDD of 'this' DStream and 'other' DStream.
333+
"""
298334
jfunc = RDDFunction(self.ctx, lambda a, b, t: func(a, b), self._jrdd_deserializer)
299335
dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
300336
other._jdstream.dstream(), jfunc)
301337
jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer
302338
return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer)
303339

304340
def repartitions(self, numPartitions):
341+
"""
342+
Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
343+
returned DStream has exactly numPartitions partitions.
344+
"""
305345
return self.transform(lambda rdd: rdd.repartition(numPartitions))
306346

347+
@property
348+
def _slideDuration(self):
349+
"""
350+
Return the slideDuration in seconds of this DStream
351+
"""
352+
return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0
353+
307354
def union(self, other):
355+
"""
356+
Return a new DStream by unifying data of another DStream with this DStream.
357+
@param other Another DStream having the same interval (i.e., slideDuration) as this DStream.
358+
"""
359+
if self._slideDuration != other._slideDuration:
360+
raise ValueError("the two DStream should have same slide duration")
308361
return self.transformWith(lambda a, b: a.union(b), other, True)
309362

310-
def cogroup(self, other):
311-
return self.transformWith(lambda a, b: a.cogroup(b), other)
363+
def cogroup(self, other, numPartitions=None):
364+
"""
365+
Return a new DStream by applying 'cogroup' between RDDs of `this`
366+
DStream and `other` DStream.
367+
368+
Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
369+
"""
370+
return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other)
371+
372+
def join(self, other, numPartitions=None):
373+
"""
374+
Return a new DStream by applying 'join' between RDDs of `this` DStream and
375+
`other` DStream.
376+
377+
Hash partitioning is used to generate the RDDs with `numPartitions`
378+
partitions.
379+
"""
380+
return self.transformWith(lambda a, b: a.join(b, numPartitions), other)
381+
382+
def leftOuterJoin(self, other, numPartitions=None):
383+
"""
384+
Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and
385+
`other` DStream.
312386
313-
def leftOuterJoin(self, other):
314-
return self.transformWith(lambda a, b: a.leftOuterJion(b), other)
387+
Hash partitioning is used to generate the RDDs with `numPartitions`
388+
partitions.
389+
"""
390+
return self.transformWith(lambda a, b: a.leftOuterJion(b, numPartitions), other)
391+
392+
def rightOuterJoin(self, other, numPartitions=None):
393+
"""
394+
Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
395+
`other` DStream.
396+
397+
Hash partitioning is used to generate the RDDs with `numPartitions`
398+
partitions.
399+
"""
400+
return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other)
401+
402+
def fullOuterJoin(self, other, numPartitions=None):
403+
"""
404+
Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and
405+
`other` DStream.
315406
316-
def rightOuterJoin(self, other):
317-
return self.transformWith(lambda a, b: a.rightOuterJoin(b), other)
407+
Hash partitioning is used to generate the RDDs with `numPartitions`
408+
partitions.
409+
"""
410+
return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other)
318411

319-
def _jtime(self, milliseconds):
320-
return self.ctx._jvm.Time(milliseconds)
412+
def _jtime(self, timestamp):
413+
""" convert datetime or unix_timestamp into Time
414+
"""
415+
if isinstance(timestamp, datetime):
416+
timestamp = time.mktime(timestamp.timetuple())
417+
return self.ctx._jvm.Time(long(timestamp * 1000))
321418

322419
def slice(self, begin, end):
420+
"""
421+
Return all the RDDs between 'begin' to 'end' (both included)
422+
423+
`begin`, `end` could be datetime.datetime() or unix_timestamp
424+
"""
323425
jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end))
324426
return [RDD(jrdd, self.ctx, self._jrdd_deserializer) for jrdd in jrdds]
325427

428+
def _check_window(self, window, slide):
429+
duration = self._jdstream.dstream().slideDuration().milliseconds()
430+
if int(window * 1000) % duration != 0:
431+
raise ValueError("windowDuration must be multiple of the slide duration (%d ms)"
432+
% duration)
433+
if slide and int(slide * 1000) % duration != 0:
434+
raise ValueError("slideDuration must be multiple of the slide duration (%d ms)"
435+
% duration)
436+
326437
def window(self, windowDuration, slideDuration=None):
438+
"""
439+
Return a new DStream in which each RDD contains all the elements in seen in a
440+
sliding window of time over this DStream.
441+
442+
@param windowDuration width of the window; must be a multiple of this DStream's
443+
batching interval
444+
@param slideDuration sliding interval of the window (i.e., the interval after which
445+
the new DStream will generate RDDs); must be a multiple of this
446+
DStream's batching interval
447+
"""
448+
self._check_window(windowDuration, slideDuration)
327449
d = self._ssc._jduration(windowDuration)
328450
if slideDuration is None:
329451
return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer)
330452
s = self._ssc._jduration(slideDuration)
331453
return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer)
332454

333455
def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration):
456+
"""
457+
Return a new DStream in which each RDD has a single element generated by reducing all
458+
elements in a sliding window over this DStream.
459+
460+
if `invReduceFunc` is not None, the reduction is done incrementally
461+
using the old window's reduced value :
462+
1. reduce the new values that entered the window (e.g., adding new counts)
463+
2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
464+
This is more efficient than `invReduceFunc` is None.
465+
466+
@param reduceFunc associative reduce function
467+
@param invReduceFunc inverse reduce function of `reduceFunc`
468+
@param windowDuration width of the window; must be a multiple of this DStream's
469+
batching interval
470+
@param slideDuration sliding interval of the window (i.e., the interval after which
471+
the new DStream will generate RDDs); must be a multiple of this
472+
DStream's batching interval
473+
"""
334474
keyed = self.map(lambda x: (1, x))
335475
reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc,
336476
windowDuration, slideDuration, 1)
337477
return reduced.map(lambda (k, v): v)
338478

339479
def countByWindow(self, windowDuration, slideDuration):
480+
"""
481+
Return a new DStream in which each RDD has a single element generated
482+
by counting the number of elements in a window over this DStream.
483+
windowDuration and slideDuration are as defined in the window() operation.
484+
485+
This is equivalent to window(windowDuration, slideDuration).count(),
486+
but will be more efficient if window is large.
487+
"""
340488
return self.map(lambda x: 1).reduceByWindow(operator.add, operator.sub,
341489
windowDuration, slideDuration)
342490

343491
def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None):
492+
"""
493+
Return a new DStream in which each RDD contains the count of distinct elements in
494+
RDDs in a sliding window over this DStream.
495+
496+
@param windowDuration width of the window; must be a multiple of this DStream's
497+
batching interval
498+
@param slideDuration sliding interval of the window (i.e., the interval after which
499+
the new DStream will generate RDDs); must be a multiple of this
500+
DStream's batching interval
501+
@param numPartitions number of partitions of each RDD in the new DStream.
502+
"""
344503
keyed = self.map(lambda x: (x, 1))
345504
counted = keyed.reduceByKeyAndWindow(lambda a, b: a + b, lambda a, b: a - b,
346505
windowDuration, slideDuration, numPartitions)
347506
return counted.filter(lambda (k, v): v > 0).count()
348507

349508
def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None):
509+
"""
510+
Return a new DStream by applying `groupByKey` over a sliding window.
511+
Similar to `DStream.groupByKey()`, but applies it over a sliding window.
512+
513+
@param windowDuration width of the window; must be a multiple of this DStream's
514+
batching interval
515+
@param slideDuration sliding interval of the window (i.e., the interval after which
516+
the new DStream will generate RDDs); must be a multiple of this
517+
DStream's batching interval
518+
@param numPartitions Number of partitions of each RDD in the new DStream.
519+
"""
350520
ls = self.mapValues(lambda x: [x])
351521
grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):],
352522
windowDuration, slideDuration, numPartitions)
353523
return grouped.mapValues(ResultIterable)
354524

355-
def reduceByKeyAndWindow(self, func, invFunc,
356-
windowDuration, slideDuration, numPartitions=None):
525+
def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None,
526+
numPartitions=None, filterFunc=None):
527+
"""
528+
Return a new DStream by applying incremental `reduceByKey` over a sliding window.
529+
530+
The reduced value of over a new window is calculated using the old window's reduce value :
531+
1. reduce the new values that entered the window (e.g., adding new counts)
532+
2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
357533
358-
duration = self._jdstream.dstream().slideDuration().milliseconds()
359-
if int(windowDuration * 1000) % duration != 0:
360-
raise ValueError("windowDuration must be multiple of the slide duration (%d ms)"
361-
% duration)
362-
if int(slideDuration * 1000) % duration != 0:
363-
raise ValueError("slideDuration must be multiple of the slide duration (%d ms)"
364-
% duration)
534+
`invFunc` can be None, then it will reduce all the RDDs in window, could be slower
535+
than having `invFunc`.
365536
537+
@param reduceFunc associative reduce function
538+
@param invReduceFunc inverse function of `reduceFunc`
539+
@param windowDuration width of the window; must be a multiple of this DStream's
540+
batching interval
541+
@param slideDuration sliding interval of the window (i.e., the interval after which
542+
the new DStream will generate RDDs); must be a multiple of this
543+
DStream's batching interval
544+
@param numPartitions number of partitions of each RDD in the new DStream.
545+
@param filterFunc function to filter expired key-value pairs;
546+
only pairs that satisfy the function are retained
547+
set this to null if you do not want to filter
548+
"""
549+
self._check_window(windowDuration, slideDuration)
366550
reduced = self.reduceByKey(func)
367551

368552
def reduceFunc(a, b, t):
369553
b = b.reduceByKey(func, numPartitions)
370-
return a.union(b).reduceByKey(func, numPartitions) if a else b
554+
r = a.union(b).reduceByKey(func, numPartitions) if a else b
555+
if filterFunc:
556+
r = r.filter(filterFunc)
557+
return r
371558

372559
def invReduceFunc(a, b, t):
373560
b = b.reduceByKey(func, numPartitions)
374561
joined = a.leftOuterJoin(b, numPartitions)
375562
return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)
376563

377564
jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer)
378-
jinvReduceFunc = RDDFunction(self.ctx, invReduceFunc, reduced._jrdd_deserializer)
565+
if invReduceFunc:
566+
jinvReduceFunc = RDDFunction(self.ctx, invReduceFunc, reduced._jrdd_deserializer)
567+
else:
568+
jinvReduceFunc = None
569+
if slideDuration is None:
570+
slideDuration = self._slideDuration
379571
dstream = self.ctx._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(),
380572
jreduceFunc, jinvReduceFunc,
381573
self._ssc._jduration(windowDuration),
@@ -384,15 +576,20 @@ def invReduceFunc(a, b, t):
384576

385577
def updateStateByKey(self, updateFunc, numPartitions=None):
386578
"""
387-
:param updateFunc: [(k, vs, s)] -> [(k, s)]
579+
Return a new "state" DStream where the state for each key is updated by applying
580+
the given function on the previous state of the key and the new values of the key.
581+
582+
@param updateFunc State update function ([(k, vs, s)] -> [(k, s)]).
583+
If `s` is None, then `k` will be eliminated.
388584
"""
389585
def reduceFunc(a, b, t):
390586
if a is None:
391587
g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None))
392588
else:
393-
g = a.cogroup(b).map(lambda (k, (va, vb)):
394-
(k, list(vb), list(va)[0] if len(va) else None))
395-
return g.mapPartitions(lambda x: updateFunc(x) or [])
589+
g = a.cogroup(b, numPartitions)
590+
g = g.map(lambda (k, (va, vb)): (k, list(vb), list(va)[0] if len(va) else None))
591+
state = g.mapPartitions(lambda x: updateFunc(x))
592+
return state.filter(lambda (k, v): v is not None)
396593

397594
jreduceFunc = RDDFunction(self.ctx, reduceFunc,
398595
self.ctx.serializer, self._jrdd_deserializer)

python/pyspark/streaming/tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,21 @@ def _sort_result_based_on_key(self, outputs):
8989

9090

9191
class TestBasicOperations(PySparkStreamingTestCase):
92+
93+
def test_take(self):
94+
input = [range(i) for i in range(3)]
95+
dstream = self.ssc.queueStream(input)
96+
rdds = dstream.take(3)
97+
self.assertEqual(3, len(rdds))
98+
for d, rdd in zip(input, rdds):
99+
self.assertEqual(d, rdd.collect())
100+
101+
def test_first(self):
102+
input = [range(10)]
103+
dstream = self.ssc.queueStream(input)
104+
rdd = dstream.first()
105+
self.assertEqual(range(10), rdd.collect())
106+
92107
def test_map(self):
93108
"""Basic operation test for DStream.map."""
94109
input = [range(1, 5), range(5, 9), range(9, 13)]

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,10 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
207207
// Get the RDD of the reduced value of the previous window
208208
val previousWindowRDD = getOrCompute(previousWindow.endTime)
209209

210-
// for small window, reduce once will be better than twice
211-
if (windowDuration > slideDuration * 5 && previousWindowRDD.isDefined) {
210+
if (pinvReduceFunc != null && previousWindowRDD.isDefined
211+
// for small window, reduce once will be better than twice
212+
&& windowDuration > slideDuration * 5) {
213+
212214
// subtract the values from old RDDs
213215
val oldRDDs =
214216
parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration)
@@ -238,4 +240,4 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
238240
}
239241
}
240242
}
241-
}
243+
}

0 commit comments

Comments
 (0)