Skip to content

Commit

Permalink
Move write part of ShuffleMapTask to ShuffleManager
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Jun 8, 2014
1 parent f6f011d commit 4f681ba
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 61 deletions.
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -65,7 +66,7 @@ class ShuffleDependency[K, V, C](

val shuffleId: Int = rdd.context.newShuffleId()

val shuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
shuffleId, rdd.partitions.size, this)

rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
Expand Down
66 changes: 11 additions & 55 deletions core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
import org.apache.spark.shuffle.ShuffleWriter

private[spark] object ShuffleMapTask {

Expand Down Expand Up @@ -141,66 +142,21 @@ private[spark] class ShuffleMapTask(
}

override def runTask(context: TaskContext): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
metrics = Some(context.taskMetrics)

val blockManager = SparkEnv.get.blockManager
val shuffleBlockManager = blockManager.shuffleBlockManager
var shuffle: ShuffleWriterGroup = null
var success = false

var writer: ShuffleWriter[Any, Any] = null
try {
// Obtain all the block writers for shuffle blocks.
val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)

// Write the map output to its associated buckets.
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
for (elem <- rdd.iterator(split, context)) {
val pair = elem.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
shuffle.writers(bucketId).write(pair)
}

// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
var totalTime = 0L
val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
writer.close()
val size = writer.fileSegment().length
totalBytes += size
totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
writer.write(elem.asInstanceOf[Product2[Any, Any]])
}

// Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
shuffleMetrics.shuffleWriteTime = totalTime
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)

success = true
new MapStatus(blockManager.blockManagerId, compressedSizes)
} catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark.
if (shuffle != null && shuffle.writers != null) {
for (writer <- shuffle.writers) {
writer.revertPartialWrites()
writer.close()
return writer.stop(success = true).get
} catch {
case e: Exception =>
if (writer != null) {
writer.stop(success = false)
}
}
throw e
throw e
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && shuffle.writers != null) {
try {
shuffle.releaseWriters(success)
} catch {
case e: Exception => logError("Failed to release shuffle writers", e)
}
}
// Execute the callbacks on task completion.
context.executeOnCompleteCallbacks()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ private[spark] trait ShuffleManager {
def unregisterShuffle(shuffleId: Int)

/** Shut down this ShuffleManager. */
def stop()
def stop(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ private[spark] trait ShuffleReader[K, C] {
def read(): Iterator[Product2[K, C]]

/** Close this reader */
def stop()
def stop(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.shuffle

import org.apache.spark.scheduler.MapStatus

/**
* Obtained inside a map task to write out records to the shuffle system.
*/
Expand All @@ -25,5 +27,5 @@ private[spark] trait ShuffleWriter[K, V] {
def write(record: Product2[K, V]): Unit

/** Close this writer, passing along whether the map completed */
def stop(success: Boolean)
def stop(success: Boolean): Option[MapStatus]
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class HashShuffleManager(conf: SparkConf) extends ShuffleManager {

/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
: ShuffleWriter[K, V] = ???
: ShuffleWriter[K, V] = {
new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
}

/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,97 @@

package org.apache.spark.shuffle.hash

class HashShuffleWriter {
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.storage.{BlockObjectWriter, ShuffleWriterGroup}
import org.apache.spark.serializer.Serializer
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus

class HashShuffleWriter[K, V](
handle: BaseShuffleHandle[K, V, _],
mapId: Int,
context: TaskContext)
extends ShuffleWriter[K, V] with Logging {

private val dep = handle.dependency
private val numOutputSplits = dep.partitioner.numPartitions
private val metrics = context.taskMetrics
private var success = false
private var stopping = false

private val blockManager = SparkEnv.get.blockManager
private val shuffleBlockManager = blockManager.shuffleBlockManager
private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser)

/** Write a record to this task's output */
override def write(record: Product2[K, V]): Unit = {
val pair = record.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
shuffle.writers(bucketId).write(pair)
}

/** Close this writer, passing along whether the map completed */
override def stop(success: Boolean): Option[MapStatus] = {
try {
if (stopping) {
return None
}
stopping = true
if (success) {
try {
return Some(commitWritesAndBuildStatus())
} catch {
case e: Exception =>
revertWrites()
throw e
}
} else {
revertWrites()
return None
}
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && shuffle.writers != null) {
try {
shuffle.releaseWriters(success)
} catch {
case e: Exception => logError("Failed to release shuffle writers", e)
}
}
}
}

private def commitWritesAndBuildStatus(): MapStatus = {
// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
var totalTime = 0L
val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
writer.close()
val size = writer.fileSegment().length
totalBytes += size
totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
}

// Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
shuffleMetrics.shuffleWriteTime = totalTime
metrics.shuffleWriteMetrics = Some(shuffleMetrics)

success = true
new MapStatus(blockManager.blockManagerId, compressedSizes)
}

private def revertWrites(): Unit = {
if (shuffle != null && shuffle.writers != null) {
for (writer <- shuffle.writers) {
writer.revertPartialWrites()
writer.close()
}
}
}
}

0 comments on commit 4f681ba

Please sign in to comment.