Skip to content

[SPARK-3554] [PySpark] use broadcast automatically for large closure #2417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,8 +2061,12 @@ def _jrdd(self):
self._jrdd_deserializer = NoOpSerializer()
command = (self.func, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if pickled_command > (1 << 20): # 1M
broadcast = self.ctx.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from array import array
from operator import itemgetter

from pyspark.rdd import RDD, PipelinedRDD
from pyspark.rdd import RDD
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
from pyspark.storagelevel import StorageLevel

Expand Down Expand Up @@ -974,7 +974,11 @@ def registerFunction(self, name, f, returnType=StringType()):
command = (func,
BatchedSerializer(PickleSerializer(), 1024),
BatchedSerializer(PickleSerializer(), 1024))
pickled_command = CloudPickleSerializer().dumps(command)
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if pickled_command > (1 << 20): # 1M
broadcast = self._sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self._sc._pickled_broadcast_vars],
self._sc._gateway._gateway_client)
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,12 @@ def test_large_broadcast(self):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)

def test_large_closure(self):
N = 1000000
data = [float(i) for i in xrange(N)]
m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum()
self.assertEquals(N, m)

def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))
b = self.sc.parallelize(range(100, 105))
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ def main(infile, outfile):
_broadcastRegistry[bid] = Broadcast(bid, value)
else:
bid = - bid - 1
_broadcastRegistry.remove(bid)
_broadcastRegistry.pop(bid)

_accumulatorRegistry.clear()
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
(func, deserializer, serializer) = command
init_time = time.time()
iterator = deserializer.load_stream(infile)
Expand Down