-
Notifications
You must be signed in to change notification settings - Fork 1
Readability improvements to SortShuffleReader #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
50e63d4
af3c0cb
fd2c813
64a5445
43f5b50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,23 +17,32 @@ | |
|
||
package org.apache.spark.shuffle.sort | ||
|
||
import java.io.{BufferedOutputStream, FileOutputStream, File} | ||
import java.nio.ByteBuffer | ||
import java.io.{BufferedOutputStream, FileOutputStream} | ||
import java.util.Comparator | ||
import java.util.concurrent.{CountDownLatch, TimeUnit, LinkedBlockingQueue} | ||
|
||
import org.apache.spark.network.ManagedBuffer | ||
|
||
import scala.collection.mutable.{ArrayBuffer, HashMap} | ||
|
||
import org.apache.spark.{Logging, InterruptibleIterator, SparkEnv, TaskContext} | ||
import org.apache.spark.network.ManagedBuffer | ||
import org.apache.spark.serializer.Serializer | ||
import org.apache.spark.shuffle.{ShuffleReader, BaseShuffleHandle} | ||
import org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher | ||
import org.apache.spark.storage._ | ||
import org.apache.spark.util.CompletionIterator | ||
import org.apache.spark.util.collection.ExternalSorter | ||
import org.apache.spark.util.collection.{MergeUtil, TieredDiskMerger} | ||
|
||
/** | ||
* SortShuffleReader merges and aggregates shuffle data that has already been sorted within each | ||
* map output block. | ||
* | ||
* As blocks are fetched, we store them in memory until we fail to acquire space frm the | ||
* ShuffleMemoryManager. When this occurs, we merge the in-memory blocks to disk and go back to | ||
* fetching. | ||
* | ||
* TieredDiskMerger is responsible for managing the merged on-disk blocks and for supplying an | ||
* iterator with their merged contents. The final iterator that is passed to user code merges this | ||
* on-disk iterator with the in-memory blocks that have not yet been spilled. | ||
*/ | ||
private[spark] class SortShuffleReader[K, C]( | ||
handle: BaseShuffleHandle[K, _, C], | ||
startPartition: Int, | ||
|
@@ -44,26 +53,21 @@ private[spark] class SortShuffleReader[K, C]( | |
require(endPartition == startPartition + 1, | ||
"Sort shuffle currently only supports fetching one partition") | ||
|
||
sealed trait ShufflePartition | ||
case class MemoryPartition(blockId: BlockId, blockData: ManagedBuffer) extends ShufflePartition | ||
case class FilePartition(blockId: BlockId, mappedFile: File) extends ShufflePartition | ||
private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 | ||
|
||
case class MemoryBlock(blockId: BlockId, blockData: ManagedBuffer) | ||
|
||
private val mergingGroup = new LinkedBlockingQueue[ShufflePartition]() | ||
private val mergedGroup = new LinkedBlockingQueue[ShufflePartition]() | ||
private var numSplits: Int = 0 | ||
private val mergeFinished = new CountDownLatch(1) | ||
private val mergingThread = new MergingThread() | ||
private val tid = Thread.currentThread().getId | ||
private var shuffleRawBlockFetcherItr: ShuffleRawBlockFetcherIterator = null | ||
private var shuffleRawBlockFetcherItr: ShuffleRawBlockFetcherIterator = _ | ||
|
||
private val dep = handle.dependency | ||
private val conf = SparkEnv.get.conf | ||
private val blockManager = SparkEnv.get.blockManager | ||
private val ser = Serializer.getSerializer(dep.serializer) | ||
private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager | ||
|
||
private val ioSortFactor = conf.getInt("spark.shuffle.ioSortFactor", 100) | ||
private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 | ||
private val memoryBlocks = new ArrayBuffer[MemoryBlock]() | ||
|
||
private val tieredMerger = new TieredDiskMerger(conf, dep, context) | ||
|
||
private val keyComparator: Comparator[K] = dep.keyOrdering.getOrElse(new Comparator[K] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both have this here and in TieredDiskMerger, can we just remove another one and pass it by parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. |
||
override def compare(a: K, b: K) = { | ||
|
@@ -84,71 +88,66 @@ private[spark] class SortShuffleReader[K, C]( | |
} | ||
|
||
private def sortShuffleRead(): Iterator[Product2[K, C]] = { | ||
val rawBlockIterator = fetchRawBlock() | ||
|
||
mergingThread.setNumSplits(numSplits) | ||
mergingThread.setDaemon(true) | ||
mergingThread.start() | ||
tieredMerger.start() | ||
|
||
for ((blockId, blockData) <- rawBlockIterator) { | ||
for ((blockId, blockData) <- fetchRawBlocks()) { | ||
if (blockData.isEmpty) { | ||
throw new IllegalStateException(s"block $blockId is empty for unknown reason") | ||
} | ||
|
||
val amountToRequest = blockData.get.size | ||
val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) | ||
val shouldSpill = if (granted < amountToRequest) { | ||
memoryBlocks += MemoryBlock(blockId, blockData.get) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about the situation when memory is not enough for even one block? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In that case we'd end up spilling immediately (further down), which is the correct behavior, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think so. I think we should judge whether memory is enough firstly, then put to |
||
|
||
// Try to fit block in memory. If this fails, merge in-memory blocks to disk. | ||
val blockSize = blockData.get.size | ||
val granted = shuffleMemoryManager.tryToAcquire(blockData.get.size) | ||
if (granted < blockSize) { | ||
shuffleMemoryManager.release(granted) | ||
logInfo(s"Grant memory $granted less than the amount to request $amountToRequest, " + | ||
s"spilling data to file") | ||
true | ||
} else { | ||
false | ||
} | ||
|
||
if (!shouldSpill) { | ||
mergingGroup.offer(MemoryPartition(blockId, blockData.get)) | ||
} else { | ||
val itrGroup = memoryBlocksToIterators() | ||
val partialMergedIter = | ||
MergeUtil.mergeSort(itrGroup, keyComparator, dep.keyOrdering, dep.aggregator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm afraid this will make merge factor too small and too many spilled files if memory is not enough, extremely, if the memory can only fit for one block. Besides now the merge factor is controlled by memory size, not mergeWidth. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if we have only one block in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is there anything we could do differently?
Good point. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What my concern is that to avoid disk spilling we should do more merge actions compared to previous code, this will introduce additional serialization and compression, I'm not sure how to balance disk spilling and serde-compression overhead. |
||
|
||
// Write merged blocks to disk | ||
val (tmpBlockId, file) = blockManager.diskBlockManager.createTempBlock() | ||
val channel = new FileOutputStream(file).getChannel() | ||
val byteBuffer = blockData.get.nioByteBuffer() | ||
while (byteBuffer.remaining() > 0) { | ||
channel.write(byteBuffer) | ||
val fos = new BufferedOutputStream(new FileOutputStream(file), fileBufferSize) | ||
blockManager.dataSerializeStream(tmpBlockId, fos, partialMergedIter, ser) | ||
tieredMerger.registerOnDiskBlock(tmpBlockId, file) | ||
|
||
for (block <- memoryBlocks) { | ||
shuffleMemoryManager.release(block.blockData.size) | ||
} | ||
channel.close() | ||
mergingGroup.offer(FilePartition(tmpBlockId, file)) | ||
memoryBlocks.clear() | ||
} | ||
|
||
shuffleRawBlockFetcherItr.currentResult = null | ||
} | ||
tieredMerger.doneRegisteringOnDiskBlocks() | ||
|
||
mergeFinished.await() | ||
|
||
// Merge the final group for combiner to directly feed to the reducer | ||
val finalMergedPartArray = mergedGroup.toArray(new Array[ShufflePartition](mergedGroup.size())) | ||
val finalItrGroup = getIteratorGroup(finalMergedPartArray) | ||
val mergedItr = if (dep.aggregator.isDefined) { | ||
ExternalSorter.mergeWithAggregation(finalItrGroup, dep.aggregator.get.mergeCombiners, | ||
keyComparator, dep.keyOrdering.isDefined) | ||
} else { | ||
ExternalSorter.mergeSort(finalItrGroup, keyComparator) | ||
} | ||
|
||
mergedGroup.clear() | ||
|
||
// Release the shuffle used memory of this thread | ||
shuffleMemoryManager.releaseMemoryForThisThread() | ||
// Merge on-disk blocks with in-memory blocks to directly feed to the reducer. | ||
val finalItrGroup = memoryBlocksToIterators() ++ Seq(tieredMerger.readMerged()) | ||
val mergedItr = | ||
MergeUtil.mergeSort(finalItrGroup, keyComparator, dep.keyOrdering, dep.aggregator) | ||
|
||
// Release the in-memory block and on-disk file when iteration is completed. | ||
val completionItr = CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]]( | ||
mergedItr, releaseUnusedShufflePartition(finalMergedPartArray)) | ||
mergedItr, () => { | ||
memoryBlocks.foreach(block => shuffleMemoryManager.release(block.blockData.size)) | ||
memoryBlocks.clear() | ||
}) | ||
|
||
new InterruptibleIterator(context, completionItr.map(p => (p._1, p._2))) | ||
} | ||
|
||
def memoryBlocksToIterators(): Seq[Iterator[Product2[K, C]]] = { | ||
memoryBlocks.map{ case MemoryBlock(id, buf) => | ||
blockManager.dataDeserialize(id, buf.nioByteBuffer(), ser) | ||
.asInstanceOf[Iterator[Product2[K, C]]] | ||
} | ||
} | ||
|
||
override def stop(): Unit = ??? | ||
|
||
private def fetchRawBlock(): Iterator[(BlockId, Option[ManagedBuffer])] = { | ||
private def fetchRawBlocks(): Iterator[(BlockId, Option[ManagedBuffer])] = { | ||
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(handle.shuffleId, startPartition) | ||
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]() | ||
for (((address, size), index) <- statuses.zipWithIndex) { | ||
|
@@ -158,10 +157,12 @@ private[spark] class SortShuffleReader[K, C]( | |
case (address, splits) => | ||
(address, splits.map(s => (ShuffleBlockId(handle.shuffleId, s._1, startPartition), s._2))) | ||
} | ||
var numMapBlocks = 0 | ||
blocksByAddress.foreach { case (_, blocks) => | ||
blocks.foreach { case (_, len) => if (len > 0) numSplits += 1 } | ||
blocks.foreach { case (_, len) => if (len > 0) numMapBlocks += 1 } | ||
} | ||
logInfo(s"Fetch $numSplits partitions for $tid") | ||
val threadId = Thread.currentThread.getId | ||
logInfo(s"Fetching $numMapBlocks blocks for $threadId") | ||
|
||
shuffleRawBlockFetcherItr = new ShuffleRawBlockFetcherIterator( | ||
context, | ||
|
@@ -172,127 +173,9 @@ private[spark] class SortShuffleReader[K, C]( | |
|
||
val completionItr = CompletionIterator[ | ||
(BlockId, Option[ManagedBuffer]), | ||
Iterator[(BlockId, Option[ManagedBuffer])]](shuffleRawBlockFetcherItr, { | ||
context.taskMetrics.updateShuffleReadMetrics() | ||
}) | ||
Iterator[(BlockId, Option[ManagedBuffer])]](shuffleRawBlockFetcherItr, | ||
() => context.taskMetrics.updateShuffleReadMetrics()) | ||
|
||
new InterruptibleIterator[(BlockId, Option[ManagedBuffer])](context, completionItr) | ||
} | ||
|
||
private def getIteratorGroup(shufflePartGroup: Array[ShufflePartition]) | ||
: Seq[Iterator[Product2[K, C]]] = { | ||
shufflePartGroup.map { part => | ||
val itr = part match { | ||
case MemoryPartition(id, buf) => | ||
// Release memory usage | ||
shuffleMemoryManager.release(buf.size, tid) | ||
blockManager.dataDeserialize(id, buf.nioByteBuffer(), ser) | ||
case FilePartition(id, file) => | ||
val blockData = blockManager.diskStore.getBytes(id).getOrElse( | ||
throw new IllegalStateException(s"cannot get data from block $id")) | ||
blockManager.dataDeserialize(id, blockData, ser) | ||
} | ||
itr.asInstanceOf[Iterator[Product2[K, C]]] | ||
}.toSeq | ||
} | ||
|
||
|
||
/** | ||
* Release the left in-memory buffer or on-disk file after merged. | ||
*/ | ||
private def releaseUnusedShufflePartition(shufflePartGroup: Array[ShufflePartition]): Unit = { | ||
shufflePartGroup.map { part => | ||
part match { | ||
case MemoryPartition(id, buf) => buf.release() | ||
case FilePartition(id, file) => | ||
try { | ||
file.delete() | ||
} catch { | ||
// Swallow the exception | ||
case e: Throwable => logWarning(s"Unexpected errors when deleting file: ${ | ||
file.getAbsolutePath}", e) | ||
} | ||
} | ||
} | ||
} | ||
|
||
private class MergingThread extends Thread { | ||
private var isLooped = true | ||
private var leftTobeMerged = 0 | ||
|
||
def setNumSplits(numSplits: Int) { | ||
leftTobeMerged = numSplits | ||
} | ||
|
||
override def run() { | ||
while (isLooped) { | ||
if (leftTobeMerged < ioSortFactor && leftTobeMerged > 0) { | ||
var count = leftTobeMerged | ||
while (count > 0) { | ||
val part = mergingGroup.poll(100, TimeUnit.MILLISECONDS) | ||
if (part != null) { | ||
mergedGroup.offer(part) | ||
count -= 1 | ||
leftTobeMerged -= 1 | ||
} | ||
} | ||
} else if (leftTobeMerged >= ioSortFactor) { | ||
val mergingPartArray = ArrayBuffer[ShufflePartition]() | ||
var count = if (numSplits / ioSortFactor > ioSortFactor) { | ||
ioSortFactor | ||
} else { | ||
val mergedSize = mergedGroup.size() | ||
val left = leftTobeMerged - (ioSortFactor - mergedSize - 1) | ||
if (left <= ioSortFactor) { | ||
left | ||
} else { | ||
ioSortFactor | ||
} | ||
} | ||
val countCopy = count | ||
|
||
while (count > 0) { | ||
val part = mergingGroup.poll(100, TimeUnit.MILLISECONDS) | ||
if (part != null) { | ||
mergingPartArray += part | ||
count -= 1 | ||
leftTobeMerged -= 1 | ||
} | ||
} | ||
|
||
// Merge the partitions | ||
val itrGroup = getIteratorGroup(mergingPartArray.toArray) | ||
val partialMergedIter = if (dep.aggregator.isDefined) { | ||
ExternalSorter.mergeWithAggregation(itrGroup, dep.aggregator.get.mergeCombiners, | ||
keyComparator, dep.keyOrdering.isDefined) | ||
} else { | ||
ExternalSorter.mergeSort(itrGroup, keyComparator) | ||
} | ||
// Write merged partitions to disk | ||
val (tmpBlockId, file) = blockManager.diskBlockManager.createTempBlock() | ||
val fos = new BufferedOutputStream(new FileOutputStream(file), fileBufferSize) | ||
blockManager.dataSerializeStream(tmpBlockId, fos, partialMergedIter, ser) | ||
logInfo(s"Merge $countCopy partitions and write into file ${file.getName}") | ||
|
||
releaseUnusedShufflePartition(mergingPartArray.toArray) | ||
mergedGroup.add(FilePartition(tmpBlockId, file)) | ||
} else { | ||
val mergedSize = mergedGroup.size() | ||
if (mergedSize > ioSortFactor) { | ||
leftTobeMerged = mergedSize | ||
|
||
// Swap the merged group and merging group and do merge again, | ||
// since file number is still larger than ioSortFactor | ||
assert(mergingGroup.size() == 0) | ||
mergingGroup.addAll(mergedGroup) | ||
mergedGroup.clear() | ||
} else { | ||
assert(mergingGroup.size() == 0) | ||
isLooped = false | ||
mergeFinished.countDown() | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo
from
.