Skip to content

SPARK-977 Added Python RDD.zip function #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import warnings

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, pack_long
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
Expand Down Expand Up @@ -1057,6 +1057,24 @@ def coalesce(self, numPartitions, shuffle=False):
jrdd = self._jrdd.coalesce(numPartitions)
return RDD(jrdd, self.ctx, self._jrdd_deserializer)

def zip(self, other):
"""
Zips this RDD with another one, returning key-value pairs with the first element in each RDD
second element in each RDD, etc. Assumes that the two RDDs have the same number of
partitions and the same number of elements in each partition (e.g. one was made through
a map on the other).

>>> x = sc.parallelize(range(0,5))
>>> y = sc.parallelize(range(1000, 1005))
>>> x.zip(y).collect()
[(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
"""
pairRDD = self._jrdd.zip(other._jrdd)
deserializer = PairDeserializer(self._jrdd_deserializer,
other._jrdd_deserializer)
return RDD(pairRDD, self.ctx, deserializer)


# TODO: `lookup` is disabled because we can't make direct comparisons based
# on the key; we need to compare the hash of the key to the hash of the
# keys in the pairs. This could be an expensive operation, since those
Expand Down
29 changes: 28 additions & 1 deletion python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,18 @@ def __init__(self, key_ser, val_ser):
self.key_ser = key_ser
self.val_ser = val_ser

def load_stream(self, stream):
def prepare_keys_values(self, stream):
key_stream = self.key_ser._load_stream_without_unbatching(stream)
val_stream = self.val_ser._load_stream_without_unbatching(stream)
key_is_batched = isinstance(self.key_ser, BatchedSerializer)
val_is_batched = isinstance(self.val_ser, BatchedSerializer)
for (keys, vals) in izip(key_stream, val_stream):
keys = keys if key_is_batched else [keys]
vals = vals if val_is_batched else [vals]
yield (keys, vals)

def load_stream(self, stream):
for (keys, vals) in self.prepare_keys_values(stream):
for pair in product(keys, vals):
yield pair

Expand All @@ -224,6 +228,29 @@ def __str__(self):
(str(self.key_ser), str(self.val_ser))


class PairDeserializer(CartesianDeserializer):
"""
Deserializes the JavaRDD zip() of two PythonRDDs.
"""

def __init__(self, key_ser, val_ser):
self.key_ser = key_ser
self.val_ser = val_ser

def load_stream(self, stream):
for (keys, vals) in self.prepare_keys_values(stream):
for pair in izip(keys, vals):
yield pair

def __eq__(self, other):
return isinstance(other, PairDeserializer) and \
self.key_ser == other.key_ser and self.val_ser == other.val_ser

def __str__(self):
return "PairDeserializer<%s, %s>" % \
(str(self.key_ser), str(self.val_ser))


class NoOpSerializer(FramedSerializer):

def loads(self, obj): return obj
Expand Down