Skip to content

Commit 550dfd9

Browse files
committed
WIP fixing 1.1 merge
1 parent 5cdb6fa commit 550dfd9

File tree

1 file changed

+62
-14
lines changed

1 file changed

+62
-14
lines changed

python/pyspark/streaming/dstream.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pyspark.storagelevel import StorageLevel
2626
from pyspark.resultiterable import ResultIterable
2727
from pyspark.streaming.util import rddToFileName, RDDFunction
28+
from pyspark.rdd import portable_hash, _parse_memory
2829
from pyspark.traceback_utils import SCCallSiteSync
2930

3031
from py4j.java_collections import ListConverter, MapConverter
@@ -40,6 +41,7 @@ def __init__(self, jdstream, ssc, jrdd_deserializer):
4041
self._jrdd_deserializer = jrdd_deserializer
4142
self.is_cached = False
4243
self.is_checkpointed = False
44+
self._partitionFunc = None
4345

4446
def context(self):
4547
"""
@@ -161,32 +163,71 @@ def _mergeCombiners(iterator):
161163

162164
return shuffled.mapPartitions(_mergeCombiners)
163165

164-
def partitionBy(self, numPartitions, partitionFunc=None):
166+
def partitionBy(self, numPartitions, partitionFunc=portable_hash):
165167
"""
166168
Return a copy of the DStream partitioned using the specified partitioner.
167169
"""
168170
if numPartitions is None:
169171
numPartitions = self.ctx._defaultReducePartitions()
170172

171-
if partitionFunc is None:
172-
partitionFunc = lambda x: 0 if x is None else hash(x)
173-
174173
# Transferring O(n) objects to Java is too expensive. Instead, we'll
175174
# form the hash buckets in Python, transferring O(numPartitions) objects
176175
# to Java. Each object is a (splitNumber, [objects]) pair.
176+
177177
outputSerializer = self.ctx._unbatched_serializer
178+
#
179+
# def add_shuffle_key(split, iterator):
180+
# buckets = defaultdict(list)
181+
#
182+
# for (k, v) in iterator:
183+
# buckets[partitionFunc(k) % numPartitions].append((k, v))
184+
# for (split, items) in buckets.iteritems():
185+
# yield pack_long(split)
186+
# yield outputSerializer.dumps(items)
187+
# keyed = PipelinedDStream(self, add_shuffle_key)
188+
189+
limit = (_parse_memory(self.ctx._conf.get(
190+
"spark.python.worker.memory", "512m")) / 2)
178191

179192
def add_shuffle_key(split, iterator):
193+
180194
buckets = defaultdict(list)
195+
c, batch = 0, min(10 * numPartitions, 1000)
181196

182-
for (k, v) in iterator:
197+
for k, v in iterator:
183198
buckets[partitionFunc(k) % numPartitions].append((k, v))
184-
for (split, items) in buckets.iteritems():
199+
c += 1
200+
201+
# check used memory and avg size of chunk of objects
202+
if (c % 1000 == 0 and get_used_memory() > limit
203+
or c > batch):
204+
n, size = len(buckets), 0
205+
for split in buckets.keys():
206+
yield pack_long(split)
207+
d = outputSerializer.dumps(buckets[split])
208+
del buckets[split]
209+
yield d
210+
size += len(d)
211+
212+
avg = (size / n) >> 20
213+
# let 1M < avg < 10M
214+
if avg < 1:
215+
batch *= 1.5
216+
elif avg > 10:
217+
batch = max(batch / 1.5, 1)
218+
c = 0
219+
220+
for split, items in buckets.iteritems():
185221
yield pack_long(split)
186222
yield outputSerializer.dumps(items)
187-
keyed = PipelinedDStream(self, add_shuffle_key)
223+
224+
keyed = self._mapPartitionsWithIndex(add_shuffle_key)
225+
226+
227+
228+
188229
keyed._bypass_serializer = True
189-
with SCCallSiteSync(self.context) as css:
230+
with SCCallSiteSync(self.ctx) as css:
190231
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
191232
id(partitionFunc))
192233
jdstream = self.ctx._jvm.PythonPairwiseDStream(keyed._jdstream.dstream(),
@@ -428,6 +469,10 @@ def get_output(rdd, time):
428469

429470

430471
class PipelinedDStream(DStream):
472+
"""
473+
Since PipelinedDStream is same to PipelindRDD, if PipliedRDD is changed,
474+
this code should be changed in the same way.
475+
"""
431476
def __init__(self, prev, func, preservesPartitioning=False):
432477
if not isinstance(prev, PipelinedDStream) or not prev._is_pipelinable():
433478
# This transformation is the first in its stage:
@@ -453,19 +498,22 @@ def pipeline_func(split, iterator):
453498
self._jdstream_val = None
454499
self._jrdd_deserializer = self.ctx.serializer
455500
self._bypass_serializer = False
501+
self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
456502

457503
@property
458504
def _jdstream(self):
459505
if self._jdstream_val:
460506
return self._jdstream_val
461507
if self._bypass_serializer:
462-
serializer = NoOpSerializer()
463-
else:
464-
serializer = self.ctx.serializer
465-
466-
command = (self.func, self._prev_jrdd_deserializer, serializer)
467-
ser = CompressedSerializer(CloudPickleSerializer())
508+
self.jrdd_deserializer = NoOpSerializer()
509+
command = (self.func, self._prev_jrdd_deserializer,
510+
self._jrdd_deserializer)
511+
# the serialized command will be compressed by broadcast
512+
ser = CloudPickleSerializer()
468513
pickled_command = ser.dumps(command)
514+
if pickled_command > (1 << 20): # 1M
515+
broadcast = self.ctx.broadcast(pickled_command)
516+
pickled_command = ser.dumps(broadcast)
469517
broadcast_vars = ListConverter().convert(
470518
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
471519
self.ctx._gateway._gateway_client)

0 commit comments

Comments
 (0)