Skip to content

Commit be9c176

Browse files
mateizrxin
authored andcommitted
Merge pull request alteryx#201 from rxin/mappartitions
Use the proper partition index in mapPartitionsWIthIndex mapPartitionsWithIndex uses TaskContext.partitionId as the partition index. TaskContext.partitionId used to be identical to the partition index in a RDD. However, pull request alteryx#186 introduced a scenario (with partition pruning) that the two can be different. This pull request uses the right partition index in all mapPartitionsWithIndex related calls. Also removed the extra MapPartitionsWIthContextRDD and put all the mapPartitions related functionality in MapPartitionsRDD. (cherry picked from commit 14bb465) Signed-off-by: Reynold Xin <rxin@apache.org>
1 parent 9949561 commit be9c176

File tree

4 files changed

+22
-70
lines changed

4 files changed

+22
-70
lines changed

core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,16 @@ package org.apache.spark.rdd
2020
import org.apache.spark.{Partition, TaskContext}
2121

2222

23-
private[spark]
24-
class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
23+
private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
2524
prev: RDD[T],
26-
f: Iterator[T] => Iterator[U],
25+
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
2726
preservesPartitioning: Boolean = false)
2827
extends RDD[U](prev) {
2928

30-
override val partitioner =
31-
if (preservesPartitioning) firstParent[T].partitioner else None
29+
override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
3230

3331
override def getPartitions: Array[Partition] = firstParent[T].partitions
3432

3533
override def compute(split: Partition, context: TaskContext) =
36-
f(firstParent[T].iterator(split, context))
34+
f(context, split.index, firstParent[T].iterator(split, context))
3735
}

core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala

Lines changed: 0 additions & 41 deletions
This file was deleted.

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,6 @@ abstract class RDD[T: ClassManifest](
408408
def pipe(command: String, env: Map[String, String]): RDD[String] =
409409
new PipedRDD(this, command, env)
410410

411-
412411
/**
413412
* Return an RDD created by piping elements to a forked external process.
414413
* The print behavior can be customized by providing two functions.
@@ -442,7 +441,8 @@ abstract class RDD[T: ClassManifest](
442441
*/
443442
def mapPartitions[U: ClassManifest](
444443
f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
445-
new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
444+
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
445+
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
446446
}
447447

448448
/**
@@ -451,8 +451,8 @@ abstract class RDD[T: ClassManifest](
451451
*/
452452
def mapPartitionsWithIndex[U: ClassManifest](
453453
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
454-
val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
455-
new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
454+
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
455+
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
456456
}
457457

458458
/**
@@ -462,7 +462,8 @@ abstract class RDD[T: ClassManifest](
462462
def mapPartitionsWithContext[U: ClassManifest](
463463
f: (TaskContext, Iterator[T]) => Iterator[U],
464464
preservesPartitioning: Boolean = false): RDD[U] = {
465-
new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
465+
val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
466+
new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
466467
}
467468

468469
/**
@@ -483,11 +484,10 @@ abstract class RDD[T: ClassManifest](
483484
def mapWith[A: ClassManifest, U: ClassManifest]
484485
(constructA: Int => A, preservesPartitioning: Boolean = false)
485486
(f: (T, A) => U): RDD[U] = {
486-
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
487-
val a = constructA(context.partitionId)
487+
mapPartitionsWithIndex((index, iter) => {
488+
val a = constructA(index)
488489
iter.map(t => f(t, a))
489-
}
490-
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
490+
}, preservesPartitioning)
491491
}
492492

493493
/**
@@ -498,11 +498,10 @@ abstract class RDD[T: ClassManifest](
498498
def flatMapWith[A: ClassManifest, U: ClassManifest]
499499
(constructA: Int => A, preservesPartitioning: Boolean = false)
500500
(f: (T, A) => Seq[U]): RDD[U] = {
501-
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
502-
val a = constructA(context.partitionId)
501+
mapPartitionsWithIndex((index, iter) => {
502+
val a = constructA(index)
503503
iter.flatMap(t => f(t, a))
504-
}
505-
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
504+
}, preservesPartitioning)
506505
}
507506

508507
/**
@@ -511,11 +510,10 @@ abstract class RDD[T: ClassManifest](
511510
* partition with the index of that partition.
512511
*/
513512
def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) {
514-
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
515-
val a = constructA(context.partitionId)
513+
mapPartitionsWithIndex { (index, iter) =>
514+
val a = constructA(index)
516515
iter.map(t => {f(t, a); t})
517-
}
518-
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
516+
}.foreach(_ => {})
519517
}
520518

521519
/**
@@ -524,11 +522,10 @@ abstract class RDD[T: ClassManifest](
524522
* partition with the index of that partition.
525523
*/
526524
def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
527-
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
528-
val a = constructA(context.partitionId)
525+
mapPartitionsWithIndex((index, iter) => {
526+
val a = constructA(index)
529527
iter.filter(t => p(t, a))
530-
}
531-
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
528+
}, preservesPartitioning = true)
532529
}
533530

534531
/**

core/src/test/scala/org/apache/spark/CheckpointSuite.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
6262
testCheckpointing(_.sample(false, 0.5, 0))
6363
testCheckpointing(_.glom())
6464
testCheckpointing(_.mapPartitions(_.map(_.toString)))
65-
testCheckpointing(r => new MapPartitionsWithContextRDD(r,
66-
(context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
6765
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
6866
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
6967
testCheckpointing(_.pipe(Seq("cat")))

0 commit comments

Comments
 (0)