Skip to content

Commit

Permalink
[SPARK-2044] Pluggable interface for shuffles
Browse files Browse the repository at this point in the history
This is a first cut at moving shuffle logic behind a pluggable interface, as described at https://issues.apache.org/jira/browse/SPARK-2044, to let us more easily experiment with new shuffle implementations. It moves the existing shuffle code to a class HashShuffleManager behind a general ShuffleManager interface.

Two things are still missing to make this complete:
* MapOutputTracker needs to be hidden behind the ShuffleManager interface; this will also require adding methods to ShuffleManager that will let the DAGScheduler interact with it as it does with the MapOutputTracker today
* The code to do map-sides and reduce-side combine in ShuffledRDD, PairRDDFunctions, etc needs to be moved into the ShuffleManager's readers and writers

However, some of these may also be done later after we merge the current interface.

Author: Matei Zaharia <matei@databricks.com>

Closes apache#1009 from mateiz/pluggable-shuffle and squashes the following commits:

7a09862 [Matei Zaharia] review comments
be33d3f [Matei Zaharia] review comments
1513d4e [Matei Zaharia] Add ASF header
ac56831 [Matei Zaharia] Bug fix and better error message
4f681ba [Matei Zaharia] Move write part of ShuffleMapTask to ShuffleManager
f6f011d [Matei Zaharia] Move hash shuffle reader behind ShuffleManager interface
55c7717 [Matei Zaharia] Changed RDD code to use ShuffleReader
75cc044 [Matei Zaharia] Partial work to move hash shuffle in
  • Loading branch information
mateiz authored and rxin committed Jun 12, 2014
1 parent d920335 commit 508fd37
Show file tree
Hide file tree
Showing 22 changed files with 459 additions and 130 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}

/** Register a ShuffleDependency for cleanup when it is garbage collected. */
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
}

Expand Down
12 changes: 9 additions & 3 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -50,19 +51,24 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null,
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
*/
@DeveloperApi
class ShuffleDependency[K, V](
class ShuffleDependency[K, V, C](
@transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializer: Serializer = null)
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {

val shuffleId: Int = rdd.context.newShuffleId()

val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
shuffleId, rdd.partitions.size, this)

rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}

Expand Down
28 changes: 18 additions & 10 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.ConnectionManager
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
import org.apache.spark.util.{AkkaUtils, Utils}

Expand All @@ -56,7 +57,7 @@ class SparkEnv (
val closureSerializer: Serializer,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
Expand All @@ -80,7 +81,7 @@ class SparkEnv (
pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
shuffleFetcher.stop()
shuffleManager.stop()
broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
Expand Down Expand Up @@ -163,13 +164,20 @@ object SparkEnv extends Logging {
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = conf.get(propertyName, defaultClassName)
val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader)
// First try with the constructor that takes SparkConf. If we can't find one,
// use a no-arg constructor instead.
// Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
// SparkConf, then one taking no arguments
try {
cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
.newInstance(conf, new java.lang.Boolean(isDriver))
.asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
cls.getConstructor().newInstance().asInstanceOf[T]
try {
cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
cls.getConstructor().newInstance().asInstanceOf[T]
}
}
}

Expand Down Expand Up @@ -219,9 +227,6 @@ object SparkEnv extends Logging {

val cacheManager = new CacheManager(blockManager)

val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")

val httpFileServer = new HttpFileServer(securityManager)
httpFileServer.initialize()
conf.set("spark.fileserver.uri", httpFileServer.serverUri)
Expand All @@ -242,6 +247,9 @@ object SparkEnv extends Logging {
"."
}

val shuffleManager = instantiateClass[ShuffleManager](
"spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")

// Warn about deprecated spark.cache.class property
if (conf.contains("spark.cache.class")) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
Expand All @@ -255,7 +263,7 @@ object SparkEnv extends Logging {
closureSerializer,
cacheManager,
mapOutputTracker,
shuffleFetcher,
shuffleManager,
broadcastManager,
blockManager,
connectionManager,
Expand Down
22 changes: 12 additions & 10 deletions core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle

private[spark] sealed trait CoGroupSplitDep extends Serializable

Expand All @@ -44,7 +45,7 @@ private[spark] case class NarrowCoGroupSplitDep(
}
}

private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep

private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
extends Partition with Serializable {
Expand Down Expand Up @@ -74,10 +75,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Seq[CoGroup]

private var serializer: Serializer = null
private var serializer: Option[Serializer] = None

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
this.serializer = serializer
this.serializer = Option(serializer)
this
}

Expand All @@ -88,7 +90,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
new ShuffleDependency[Any, Any](rdd, part, serializer)
new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer)
}
}
}
Expand All @@ -100,8 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) =>
// Assume each RDD contributed a single dependency, and get it
dependencies(j) match {
case s: ShuffleDependency[_, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId)
case s: ShuffleDependency[_, _, _] =>
new ShuffleCoGroupSplitDep(s.shuffleHandle)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
Expand All @@ -126,11 +128,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
rddIterators += ((it, depNum))

case ShuffleCoGroupSplitDep(shuffleId) =>
case ShuffleCoGroupSplitDep(handle) =>
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
val ser = Serializer.getSerializer(serializer)
val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
val it = SparkEnv.get.shuffleManager
.getReader(handle, split.index, split.index + 1, context)
.read()
rddIterators += ((it, depNum))
}

