@@ -552,14 +552,18 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None
552
552
553
553
def reduceFunc (t , a , b ):
554
554
b = b .reduceByKey (func , numPartitions )
555
- r = a .union (b ).reduceByKey (func , numPartitions ) if a else b
555
+ # use the average of number of partitions, or it will keep increasing
556
+ partitions = numPartitions or (a .getNumPartitions () + b .getNumPartitions ())/ 2
557
+ r = a .union (b ).reduceByKey (func , partitions ) if a else b
556
558
if filterFunc :
557
559
r = r .filter (filterFunc )
558
560
return r
559
561
560
562
def invReduceFunc (t , a , b ):
561
563
b = b .reduceByKey (func , numPartitions )
562
- joined = a .leftOuterJoin (b , numPartitions )
564
+ # use the average of number of partitions, or it will keep increasing
565
+ partitions = numPartitions or (a .getNumPartitions () + b .getNumPartitions ())/ 2
566
+ joined = a .leftOuterJoin (b , partitions )
563
567
return joined .mapValues (lambda (v1 , v2 ): invFunc (v1 , v2 ) if v2 is not None else v1 )
564
568
565
569
jreduceFunc = RDDFunction (self .ctx , reduceFunc , reduced ._jrdd_deserializer )
@@ -587,7 +591,9 @@ def reduceFunc(t, a, b):
587
591
if a is None :
588
592
g = b .groupByKey (numPartitions ).map (lambda (k , vs ): (k , list (vs ), None ))
589
593
else :
590
- g = a .cogroup (b , numPartitions )
594
+ # use the average of number of partitions, or it will keep increasing
595
+ partitions = numPartitions or (a .getNumPartitions () + b .getNumPartitions ())/ 2
596
+ g = a .cogroup (b , partitions )
591
597
g = g .map (lambda (k , (va , vb )): (k , list (vb ), list (va )[0 ] if len (va ) else None ))
592
598
state = g .mapPartitions (lambda x : updateFunc (x ))
593
599
return state .filter (lambda (k , v ): v is not None )
0 commit comments