25
25
from pyspark .storagelevel import StorageLevel
26
26
from pyspark .resultiterable import ResultIterable
27
27
from pyspark .streaming .util import rddToFileName , RDDFunction
28
+ from pyspark .rdd import portable_hash , _parse_memory
28
29
from pyspark .traceback_utils import SCCallSiteSync
29
30
30
31
from py4j .java_collections import ListConverter , MapConverter
@@ -40,6 +41,7 @@ def __init__(self, jdstream, ssc, jrdd_deserializer):
40
41
self ._jrdd_deserializer = jrdd_deserializer
41
42
self .is_cached = False
42
43
self .is_checkpointed = False
44
+ self ._partitionFunc = None
43
45
44
46
def context (self ):
45
47
"""
@@ -161,32 +163,71 @@ def _mergeCombiners(iterator):
161
163
162
164
return shuffled .mapPartitions (_mergeCombiners )
163
165
164
- def partitionBy (self , numPartitions , partitionFunc = None ):
166
+ def partitionBy (self , numPartitions , partitionFunc = portable_hash ):
165
167
"""
166
168
Return a copy of the DStream partitioned using the specified partitioner.
167
169
"""
168
170
if numPartitions is None :
169
171
numPartitions = self .ctx ._defaultReducePartitions ()
170
172
171
- if partitionFunc is None :
172
- partitionFunc = lambda x : 0 if x is None else hash (x )
173
-
174
173
# Transferring O(n) objects to Java is too expensive. Instead, we'll
175
174
# form the hash buckets in Python, transferring O(numPartitions) objects
176
175
# to Java. Each object is a (splitNumber, [objects]) pair.
176
+
177
177
outputSerializer = self .ctx ._unbatched_serializer
178
+ #
179
+ # def add_shuffle_key(split, iterator):
180
+ # buckets = defaultdict(list)
181
+ #
182
+ # for (k, v) in iterator:
183
+ # buckets[partitionFunc(k) % numPartitions].append((k, v))
184
+ # for (split, items) in buckets.iteritems():
185
+ # yield pack_long(split)
186
+ # yield outputSerializer.dumps(items)
187
+ # keyed = PipelinedDStream(self, add_shuffle_key)
188
+
189
+ limit = (_parse_memory (self .ctx ._conf .get (
190
+ "spark.python.worker.memory" , "512m" )) / 2 )
178
191
179
192
def add_shuffle_key (split , iterator ):
193
+
180
194
buckets = defaultdict (list )
195
+ c , batch = 0 , min (10 * numPartitions , 1000 )
181
196
182
- for ( k , v ) in iterator :
197
+ for k , v in iterator :
183
198
buckets [partitionFunc (k ) % numPartitions ].append ((k , v ))
184
- for (split , items ) in buckets .iteritems ():
199
+ c += 1
200
+
201
+ # check used memory and avg size of chunk of objects
202
+ if (c % 1000 == 0 and get_used_memory () > limit
203
+ or c > batch ):
204
+ n , size = len (buckets ), 0
205
+ for split in buckets .keys ():
206
+ yield pack_long (split )
207
+ d = outputSerializer .dumps (buckets [split ])
208
+ del buckets [split ]
209
+ yield d
210
+ size += len (d )
211
+
212
+ avg = (size / n ) >> 20
213
+ # let 1M < avg < 10M
214
+ if avg < 1 :
215
+ batch *= 1.5
216
+ elif avg > 10 :
217
+ batch = max (batch / 1.5 , 1 )
218
+ c = 0
219
+
220
+ for split , items in buckets .iteritems ():
185
221
yield pack_long (split )
186
222
yield outputSerializer .dumps (items )
187
- keyed = PipelinedDStream (self , add_shuffle_key )
223
+
224
+ keyed = self ._mapPartitionsWithIndex (add_shuffle_key )
225
+
226
+
227
+
228
+
188
229
keyed ._bypass_serializer = True
189
- with SCCallSiteSync (self .context ) as css :
230
+ with SCCallSiteSync (self .ctx ) as css :
190
231
partitioner = self .ctx ._jvm .PythonPartitioner (numPartitions ,
191
232
id (partitionFunc ))
192
233
jdstream = self .ctx ._jvm .PythonPairwiseDStream (keyed ._jdstream .dstream (),
@@ -428,6 +469,10 @@ def get_output(rdd, time):
428
469
429
470
430
471
class PipelinedDStream (DStream ):
472
+ """
473
+ Since PipelinedDStream is same to PipelindRDD, if PipliedRDD is changed,
474
+ this code should be changed in the same way.
475
+ """
431
476
def __init__ (self , prev , func , preservesPartitioning = False ):
432
477
if not isinstance (prev , PipelinedDStream ) or not prev ._is_pipelinable ():
433
478
# This transformation is the first in its stage:
@@ -453,19 +498,22 @@ def pipeline_func(split, iterator):
453
498
self ._jdstream_val = None
454
499
self ._jrdd_deserializer = self .ctx .serializer
455
500
self ._bypass_serializer = False
501
+ self ._partitionFunc = prev ._partitionFunc if self .preservesPartitioning else None
456
502
457
503
@property
458
504
def _jdstream (self ):
459
505
if self ._jdstream_val :
460
506
return self ._jdstream_val
461
507
if self ._bypass_serializer :
462
- serializer = NoOpSerializer ()
463
- else :
464
- serializer = self .ctx .serializer
465
-
466
- command = (self .func , self ._prev_jrdd_deserializer , serializer )
467
- ser = CompressedSerializer (CloudPickleSerializer ())
508
+ self .jrdd_deserializer = NoOpSerializer ()
509
+ command = (self .func , self ._prev_jrdd_deserializer ,
510
+ self ._jrdd_deserializer )
511
+ # the serialized command will be compressed by broadcast
512
+ ser = CloudPickleSerializer ()
468
513
pickled_command = ser .dumps (command )
514
+ if pickled_command > (1 << 20 ): # 1M
515
+ broadcast = self .ctx .broadcast (pickled_command )
516
+ pickled_command = ser .dumps (broadcast )
469
517
broadcast_vars = ListConverter ().convert (
470
518
[x ._jbroadcast for x in self .ctx ._pickled_broadcast_vars ],
471
519
self .ctx ._gateway ._gateway_client )
0 commit comments