Skip to content

[SPARK-25286][CORE] Removing the dangerous parmap #22292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ package org.apache.spark.rdd
import java.io.{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext
import scala.collection.parallel.ForkJoinTaskSupport
import scala.concurrent.forkjoin.ForkJoinPool
import scala.reflect.ClassTag

import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ThreadUtils.parmap
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -60,7 +59,8 @@ private[spark] class UnionPartition[T: ClassTag](
}

object UnionRDD {
private[spark] lazy val threadPool = new ForkJoinPool(8)
private[spark] lazy val partitionEvalTaskSupport =
new ForkJoinTaskSupport(new ForkJoinPool(8))
}

@DeveloperApi
Expand All @@ -74,13 +74,14 @@ class UnionRDD[T: ClassTag](
rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)

override def getPartitions: Array[Partition] = {
val partitionLengths = if (isPartitionListingParallel) {
implicit val ec = ExecutionContext.fromExecutor(UnionRDD.threadPool)
parmap(rdds)(_.partitions.length)
val parRDDs = if (isPartitionListingParallel) {
val parArray = rdds.par
parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
parArray
} else {
rdds.map(_.partitions.length)
rdds
}
val array = new Array[Partition](partitionLengths.sum)
val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum)
var pos = 0
for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
Expand Down
32 changes: 4 additions & 28 deletions core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -284,36 +284,12 @@ private[spark] object ThreadUtils {
try {
implicit val ec = ExecutionContext.fromExecutor(pool)

parmap(in)(f)
val futures = in.map(x => Future(f(x)))
val futureSeq = Future.sequence(futures)

awaitResult(futureSeq, Duration.Inf)
} finally {
pool.shutdownNow()
}
}

/**
* Transforms input collection by applying the given function to each element in parallel fashion.
* Comparing to the map() method of Scala parallel collections, this method can be interrupted
* at any time. This is useful on canceling of task execution, for example.
*
* @param in - the input collection which should be transformed in parallel.
* @param f - the lambda function will be applied to each element of `in`.
* @param ec - an execution context for parallel applying of the given function `f`.
* @tparam I - the type of elements in the input collection.
* @tparam O - the type of elements in resulted collection.
* @return new collection in which each element was given from the input collection `in` by
* applying the lambda function `f`.
*/
def parmap[I, O, Col[X] <: TraversableLike[X, Col[X]]]
(in: Col[I])
(f: I => O)
(implicit
cbf: CanBuildFrom[Col[I], Future[O], Col[Future[O]]], // For in.map
cbf2: CanBuildFrom[Col[Future[O]], O, Col[O]], // for Future.sequence
ec: ExecutionContext
): Col[O] = {
val futures = in.map(x => Future(f(x)))
val futureSeq = Future.sequence(futures)

awaitResult(futureSeq, Duration.Inf)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,11 @@ private[streaming] object FileBasedWriteAheadLog {
handler: I => Iterator[O]): Iterator[O] = {
val taskSupport = new ExecutionContextTaskSupport(executionContext)
val groupSize = taskSupport.parallelismLevel.max(8)
implicit val ec = executionContext

source.grouped(groupSize).flatMap { group =>
ThreadUtils.parmap(group)(handler)
val parallelCollection = group.par
parallelCollection.tasksupport = taskSupport
parallelCollection.map(handler)
}.flatten
}
}