Skip to content

[SPARK-5785] [PySpark] narrow dependency for cogroup/join in PySpark #4629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -961,11 +961,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}

/** Build the union of a list of RDDs. */
def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = {
val partitioners = rdds.flatMap(_.partitioner).toSet
if (partitioners.size == 1) {
new PartitionerAwareUnionRDD(this, rdds)
} else {
new UnionRDD(this, rdds)
}
}

/** Build the union of a list of RDDs passed as variable-length arguments. */
def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change this method to call the union method that you modified so the change will take effect here, too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

new UnionRDD(this, Seq(first) ++ rest)
union(Seq(first) ++ rest)

/** Get an RDD that has no partitions or elements. */
def emptyRDD[T: ClassTag] = new EmptyRDD[T](this)
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
RDD[(Long, Array[Byte])](prev) {
override def getPartitions = prev.partitions
override val partitioner = prev.partitioner
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (Utils.deserializeLongValue(a), b)
Expand All @@ -329,6 +330,15 @@ private[spark] object PythonRDD extends Logging {
}
}

/**
* Return an RDD of values from an RDD of (Long, Array[Byte]), with preservePartitions=true
*
* This is useful for PySpark to have the partitioner after partitionBy()
*/
def valueOfPair(pair: JavaPairRDD[Long, Array[Byte]]): JavaRDD[Array[Byte]] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that JavaPairRDD.values should do the same thing; is there a reason why we can't call that directly from Python?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Scala/Java API, RDD.values() will change the RDD from (K, V) into RDD of V, so preservePartitions should not be true.

For PySpark, it change the RDD from (hash, [(K, V)]) to (K, V), preservePartitions should be true.

pair.rdd.mapPartitions(it => it.map(_._2), true)
}

/**
* Adapter for calling SparkContext#runJob from Python.
*
Expand Down
8 changes: 7 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,13 @@ abstract class RDD[T: ClassTag](
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
*/
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
def union(other: RDD[T]): RDD[T] = {
if (partitioner.isDefined && other.partitioner == partitioner) {
new PartitionerAwareUnionRDD(sc, Array(this, other))
} else {
new UnionRDD(sc, Array(this, other))
}
}

/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@


def _do_python_join(rdd, other, numPartitions, dispatch):
vs = rdd.map(lambda (k, v): (k, (1, v)))
ws = other.map(lambda (k, v): (k, (2, v)))
vs = rdd.mapValues(lambda v: (1, v))
ws = other.mapValues(lambda v: (2, v))
return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x: dispatch(x.__iter__()))


Expand Down Expand Up @@ -98,8 +98,8 @@ def dispatch(seq):

def python_cogroup(rdds, numPartitions):
def make_mapper(i):
return lambda (k, v): (k, (i, v))
vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)]
return lambda v: (i, v)
vrdds = [rdd.mapValues(make_mapper(i)) for i, rdd in enumerate(rdds)]
union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
rdd_len = len(vrdds)

Expand Down
49 changes: 33 additions & 16 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,19 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])


class Partitioner(object):
def __init__(self, numPartitions, partitionFunc):
self.numPartitions = numPartitions
self.partitionFunc = partitionFunc

def __eq__(self, other):
return (isinstance(other, Partitioner) and self.numPartitions == other.numPartitions
and self.partitionFunc == other.partitionFunc)

def __call__(self, k):
return self.partitionFunc(k) % self.numPartitions


class RDD(object):

