Skip to content

[SPARK-2790] [PySpark] fix zip with serializers which have different batch sizes. #1894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,31 @@ def zip(self, other):
>>> x.zip(y).collect()
[(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
"""
if self.getNumPartitions() != other.getNumPartitions():
raise ValueError("Can only zip with RDD which has the same number of partitions")

def get_batch_size(ser):
if isinstance(ser, BatchedSerializer):
return ser.batchSize
return 0

def batch_as(rdd, batchSize):
ser = rdd._jrdd_deserializer
if isinstance(ser, BatchedSerializer):
ser = ser.serializer
return rdd._reserialize(BatchedSerializer(ser, batchSize))

my_batch = get_batch_size(self._jrdd_deserializer)
other_batch = get_batch_size(other._jrdd_deserializer)
if my_batch != other_batch:
# use the greatest batchSize to batch the other one.
if my_batch > other_batch:
other = batch_as(other, my_batch)
else:
self = batch_as(self, other_batch)

# There will be an Exception in JVM if there are different number
# of items in each partitions.
pairRDD = self._jrdd.zip(other._jrdd)
deserializer = PairDeserializer(self._jrdd_deserializer,
other._jrdd_deserializer)
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ def __init__(self, key_ser, val_ser):

def load_stream(self, stream):
for (keys, vals) in self.prepare_keys_values(stream):
if len(keys) != len(vals):
raise ValueError("Can not deserialize RDD with different number of items"
" in pair: (%d, %d)" % (len(keys), len(vals)))
for pair in izip(keys, vals):
yield pair

Expand Down
27 changes: 26 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger

_have_scipy = False
Expand Down Expand Up @@ -330,6 +330,31 @@ def test_large_broadcast(self):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)

def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))
b = self.sc.parallelize(range(100, 105))
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
b = b._reserialize(MarshalSerializer())
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])

def test_zip_with_different_number_of_items(self):
a = self.sc.parallelize(range(5), 2)
# different number of partitions
b = self.sc.parallelize(range(100, 106), 3)
self.assertRaises(ValueError, lambda: a.zip(b))
# different number of batched items in JVM
b = self.sc.parallelize(range(100, 104), 2)
self.assertRaises(Exception, lambda: a.zip(b).count())
# different number of items in one pair
b = self.sc.parallelize(range(100, 106), 2)
self.assertRaises(Exception, lambda: a.zip(b).count())
# same total number of items, but different distributions
a = self.sc.parallelize([2, 3], 2).flatMap(range)
b = self.sc.parallelize([3, 2], 2).flatMap(range)
self.assertEquals(a.count(), b.count())
self.assertRaises(Exception, lambda: a.zip(b).count())


class TestIO(PySparkTestCase):

Expand Down