Skip to content

Readability improvements to SortShuffleReader #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 30, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,32 @@

package org.apache.spark.shuffle.sort

import java.io.{BufferedOutputStream, FileOutputStream, File}
import java.nio.ByteBuffer
import java.io.{BufferedOutputStream, FileOutputStream}
import java.util.Comparator
import java.util.concurrent.{CountDownLatch, TimeUnit, LinkedBlockingQueue}

import org.apache.spark.network.ManagedBuffer

import scala.collection.mutable.{ArrayBuffer, HashMap}

import org.apache.spark.{Logging, InterruptibleIterator, SparkEnv, TaskContext}
import org.apache.spark.network.ManagedBuffer
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleReader, BaseShuffleHandle}
import org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher
import org.apache.spark.storage._
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.util.collection.{MergeUtil, TieredDiskMerger}

/**
* SortShuffleReader merges and aggregates shuffle data that has already been sorted within each
* map output block.
*
* As blocks are fetched, we store them in memory until we fail to acquire space frm the
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo from.

* ShuffleMemoryManager. When this occurs, we merge the in-memory blocks to disk and go back to
* fetching.
*
* TieredDiskMerger is responsible for managing the merged on-disk blocks and for supplying an
* iterator with their merged contents. The final iterator that is passed to user code merges this
* on-disk iterator with the in-memory blocks that have not yet been spilled.
*/
private[spark] class SortShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
Expand All @@ -44,26 +53,21 @@ private[spark] class SortShuffleReader[K, C](
require(endPartition == startPartition + 1,
"Sort shuffle currently only supports fetching one partition")

sealed trait ShufflePartition
case class MemoryPartition(blockId: BlockId, blockData: ManagedBuffer) extends ShufflePartition
case class FilePartition(blockId: BlockId, mappedFile: File) extends ShufflePartition
private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024

case class MemoryBlock(blockId: BlockId, blockData: ManagedBuffer)

private val mergingGroup = new LinkedBlockingQueue[ShufflePartition]()
private val mergedGroup = new LinkedBlockingQueue[ShufflePartition]()
private var numSplits: Int = 0
private val mergeFinished = new CountDownLatch(1)
private val mergingThread = new MergingThread()
private val tid = Thread.currentThread().getId
private var shuffleRawBlockFetcherItr: ShuffleRawBlockFetcherIterator = null
private var shuffleRawBlockFetcherItr: ShuffleRawBlockFetcherIterator = _

private val dep = handle.dependency
private val conf = SparkEnv.get.conf
private val blockManager = SparkEnv.get.blockManager
private val ser = Serializer.getSerializer(dep.serializer)
private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager

private val ioSortFactor = conf.getInt("spark.shuffle.ioSortFactor", 100)
private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024
private val memoryBlocks = new ArrayBuffer[MemoryBlock]()

private val tieredMerger = new TieredDiskMerger(conf, dep, context)

private val keyComparator: Comparator[K] = dep.keyOrdering.getOrElse(new Comparator[K] {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both have this here and in TieredDiskMerger, can we just remove another one and pass it by parameter.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

override def compare(a: K, b: K) = {
Expand All @@ -84,71 +88,66 @@ private[spark] class SortShuffleReader[K, C](
}

private def sortShuffleRead(): Iterator[Product2[K, C]] = {
val rawBlockIterator = fetchRawBlock()

mergingThread.setNumSplits(numSplits)
mergingThread.setDaemon(true)
mergingThread.start()
tieredMerger.start()

for ((blockId, blockData) <- rawBlockIterator) {
for ((blockId, blockData) <- fetchRawBlocks()) {
if (blockData.isEmpty) {
throw new IllegalStateException(s"block $blockId is empty for unknown reason")
}

val amountToRequest = blockData.get.size
val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
val shouldSpill = if (granted < amountToRequest) {
memoryBlocks += MemoryBlock(blockId, blockData.get)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the situation when memory is not enough for even one block?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case we'd end up spilling immediately (further down), which is the correct behavior, right?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so. I think we should judge whether memory is enough firstly, then put to memoryBlocks if enough, otherwise directly spilling to disk.


// Try to fit block in memory. If this fails, merge in-memory blocks to disk.
val blockSize = blockData.get.size
val granted = shuffleMemoryManager.tryToAcquire(blockData.get.size)
if (granted < blockSize) {
shuffleMemoryManager.release(granted)
logInfo(s"Grant memory $granted less than the amount to request $amountToRequest, " +
s"spilling data to file")
true
} else {
false
}

if (!shouldSpill) {
mergingGroup.offer(MemoryPartition(blockId, blockData.get))
} else {
val itrGroup = memoryBlocksToIterators()
val partialMergedIter =
MergeUtil.mergeSort(itrGroup, keyComparator, dep.keyOrdering, dep.aggregator)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid this will make merge factor too small and too many spilled files if memory is not enough, extremely, if the memory can only fit for one block. Besides now the merge factor is controlled by memory size, not mergeWidth.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we have only one block in memoryBlocks, we can directly write to disk without deserialize-serialize again.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid this will make merge factor too small and too many spilled files if memory is not enough

Is there anything we could do differently?

I think if we have only one block in memoryBlocks, we can directly write to disk without deserialize-serialize again.

Good point.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What my concern is that to avoid disk spilling we should do more merge actions compared to previous code, this will introduce additional serialization and compression, I'm not sure how to balance disk spilling and serde-compression overhead.


// Write merged blocks to disk
val (tmpBlockId, file) = blockManager.diskBlockManager.createTempBlock()
val channel = new FileOutputStream(file).getChannel()
val byteBuffer = blockData.get.nioByteBuffer()
while (byteBuffer.remaining() > 0) {
channel.write(byteBuffer)
val fos = new BufferedOutputStream(new FileOutputStream(file), fileBufferSize)
blockManager.dataSerializeStream(tmpBlockId, fos, partialMergedIter, ser)
tieredMerger.registerOnDiskBlock(tmpBlockId, file)

for (block <- memoryBlocks) {
shuffleMemoryManager.release(block.blockData.size)
}
channel.close()
mergingGroup.offer(FilePartition(tmpBlockId, file))
memoryBlocks.clear()
}

shuffleRawBlockFetcherItr.currentResult = null
}
tieredMerger.doneRegisteringOnDiskBlocks()

mergeFinished.await()

// Merge the final group for combiner to directly feed to the reducer
val finalMergedPartArray = mergedGroup.toArray(new Array[ShufflePartition](mergedGroup.size()))
val finalItrGroup = getIteratorGroup(finalMergedPartArray)
val mergedItr = if (dep.aggregator.isDefined) {
ExternalSorter.mergeWithAggregation(finalItrGroup, dep.aggregator.get.mergeCombiners,
keyComparator, dep.keyOrdering.isDefined)
} else {
ExternalSorter.mergeSort(finalItrGroup, keyComparator)
}

mergedGroup.clear()

// Release the shuffle used memory of this thread
shuffleMemoryManager.releaseMemoryForThisThread()
// Merge on-disk blocks with in-memory blocks to directly feed to the reducer.
val finalItrGroup = memoryBlocksToIterators() ++ Seq(tieredMerger.readMerged())
val mergedItr =
MergeUtil.mergeSort(finalItrGroup, keyComparator, dep.keyOrdering, dep.aggregator)

// Release the in-memory block and on-disk file when iteration is completed.
val completionItr = CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](
mergedItr, releaseUnusedShufflePartition(finalMergedPartArray))
mergedItr, () => {
memoryBlocks.foreach(block => shuffleMemoryManager.release(block.blockData.size))
memoryBlocks.clear()
})

new InterruptibleIterator(context, completionItr.map(p => (p._1, p._2)))
}

def memoryBlocksToIterators(): Seq[Iterator[Product2[K, C]]] = {
memoryBlocks.map{ case MemoryBlock(id, buf) =>
blockManager.dataDeserialize(id, buf.nioByteBuffer(), ser)
.asInstanceOf[Iterator[Product2[K, C]]]
}
}

override def stop(): Unit = ???

private def fetchRawBlock(): Iterator[(BlockId, Option[ManagedBuffer])] = {
private def fetchRawBlocks(): Iterator[(BlockId, Option[ManagedBuffer])] = {
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(handle.shuffleId, startPartition)
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]()
for (((address, size), index) <- statuses.zipWithIndex) {
Expand All @@ -158,10 +157,12 @@ private[spark] class SortShuffleReader[K, C](
case (address, splits) =>
(address, splits.map(s => (ShuffleBlockId(handle.shuffleId, s._1, startPartition), s._2)))
}
var numMapBlocks = 0
blocksByAddress.foreach { case (_, blocks) =>
blocks.foreach { case (_, len) => if (len > 0) numSplits += 1 }
blocks.foreach { case (_, len) => if (len > 0) numMapBlocks += 1 }
}
logInfo(s"Fetch $numSplits partitions for $tid")
val threadId = Thread.currentThread.getId
logInfo(s"Fetching $numMapBlocks blocks for $threadId")

shuffleRawBlockFetcherItr = new ShuffleRawBlockFetcherIterator(
context,
Expand All @@ -172,127 +173,9 @@ private[spark] class SortShuffleReader[K, C](

val completionItr = CompletionIterator[
(BlockId, Option[ManagedBuffer]),
Iterator[(BlockId, Option[ManagedBuffer])]](shuffleRawBlockFetcherItr, {
context.taskMetrics.updateShuffleReadMetrics()
})
Iterator[(BlockId, Option[ManagedBuffer])]](shuffleRawBlockFetcherItr,
() => context.taskMetrics.updateShuffleReadMetrics())

new InterruptibleIterator[(BlockId, Option[ManagedBuffer])](context, completionItr)
}

private def getIteratorGroup(shufflePartGroup: Array[ShufflePartition])
: Seq[Iterator[Product2[K, C]]] = {
shufflePartGroup.map { part =>
val itr = part match {
case MemoryPartition(id, buf) =>
// Release memory usage
shuffleMemoryManager.release(buf.size, tid)
blockManager.dataDeserialize(id, buf.nioByteBuffer(), ser)
case FilePartition(id, file) =>
val blockData = blockManager.diskStore.getBytes(id).getOrElse(
throw new IllegalStateException(s"cannot get data from block $id"))
blockManager.dataDeserialize(id, blockData, ser)
}
itr.asInstanceOf[Iterator[Product2[K, C]]]
}.toSeq
}


/**
* Release the left in-memory buffer or on-disk file after merged.
*/
private def releaseUnusedShufflePartition(shufflePartGroup: Array[ShufflePartition]): Unit = {
shufflePartGroup.map { part =>
part match {
case MemoryPartition(id, buf) => buf.release()
case FilePartition(id, file) =>
try {
file.delete()
} catch {
// Swallow the exception
case e: Throwable => logWarning(s"Unexpected errors when deleting file: ${
file.getAbsolutePath}", e)
}
}
}
}

private class MergingThread extends Thread {
private var isLooped = true
private var leftTobeMerged = 0

def setNumSplits(numSplits: Int) {
leftTobeMerged = numSplits
}

override def run() {
while (isLooped) {
if (leftTobeMerged < ioSortFactor && leftTobeMerged > 0) {
var count = leftTobeMerged
while (count > 0) {
val part = mergingGroup.poll(100, TimeUnit.MILLISECONDS)
if (part != null) {
mergedGroup.offer(part)
count -= 1
leftTobeMerged -= 1
}
}
} else if (leftTobeMerged >= ioSortFactor) {
val mergingPartArray = ArrayBuffer[ShufflePartition]()
var count = if (numSplits / ioSortFactor > ioSortFactor) {
ioSortFactor
} else {
val mergedSize = mergedGroup.size()
val left = leftTobeMerged - (ioSortFactor - mergedSize - 1)
if (left <= ioSortFactor) {
left
} else {
ioSortFactor
}
}
val countCopy = count

while (count > 0) {
val part = mergingGroup.poll(100, TimeUnit.MILLISECONDS)
if (part != null) {
mergingPartArray += part
count -= 1
leftTobeMerged -= 1
}
}

// Merge the partitions
val itrGroup = getIteratorGroup(mergingPartArray.toArray)
val partialMergedIter = if (dep.aggregator.isDefined) {
ExternalSorter.mergeWithAggregation(itrGroup, dep.aggregator.get.mergeCombiners,
keyComparator, dep.keyOrdering.isDefined)
} else {
ExternalSorter.mergeSort(itrGroup, keyComparator)
}
// Write merged partitions to disk
val (tmpBlockId, file) = blockManager.diskBlockManager.createTempBlock()
val fos = new BufferedOutputStream(new FileOutputStream(file), fileBufferSize)
blockManager.dataSerializeStream(tmpBlockId, fos, partialMergedIter, ser)
logInfo(s"Merge $countCopy partitions and write into file ${file.getName}")

releaseUnusedShufflePartition(mergingPartArray.toArray)
mergedGroup.add(FilePartition(tmpBlockId, file))
} else {
val mergedSize = mergedGroup.size()
if (mergedSize > ioSortFactor) {
leftTobeMerged = mergedSize

// Swap the merged group and merging group and do merge again,
// since file number is still larger than ioSortFactor
assert(mergingGroup.size() == 0)
mergingGroup.addAll(mergedGroup)
mergedGroup.clear()
} else {
assert(mergingGroup.size() == 0)
isLooped = false
mergeFinished.countDown()
}
}
}
}
}
}
Loading