Skip to content

Don't spill more blocks than we need to #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 7, 2014
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.shuffle.sort
import java.io.{BufferedOutputStream, FileOutputStream}
import java.util.Comparator

import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
import scala.util.{Failure, Success, Try}

import org.apache.spark._
Expand Down Expand Up @@ -59,6 +59,9 @@ private[spark] class SortShuffleReader[K, C](
/** Shuffle block fetcher iterator */
private var shuffleRawBlockFetcherItr: ShuffleRawBlockFetcherIterator = _

/** Number of bytes left to fetch */
private var unfetchedBytes: Long = _

private val dep = handle.dependency
private val conf = SparkEnv.get.conf
private val blockManager = SparkEnv.get.blockManager
Expand All @@ -68,7 +71,7 @@ private[spark] class SortShuffleReader[K, C](
private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024

/** ArrayBuffer to store in-memory shuffle blocks */
private val inMemoryBlocks = new ArrayBuffer[MemoryShuffleBlock]()
private val inMemoryBlocks = new Queue[MemoryShuffleBlock]()

/** Manage the BlockManagerId and related shuffle blocks */
private var statuses: Array[(BlockManagerId, Long)] = _
Expand Down Expand Up @@ -104,55 +107,29 @@ private[spark] class SortShuffleReader[K, C](
}
}

inMemoryBlocks += MemoryShuffleBlock(blockId, blockData)

// Try to fit block in memory. If this fails, merge in-memory blocks to disk.
val blockSize = blockData.size
val granted = shuffleMemoryManager.tryToAcquire(blockSize)

val block = MemoryShuffleBlock(blockId, blockData)
if (granted < blockSize) {
logInfo(s"Granted $granted memory is not enough to store shuffle block ($blockSize), " +
s"try to consolidate in-memory blocks to release the memory")
s"spilling in-memory blocks to release the memory")

shuffleMemoryManager.release(granted)

// Write merged blocks to disk
val (tmpBlockId, file) = blockManager.diskBlockManager.createTempShuffleBlock()
val fos = new FileOutputStream(file)
val bos = new BufferedOutputStream(fos, fileBufferSize)

if (inMemoryBlocks.size > 1) {
val itrGroup = inMemoryBlocksToIterators()
val partialMergedItr =
MergeUtil.mergeSort(itrGroup, keyComparator, dep.keyOrdering, dep.aggregator)
blockManager.dataSerializeStream(tmpBlockId, bos, partialMergedItr, ser)
} else {
val buffer = inMemoryBlocks.map(_.blockData.nioByteBuffer()).head
val channel = fos.getChannel
while (buffer.hasRemaining) {
channel.write(buffer)
}
channel.close()
}

tieredMerger.registerOnDiskBlock(tmpBlockId, file)

logInfo(s"Merge ${inMemoryBlocks.size} in-memory blocks into file ${file.getName}")

for (block <- inMemoryBlocks) {
block.blockData.release()
shuffleMemoryManager.release(block.blockData.size)
}
inMemoryBlocks.clear()
spillInMemoryBlocks(block)
} else {
inMemoryBlocks += block
}

unfetchedBytes -= blockData.size()
shuffleRawBlockFetcherItr.currentResult = null
}
assert(unfetchedBytes == 0)

tieredMerger.doneRegisteringOnDiskBlocks()

// Merge on-disk blocks with in-memory blocks to directly feed to the reducer.
val finalItrGroup = inMemoryBlocksToIterators() ++ Seq(tieredMerger.readMerged())
val finalItrGroup = inMemoryBlocksToIterators(inMemoryBlocks) ++ Seq(tieredMerger.readMerged())
val mergedItr =
MergeUtil.mergeSort(finalItrGroup, keyComparator, dep.keyOrdering, dep.aggregator)

Expand All @@ -169,8 +146,53 @@ private[spark] class SortShuffleReader[K, C](
new InterruptibleIterator(context, completionItr.map(p => (p._1, p._2)))
}

private def inMemoryBlocksToIterators(): Seq[Iterator[Product2[K, C]]] = {
inMemoryBlocks.map{ case MemoryShuffleBlock(id, buf) =>
def spillInMemoryBlocks(tippingBlock: MemoryShuffleBlock): Unit = {
// Write merged blocks to disk
val (tmpBlockId, file) = blockManager.diskBlockManager.createTempShuffleBlock()
val fos = new FileOutputStream(file)
val bos = new BufferedOutputStream(fos, fileBufferSize)

// If the remaining unfetched data would fit inside our current allocation, we don't want to
// waste time spilling blocks beyond the space needed for it.
var bytesToSpill = unfetchedBytes
val blocksToSpill = new ArrayBuffer[MemoryShuffleBlock]()
blocksToSpill += tippingBlock
bytesToSpill -= tippingBlock.blockData.size
while (bytesToSpill > 0 && !inMemoryBlocks.isEmpty) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sryza , I think we should change here to:

if (bytesToSpill > 0) {
  while (!inMemoryBlocks.isEmpty) {
    ....
  }
}

Seems bytesToSpill will be negative and this loop will be jumped out.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is your purpose to spill only the number of blocks which fit for the left unfetchBytes to stay in memory? Seems I misunderstood it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. That way we don't spill more than we need to.

val block = inMemoryBlocks.dequeue()
blocksToSpill += block
bytesToSpill -= block.blockData.size
}

if (blocksToSpill.size > 1) {
val itrGroup = inMemoryBlocksToIterators(blocksToSpill)
val partialMergedItr =
MergeUtil.mergeSort(itrGroup, keyComparator, dep.keyOrdering, dep.aggregator)
blockManager.dataSerializeStream(tmpBlockId, bos, partialMergedItr, ser)
} else {
val buffer = blocksToSpill.map(_.blockData.nioByteBuffer()).head
val channel = fos.getChannel
while (buffer.hasRemaining) {
channel.write(buffer)
}
channel.close()
}

tieredMerger.registerOnDiskBlock(tmpBlockId, file)

logInfo(s"Merged ${blocksToSpill.size} in-memory blocks into file ${file.getName}")

for (block <- blocksToSpill) {
block.blockData.release()
if (block != tippingBlock) {
shuffleMemoryManager.release(block.blockData.size)
}
}
}

private def inMemoryBlocksToIterators(blocks: Seq[MemoryShuffleBlock])
: Seq[Iterator[Product2[K, C]]] = {
blocks.map{ case MemoryShuffleBlock(id, buf) =>
blockManager.dataDeserialize(id, buf.nioByteBuffer(), ser)
.asInstanceOf[Iterator[Product2[K, C]]]
}
Expand All @@ -190,6 +212,7 @@ private[spark] class SortShuffleReader[K, C](
}
(address, blocks.toSeq)
}
unfetchedBytes = blocksByAddress.flatMap(a => a._2.map(b => b._2)).sum

shuffleRawBlockFetcherItr = new ShuffleRawBlockFetcherIterator(
context,
Expand Down