Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2044] Pluggable interface for shuffles #1009

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}

/** Register a ShuffleDependency for cleanup when it is garbage collected. */
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
}

Expand Down
12 changes: 9 additions & 3 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -50,19 +51,24 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null,
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
*/
@DeveloperApi
class ShuffleDependency[K, V](
class ShuffleDependency[K, V, C](
@transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializer: Serializer = null)
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {

val shuffleId: Int = rdd.context.newShuffleId()

val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
shuffleId, rdd.partitions.size, this)

rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}

Expand Down
28 changes: 18 additions & 10 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.ConnectionManager
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
import org.apache.spark.util.{AkkaUtils, Utils}

Expand All @@ -56,7 +57,7 @@ class SparkEnv (
val closureSerializer: Serializer,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
Expand All @@ -80,7 +81,7 @@ class SparkEnv (
pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
shuffleFetcher.stop()
shuffleManager.stop()
broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
Expand Down Expand Up @@ -163,13 +164,20 @@ object SparkEnv extends Logging {
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = conf.get(propertyName, defaultClassName)
val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader)
// First try with the constructor that takes SparkConf. If we can't find one,
// use a no-arg constructor instead.
// Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
// SparkConf, then one taking no arguments
try {
cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
.newInstance(conf, new java.lang.Boolean(isDriver))
.asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
cls.getConstructor().newInstance().asInstanceOf[T]
try {
cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
cls.getConstructor().newInstance().asInstanceOf[T]
}
}
}

Expand Down Expand Up @@ -219,9 +227,6 @@ object SparkEnv extends Logging {

val cacheManager = new CacheManager(blockManager)

val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")

val httpFileServer = new HttpFileServer(securityManager)
httpFileServer.initialize()
conf.set("spark.fileserver.uri", httpFileServer.serverUri)
Expand All @@ -242,6 +247,9 @@ object SparkEnv extends Logging {
"."
}

val shuffleManager = instantiateClass[ShuffleManager](
"spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")

// Warn about deprecated spark.cache.class property
if (conf.contains("spark.cache.class")) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
Expand All @@ -255,7 +263,7 @@ object SparkEnv extends Logging {
closureSerializer,
cacheManager,
mapOutputTracker,
shuffleFetcher,
shuffleManager,
broadcastManager,
blockManager,
connectionManager,
Expand Down
22 changes: 12 additions & 10 deletions core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle

private[spark] sealed trait CoGroupSplitDep extends Serializable

Expand All @@ -44,7 +45,7 @@ private[spark] case class NarrowCoGroupSplitDep(
}
}

private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep

private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
extends Partition with Serializable {
Expand Down Expand Up @@ -74,10 +75,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Seq[CoGroup]

private var serializer: Serializer = null
private var serializer: Option[Serializer] = None

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
this.serializer = serializer
this.serializer = Option(serializer)
this
}

Expand All @@ -88,7 +90,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
new ShuffleDependency[Any, Any](rdd, part, serializer)
new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer)
}
}
}
Expand All @@ -100,8 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) =>
// Assume each RDD contributed a single dependency, and get it
dependencies(j) match {
case s: ShuffleDependency[_, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId)
case s: ShuffleDependency[_, _, _] =>
new ShuffleCoGroupSplitDep(s.shuffleHandle)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
Expand All @@ -126,11 +128,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
rddIterators += ((it, depNum))

case ShuffleCoGroupSplitDep(shuffleId) =>
case ShuffleCoGroupSplitDep(handle) =>
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
val ser = Serializer.getSerializer(serializer)
val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
val it = SparkEnv.get.shuffleManager
.getReader(handle, split.index, split.index + 1, context)
.read()
rddIterators += ((it, depNum))
}

Expand Down
12 changes: 7 additions & 5 deletions core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
part: Partitioner)
extends RDD[P](prev.context, Nil) {

private var serializer: Serializer = null
private var serializer: Option[Serializer] = None

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
this.serializer = serializer
this.serializer = Option(serializer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be Some(serializer)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not if serializer is null

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for null it will be None, isn't that the expect result?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala> var serializer: Option[java.lang.Integer] = null
serializer: Option[Integer] = null

scala> serializer = Some(1)
serializer: Option[Integer] = Some(1)

scala> serializer.get
res0: Integer = 1

scala> serializer = Some(null)
serializer: Option[Integer] = Some(null)

scala> serializer.get
res1: Integer = null

scala> serializer.isDefined
res2: Boolean = true

scala> serializer = Option(null)
serializer: Option[Integer] = None

scala> serializer.isDefined
res3: Boolean = false

this
}

Expand All @@ -60,9 +61,10 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
}

override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
val ser = Serializer.getSerializer(serializer)
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the combiner here also V?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there's no combining, the return type is (K, V), same as the input type (K, V).

SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[P]]
}

override def clearDependencies() {
Expand Down
17 changes: 9 additions & 8 deletions core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {

private var serializer: Serializer = null
private var serializer: Option[Serializer] = None

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
this.serializer = serializer
this.serializer = Option(serializer)
this
}

Expand All @@ -79,8 +80,8 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
// Each CoGroupPartition will depend on rdd1 and rdd2
array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
dependencies(j) match {
case s: ShuffleDependency[_, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId)
case s: ShuffleDependency[_, _, _] =>
new ShuffleCoGroupSplitDep(s.shuffleHandle)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
Expand All @@ -93,7 +94,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](

override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
val ser = Serializer.getSerializer(serializer)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
Expand All @@ -109,9 +109,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)

case ShuffleCoGroupSplitDep(shuffleId) =>
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
context, ser)
case ShuffleCoGroupSplitDep(handle) =>
val iter = SparkEnv.get.shuffleManager
.getReader(handle, partition.index, partition.index + 1, context)
.read()
iter.foreach(op)
}
// the first dep is rdd1; add all values to the map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class DAGScheduler(
* The jobId value passed in will be used if the stage doesn't already exist with
* a lower jobId (jobId always increases across jobs.)
*/
private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = {
private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
Expand All @@ -209,7 +209,7 @@ class DAGScheduler(
private def newStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
shuffleDep: Option[ShuffleDependency[_, _, _]],
jobId: Int,
callSite: Option[String] = None)
: Stage =
Expand All @@ -232,7 +232,7 @@ class DAGScheduler(
private def newOrUsedStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: ShuffleDependency[_,_],
shuffleDep: ShuffleDependency[_, _, _],
jobId: Int,
callSite: Option[String] = None)
: Stage =
Expand Down Expand Up @@ -268,7 +268,7 @@ class DAGScheduler(
// we can't do it in its constructor because # of partitions is unknown
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
case shufDep: ShuffleDependency[_, _, _] =>
parents += getShuffleMapStage(shufDep, jobId)
case _ =>
visit(dep.rdd)
Expand All @@ -289,7 +289,7 @@ class DAGScheduler(
if (getCacheLocs(rdd).contains(Nil)) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
missing += mapStage
Expand Down Expand Up @@ -1087,7 +1087,7 @@ class DAGScheduler(
visitedRdds += rdd
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
visitedStages += mapStage
Expand Down
Loading