Skip to content

Commit

Permalink
Move hash shuffle reader behind ShuffleManager interface
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Jun 6, 2014
1 parent 55c7717 commit f6f011d
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 17 deletions.
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:

private var serializer: Option[Serializer] = None

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
this.serializer = Some(serializer)
this.serializer = Option(serializer)
this
}

Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](

private var serializer: Option[Serializer] = None

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
this.serializer = Some(serializer)
this.serializer = Option(serializer)
this
}

Expand All @@ -61,8 +62,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](

override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]]
SparkEnv.get.shuffleManager
.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[P]]
}
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](

private var serializer: Option[Serializer] = None

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
this.serializer = Some(serializer)
this.serializer = Option(serializer)
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ object Serializer {
def getSerializer(serializer: Serializer): Serializer = {
if (serializer == null) SparkEnv.get.serializer else serializer
}

def getSerializer(serializer: Option[Serializer]): Serializer = {
serializer.getOrElse(SparkEnv.get.serializer)
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ import org.apache.spark.serializer.Serializer
*/
private[spark] class BaseShuffleHandle[K, V, C](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C])
val numMaps: Int,
val dependency: ShuffleDependency[K, V, C])
extends ShuffleHandle(shuffleId)
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ package org.apache.spark.shuffle
*
* @param shuffleId ID of the shuffle
*/
private[spark] abstract class ShuffleHandle(val shuffleId: Int) {}
private[spark] abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.util.CompletionIterator
import org.apache.spark._
import org.apache.spark.storage.ShuffleBlockId

private[hash] class BlockStoreShuffleFetcher extends Logging {
private[hash] object BlockStoreShuffleFetcher extends Logging {
def fetch[T](
shuffleId: Int,
reduceId: Int,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package org.apache.spark.shuffle.hash

import org.apache.spark._
import org.apache.spark.shuffle.{ShuffleReader, ShuffleWriter, ShuffleHandle, ShuffleManager}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._

/**
* A ShuffleManager using the hash-based implementation available up to and including Spark 1.0.
Expand All @@ -12,7 +11,9 @@ class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
override def registerShuffle[K, V, C](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = ???
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}

/**
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
Expand All @@ -22,15 +23,18 @@ class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = ???
context: TaskContext): ShuffleReader[K, C] = {
new HashShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}

/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
: ShuffleWriter[K, V] = ???

/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Unit = ???
override def unregisterShuffle(shuffleId: Int): Unit = {}

/** Shut down this ShuffleManager. */
override def stop(): Unit = ???
override def stop(): Unit = {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,26 @@

package org.apache.spark.shuffle.hash

class HashShuffleReader {
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.TaskContext

class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
context: TaskContext)
extends ShuffleReader[K, C]
{
require(endPartition == startPartition + 1,
"Hash shuffle currently only supports fetching one partition")

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
Serializer.getSerializer(handle.dependency.serializer))
}

/** Close this reader */
override def stop(): Unit = ???
}
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
def newPairRDD = newRDD.map(_ -> 1)
def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
def newBroadcast = sc.broadcast(1 to 100)
def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = {
def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
getAllDependencies(dep.rdd)
Expand Down

0 comments on commit f6f011d

Please sign in to comment.