Expand Down
12 changes: 7 additions & 5 deletions core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
part: Partitioner)
extends RDD[P](prev.context, Nil) {

private var serializer: Serializer = null
private var serializer: Option[Serializer] = None

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
this.serializer = serializer
this.serializer = Option(serializer)
this
}

Expand All @@ -60,9 +61,10 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
}

override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
val ser = Serializer.getSerializer(serializer)
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[P]]
}

override def clearDependencies() {
Expand Down
17 changes: 9 additions & 8 deletions core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {

private var serializer: Serializer = null
private var serializer: Option[Serializer] = None

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
this.serializer = serializer
this.serializer = Option(serializer)
this
}

Expand All @@ -79,8 +80,8 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
// Each CoGroupPartition will depend on rdd1 and rdd2
array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
dependencies(j) match {
case s: ShuffleDependency[_, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId)
case s: ShuffleDependency[_, _, _] =>
new ShuffleCoGroupSplitDep(s.shuffleHandle)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
Expand All @@ -93,7 +94,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](

override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
val ser = Serializer.getSerializer(serializer)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
Expand All @@ -109,9 +109,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)

case ShuffleCoGroupSplitDep(shuffleId) =>
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
context, ser)
case ShuffleCoGroupSplitDep(handle) =>
val iter = SparkEnv.get.shuffleManager
.getReader(handle, partition.index, partition.index + 1, context)
.read()
iter.foreach(op)
}
// the first dep is rdd1; add all values to the map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class DAGScheduler(
* The jobId value passed in will be used if the stage doesn't already exist with
* a lower jobId (jobId always increases across jobs.)
*/
private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = {
private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
Expand All @@ -210,7 +210,7 @@ class DAGScheduler(
private def newStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
shuffleDep: Option[ShuffleDependency[_, _, _]],
jobId: Int,
callSite: Option[String] = None)
: Stage =
Expand All @@ -233,7 +233,7 @@ class DAGScheduler(
private def newOrUsedStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: ShuffleDependency[_,_],
shuffleDep: ShuffleDependency[_, _, _],
jobId: Int,
callSite: Option[String] = None)
: Stage =
Expand Down Expand Up @@ -269,7 +269,7 @@ class DAGScheduler(
// we can't do it in its constructor because # of partitions is unknown
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
case shufDep: ShuffleDependency[_, _, _] =>
parents += getShuffleMapStage(shufDep, jobId)
case _ =>
visit(dep.rdd)
Expand All @@ -290,7 +290,7 @@ class DAGScheduler(
if (getCacheLocs(rdd).contains(Nil)) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
missing += mapStage
Expand Down Expand Up @@ -1088,7 +1088,7 @@ class DAGScheduler(
visitedRdds += rdd
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
visitedStages += mapStage
Expand Down
Loading

0 comments on commit 508fd37

Please sign in to comment.