Skip to content

[SPARK-2871] [PySpark] add zipWithIndex() and zipWithUniqueId() #2092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,6 +1715,53 @@ def batch_as(rdd, batchSize):
other._jrdd_deserializer)
return RDD(pairRDD, self.ctx, deserializer)

def zipWithIndex(self):
"""
Zips this RDD with its element indices.

The ordering is first based on the partition index and then the
ordering of items within each partition. So the first item in
the first partition gets index 0, and the last item in the last
partition receives the largest index.

This method needs to trigger a spark job when this RDD contains
more than one partitions.

>>> sc.parallelize(["a", "b", "c", "d"], 3).zipWithIndex().collect()
[('a', 0), ('b', 1), ('c', 2), ('d', 3)]
"""
starts = [0]
if self.getNumPartitions() > 1:
nums = self.mapPartitions(lambda it: [sum(1 for i in it)]).collect()
for i in range(len(nums) - 1):
starts.append(starts[-1] + nums[i])

def func(k, it):
for i, v in enumerate(it, starts[k]):
yield v, i

return self.mapPartitionsWithIndex(func)

def zipWithUniqueId(self):
"""
Zips this RDD with generated unique Long ids.

Items in the kth partition will get ids k, n+k, 2*n+k, ..., where
n is the number of partitions. So there may exist gaps, but this
method won't trigger a spark job, which is different from
L{zipWithIndex}

>>> sc.parallelize(["a", "b", "c", "d", "e"], 3).zipWithUniqueId().collect()
[('a', 0), ('b', 1), ('c', 4), ('d', 2), ('e', 5)]
"""
n = self.getNumPartitions()

def func(k, it):
for i, v in enumerate(it):
yield v, i * n + k

return self.mapPartitionsWithIndex(func)

def name(self):
"""
Return the name of this RDD.
Expand Down