Skip to content

Commit 16a73c2

Browse files
sryzamateiz
authored andcommitted
SPARK-2978. Transformation with MR shuffle semantics
I didn't add this to the transformations list in the docs because it's kind of obscure, but would be happy to do so if others think it would be helpful. Author: Sandy Ryza <sandy@cloudera.com> Closes #2274 from sryza/sandy-spark-2978 and squashes the following commits: 4a5332a [Sandy Ryza] Fix Java test c04b447 [Sandy Ryza] Fix Python doc and add back deleted code 433ad5b [Sandy Ryza] Add Java test 4c25a54 [Sandy Ryza] Add s at the end and a couple other fixes 9b0ba99 [Sandy Ryza] Fix compilation 36e0571 [Sandy Ryza] Fix import ordering 48c12c2 [Sandy Ryza] Add Java version and additional doc e5381cd [Sandy Ryza] Fix python style warnings f147634 [Sandy Ryza] SPARK-2978. Transformation with MR shuffle semantics
1 parent e16a8e7 commit 16a73c2

File tree

6 files changed

+115
-1
lines changed

6 files changed

+115
-1
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,32 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
758758
rdd.saveAsHadoopDataset(conf)
759759
}
760760

761+
/**
762+
* Repartition the RDD according to the given partitioner and, within each resulting partition,
763+
* sort records by their keys.
764+
*
765+
* This is more efficient than calling `repartition` and then sorting within each partition
766+
* because it can push the sorting down into the shuffle machinery.
767+
*/
768+
def repartitionAndSortWithinPartitions(partitioner: Partitioner): JavaPairRDD[K, V] = {
769+
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
770+
repartitionAndSortWithinPartitions(partitioner, comp)
771+
}
772+
773+
/**
774+
* Repartition the RDD according to the given partitioner and, within each resulting partition,
775+
* sort records by their keys.
776+
*
777+
* This is more efficient than calling `repartition` and then sorting within each partition
778+
* because it can push the sorting down into the shuffle machinery.
779+
*/
780+
def repartitionAndSortWithinPartitions(partitioner: Partitioner, comp: Comparator[K])
781+
: JavaPairRDD[K, V] = {
782+
implicit val ordering = comp // Allow implicit conversion of Comparator to Ordering.
783+
fromRDD(
784+
new OrderedRDDFunctions[K, V, (K, V)](rdd).repartitionAndSortWithinPartitions(partitioner))
785+
}
786+
761787
/**
762788
* Sort the RDD by key, so that each partition contains a sorted range of the elements in
763789
* ascending order. Calling `collect` or `save` on the resulting RDD will return or output an

core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.rdd
1919

2020
import scala.reflect.ClassTag
2121

22-
import org.apache.spark.{Logging, RangePartitioner}
22+
import org.apache.spark.{Logging, Partitioner, RangePartitioner}
2323
import org.apache.spark.annotation.DeveloperApi
2424

2525
/**
@@ -64,4 +64,16 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
6464
new ShuffledRDD[K, V, V](self, part)
6565
.setKeyOrdering(if (ascending) ordering else ordering.reverse)
6666
}
67+
68+
/**
69+
* Repartition the RDD according to the given partitioner and, within each resulting partition,
70+
* sort records by their keys.
71+
*
72+
* This is more efficient than calling `repartition` and then sorting within each partition
73+
* because it can push the sorting down into the shuffle machinery.
74+
*/
75+
def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = {
76+
new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering)
77+
}
78+
6779
}

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,36 @@ public void sortByKey() {
189189
Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
190190
}
191191

192+
@Test
193+
public void repartitionAndSortWithinPartitions() {
194+
List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
195+
pairs.add(new Tuple2<Integer, Integer>(0, 5));
196+
pairs.add(new Tuple2<Integer, Integer>(3, 8));
197+
pairs.add(new Tuple2<Integer, Integer>(2, 6));
198+
pairs.add(new Tuple2<Integer, Integer>(0, 8));
199+
pairs.add(new Tuple2<Integer, Integer>(3, 8));
200+
pairs.add(new Tuple2<Integer, Integer>(1, 3));
201+
202+
JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
203+
204+
Partitioner partitioner = new Partitioner() {
205+
public int numPartitions() {
206+
return 2;
207+
}
208+
public int getPartition(Object key) {
209+
return ((Integer)key).intValue() % 2;
210+
}
211+
};
212+
213+
JavaPairRDD<Integer, Integer> repartitioned =
214+
rdd.repartitionAndSortWithinPartitions(partitioner);
215+
List<List<Tuple2<Integer, Integer>>> partitions = repartitioned.glom().collect();
216+
Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2<Integer, Integer>(0, 5),
217+
new Tuple2<Integer, Integer>(0, 8), new Tuple2<Integer, Integer>(2, 6)));
218+
Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2<Integer, Integer>(1, 3),
219+
new Tuple2<Integer, Integer>(3, 8), new Tuple2<Integer, Integer>(3, 8)));
220+
}
221+
192222
@Test
193223
public void emptyRDD() {
194224
JavaRDD<String> rdd = sc.emptyRDD();

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,20 @@ class RDDSuite extends FunSuite with SharedSparkContext {
682682
assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered)
683683
}
684684

685+
test("repartitionAndSortWithinPartitions") {
686+
val data = sc.parallelize(Seq((0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)), 2)
687+
688+
val partitioner = new Partitioner {
689+
def numPartitions: Int = 2
690+
def getPartition(key: Any): Int = key.asInstanceOf[Int] % 2
691+
}
692+
693+
val repartitioned = data.repartitionAndSortWithinPartitions(partitioner)
694+
val partitions = repartitioned.glom().collect()
695+
assert(partitions(0) === Seq((0, 5), (0, 8), (2, 6)))
696+
assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8)))
697+
}
698+
685699
test("intersection") {
686700
val all = sc.parallelize(1 to 10)
687701
val evens = sc.parallelize(2 to 10 by 2)

python/pyspark/rdd.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,30 @@ def __add__(self, other):
520520
raise TypeError
521521
return self.union(other)
522522

523+
def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=portable_hash,
524+
ascending=True, keyfunc=lambda x: x):
525+
"""
526+
Repartition the RDD according to the given partitioner and, within each resulting partition,
527+
sort records by their keys.
528+
529+
>>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)])
530+
>>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, 2)
531+
>>> rdd2.glom().collect()
532+
[[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]]
533+
"""
534+
if numPartitions is None:
535+
numPartitions = self._defaultReducePartitions()
536+
537+
spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true")
538+
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
539+
serializer = self._jrdd_deserializer
540+
541+
def sortPartition(iterator):
542+
sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
543+
return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)))
544+
545+
return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True)
546+
523547
def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
524548
"""
525549
Sorts this RDD, which is assumed to consist of (key, value) pairs.

python/pyspark/tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,14 @@ def test_histogram(self):
545545
self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
546546
self.assertRaises(TypeError, lambda: rdd.histogram(2))
547547

548+
def test_repartitionAndSortWithinPartitions(self):
549+
rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
550+
551+
repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2)
552+
partitions = repartitioned.glom().collect()
553+
self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)])
554+
self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)])
555+
548556

549557
class TestSQL(PySparkTestCase):
550558

0 commit comments

Comments
 (0)