"""
Expand All @@ -126,7 +139,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSeri
self.ctx = ctx
self._jrdd_deserializer = jrdd_deserializer
self._id = jrdd.id()
self._partitionFunc = None
self.partitioner = None

def _pickled(self):
return self._reserialize(AutoBatchedSerializer(PickleSerializer()))
Expand Down Expand Up @@ -450,14 +463,17 @@ def union(self, other):
if self._jrdd_deserializer == other._jrdd_deserializer:
rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
self._jrdd_deserializer)
return rdd
else:
# These RDDs contain data in different serialized formats, so we
# must normalize them to the default serializer.
self_copy = self._reserialize()
other_copy = other._reserialize()
return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
self.ctx.serializer)
rdd = RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
self.ctx.serializer)
if (self.partitioner == other.partitioner and
self.getNumPartitions() == rdd.getNumPartitions()):
rdd.partitioner = self.partitioner
return rdd

def intersection(self, other):
"""
Expand Down Expand Up @@ -1588,6 +1604,9 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
"""
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
partitioner = Partitioner(numPartitions, partitionFunc)
if self.partitioner == partitioner:
return self

# Transferring O(n) objects to Java is too expensive.
# Instead, we'll form the hash buckets in Python,
Expand Down Expand Up @@ -1632,18 +1651,16 @@ def add_shuffle_key(split, iterator):
yield pack_long(split)
yield outputSerializer.dumps(items)

keyed = self.mapPartitionsWithIndex(add_shuffle_key)
keyed = self.mapPartitionsWithIndex(add_shuffle_key, preservesPartitioning=True)
keyed._bypass_serializer = True
with SCCallSiteSync(self.context) as css:
pairRDD = self.ctx._jvm.PairwiseRDD(
keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
jpartitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = self.ctx._jvm.PythonRDD.valueOfPair(pairRDD.partitionBy(jpartitioner))
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique,
# even if partitionFunc is a lambda:
rdd._partitionFunc = partitionFunc
rdd.partitioner = partitioner
return rdd

# TODO: add control over map-side aggregation
Expand Down Expand Up @@ -1689,7 +1706,7 @@ def combineLocally(iterator):
merger.mergeValues(iterator)
return merger.iteritems()

locally_combined = self.mapPartitions(combineLocally)
locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True)
shuffled = locally_combined.partitionBy(numPartitions)

def _mergeCombiners(iterator):
Expand All @@ -1698,7 +1715,7 @@ def _mergeCombiners(iterator):
merger.mergeCombiners(iterator)
return merger.iteritems()

return shuffled.mapPartitions(_mergeCombiners, True)
return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True)

def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
"""
Expand Down Expand Up @@ -2077,8 +2094,8 @@ def lookup(self, key):
"""
values = self.filter(lambda (k, v): k == key).values()

if self._partitionFunc is not None:
return self.ctx.runJob(values, lambda x: x, [self._partitionFunc(key)], False)
if self.partitioner is not None:
return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False)

return values.collect()

Expand Down Expand Up @@ -2243,7 +2260,7 @@ def pipeline_func(split, iterator):
self._id = None
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
self.partitioner = prev.partitioner if self.preservesPartitioning else None
self._broadcast = None

def __del__(self):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def reduceFunc(t, a, b):
if a is None:
g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
else:
g = a.cogroup(b, numPartitions)
g = a.cogroup(b.partitionBy(numPartitions), numPartitions)
g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None))
state = g.mapValues(lambda (vs, s): updateFunc(vs, s))
return state.filter(lambda (k, v): v is not None)
Expand Down
38 changes: 37 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,6 @@ def test_multiple_python_java_RDD_conversions(self):
(u'1', {u'director': u'David Lean'}),
(u'2', {u'director': u'Andrew Dominik'})
]
from pyspark.rdd import RDD
data_rdd = self.sc.parallelize(data)
data_java_rdd = data_rdd._to_java_object_rdd()
data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd)
Expand All @@ -740,6 +739,43 @@ def test_multiple_python_java_RDD_conversions(self):
converted_rdd = RDD(data_python_rdd, self.sc)
self.assertEqual(2, converted_rdd.count())

def test_narrow_dependency_in_join(self):
rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do these tests actually check for a narrow dependency at all? I think they will pass even without it.

I'm not sure of a better suggestion, though. I had to use getNarrowDependencies in another PR to check this:
https://github.com/apache/spark/pull/4449/files#diff-4bc3643ce90b54113cad7104f91a075bR582

but I don't think that is even exposed in pyspark ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is only for correctness, I will add more check for narrow dependency base one the Python progress API (#3027)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've merged #3027, so I think we can now test this by setting a job group, running a job, then querying the statusTracker to determine how many stages were actually run as part of that job.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

parted = rdd.partitionBy(2)
self.assertEqual(2, parted.union(parted).getNumPartitions())
self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())

self.sc.setJobGroup("test1", "test", True)
tracker = self.sc.statusTracker()

d = sorted(parted.join(parted).collect())
self.assertEqual(10, len(d))
self.assertEqual((0, (0, 0)), d[0])
jobId = tracker.getJobIdsForGroup("test1")[0]
self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))

self.sc.setJobGroup("test2", "test", True)
d = sorted(parted.join(rdd).collect())
self.assertEqual(10, len(d))
self.assertEqual((0, (0, 0)), d[0])
jobId = tracker.getJobIdsForGroup("test2")[0]
self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))

self.sc.setJobGroup("test3", "test", True)
d = sorted(parted.cogroup(parted).collect())
self.assertEqual(10, len(d))
self.assertEqual([[0], [0]], map(list, d[0][1]))
jobId = tracker.getJobIdsForGroup("test3")[0]
self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))

self.sc.setJobGroup("test4", "test", True)
d = sorted(parted.cogroup(rdd).collect())
self.assertEqual(10, len(d))
self.assertEqual([[0], [0]], map(list, d[0][1]))
jobId = tracker.getJobIdsForGroup("test4")[0]
self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))


class ProfilerTests(PySparkTestCase):

Expand Down