Skip to content

Commit 069a94c

Browse files
committed
fix the number of partitions during window()
1 parent 338580a commit 069a94c

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

python/pyspark/streaming/dstream.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,14 +552,18 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None
552552

553553
def reduceFunc(t, a, b):
554554
b = b.reduceByKey(func, numPartitions)
555-
r = a.union(b).reduceByKey(func, numPartitions) if a else b
555+
# use the average of number of partitions, or it will keep increasing
556+
partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2
557+
r = a.union(b).reduceByKey(func, partitions) if a else b
556558
if filterFunc:
557559
r = r.filter(filterFunc)
558560
return r
559561

560562
def invReduceFunc(t, a, b):
561563
b = b.reduceByKey(func, numPartitions)
562-
joined = a.leftOuterJoin(b, numPartitions)
564+
# use the average of number of partitions, or it will keep increasing
565+
partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2
566+
joined = a.leftOuterJoin(b, partitions)
563567
return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)
564568

565569
jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer)
@@ -587,7 +591,9 @@ def reduceFunc(t, a, b):
587591
if a is None:
588592
g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None))
589593
else:
590-
g = a.cogroup(b, numPartitions)
594+
# use the average of number of partitions, or it will keep increasing
595+
partitions = numPartitions or (a.getNumPartitions() + b.getNumPartitions())/2
596+
g = a.cogroup(b, partitions)
591597
g = g.map(lambda (k, (va, vb)): (k, list(vb), list(va)[0] if len(va) else None))
592598
state = g.mapPartitions(lambda x: updateFunc(x))
593599
return state.filter(lambda (k, v): v is not None)

python/pyspark/streaming/tests.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import unittest
2323
import tempfile
2424

25-
from pyspark.context import SparkContext
25+
from pyspark.context import SparkContext, RDD
2626
from pyspark.streaming.context import StreamingContext
2727

2828

@@ -46,8 +46,13 @@ def _test_func(self, input, func, expected, sort=False, input2=None):
4646
@param func: wrapped function. This function should return PythonDStream object.
4747
@param expected: expected output for this testcase.
4848
"""
49+
if not isinstance(input[0], RDD):
50+
input = [self.sc.parallelize(d, 1) for d in input]
4951
input_stream = self.ssc.queueStream(input)
52+
if input2 and not isinstance(input2[0], RDD):
53+
input2 = [self.sc.parallelize(d, 1) for d in input2]
5054
input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None
55+
5156
# Apply test function to stream.
5257
if input2:
5358
stream = func(input_stream, input_stream2)
@@ -63,6 +68,7 @@ def _test_func(self, input, func, expected, sort=False, input2=None):
6368
current_time = time.time()
6469
# Check time out.
6570
if (current_time - start_time) > self.timeout:
71+
print "timeout after", self.timeout
6672
break
6773
# StreamingContext.awaitTermination is not used to wait because
6874
# if py4j server is called every 50 milliseconds, it gets an error.

0 commit comments

Comments
 (0)