Skip to content

Commit 74dad41

Browse files
committed
SPARK-2565. Update ShuffleReadMetrics as blocks are fetched
1 parent 47ccd5e commit 74dad41

File tree

10 files changed

+85
-65
lines changed

10 files changed

+85
-65
lines changed

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ private[spark] class Executor(
374374
for (taskRunner <- runningTasks.values()) {
375375
if (!taskRunner.attemptedTask.isEmpty) {
376376
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
377+
metrics.updateShuffleReadMetrics
377378
tasksMetrics += ((taskRunner.taskId, metrics))
378379
}
379380
}

core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.executor
1919

20+
import scala.collection.mutable.ArrayBuffer
21+
2022
import org.apache.spark.annotation.DeveloperApi
2123
import org.apache.spark.storage.{BlockId, BlockStatus}
2224

@@ -81,11 +83,26 @@ class TaskMetrics extends Serializable {
8183
var inputMetrics: Option[InputMetrics] = None
8284

8385
/**
84-
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here
86+
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here.
87+
* This includes read metrics aggregated over all the task's shuffle dependencies.
8588
*/
8689
private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None
8790

88-
def shuffleReadMetrics = _shuffleReadMetrics
91+
def shuffleReadMetrics() = _shuffleReadMetrics
92+
93+
/**
94+
* This should only be used when recreating TaskMetrics, not when updating read metrics in
95+
* executors.
96+
*/
97+
private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) {
98+
_shuffleReadMetrics = shuffleReadMetrics
99+
}
100+
101+
/**
102+
* ShuffleReadMetrics per dependency for collecting independently while task is in progress.
103+
*/
104+
@transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] =
105+
new ArrayBuffer[ShuffleReadMetrics]()
89106

90107
/**
91108
* If this task writes to shuffle output, metrics on the written shuffle data will be collected
@@ -98,19 +115,31 @@ class TaskMetrics extends Serializable {
98115
*/
99116
var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None
100117

101-
/** Adds the given ShuffleReadMetrics to any existing shuffle metrics for this task. */
102-
def updateShuffleReadMetrics(newMetrics: ShuffleReadMetrics) = synchronized {
103-
_shuffleReadMetrics match {
104-
case Some(existingMetrics) =>
105-
existingMetrics.shuffleFinishTime = math.max(
106-
existingMetrics.shuffleFinishTime, newMetrics.shuffleFinishTime)
107-
existingMetrics.fetchWaitTime += newMetrics.fetchWaitTime
108-
existingMetrics.localBlocksFetched += newMetrics.localBlocksFetched
109-
existingMetrics.remoteBlocksFetched += newMetrics.remoteBlocksFetched
110-
existingMetrics.remoteBytesRead += newMetrics.remoteBytesRead
111-
case None =>
112-
_shuffleReadMetrics = Some(newMetrics)
118+
/**
119+
* A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization
120+
* issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each
121+
* dependency, and merge these metrics before reporting them to the driver. This method returns
122+
* a ShuffleReadMetrics for a dependency and registers it for merging later.
123+
*/
124+
private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized {
125+
val readMetrics = new ShuffleReadMetrics()
126+
depsShuffleReadMetrics += readMetrics
127+
readMetrics
128+
}
129+
130+
/**
131+
* Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics.
132+
*/
133+
private[spark] def updateShuffleReadMetrics() = synchronized {
134+
val merged = new ShuffleReadMetrics()
135+
for (depMetrics <- depsShuffleReadMetrics) {
136+
merged.fetchWaitTime += depMetrics.fetchWaitTime
137+
merged.localBlocksFetched += depMetrics.localBlocksFetched
138+
merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched
139+
merged.remoteBytesRead += depMetrics.remoteBytesRead
140+
merged.shuffleFinishTime = math.max(merged.shuffleFinishTime, depMetrics.shuffleFinishTime)
113141
}
142+
_shuffleReadMetrics = Some(merged)
114143
}
115144
}
116145

core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
3232
shuffleId: Int,
3333
reduceId: Int,
3434
context: TaskContext,
35-
serializer: Serializer)
35+
serializer: Serializer,
36+
shuffleMetrics: ShuffleReadMetrics)
3637
: Iterator[T] =
3738
{
3839
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
@@ -73,17 +74,11 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
7374
}
7475
}
7576

76-
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
77+
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics)
7778
val itr = blockFetcherItr.flatMap(unpackBlock)
7879

7980
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
80-
val shuffleMetrics = new ShuffleReadMetrics
81-
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
82-
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
83-
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
84-
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
85-
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
86-
context.taskMetrics.updateShuffleReadMetrics(shuffleMetrics)
81+
context.taskMetrics.updateShuffleReadMetrics()
8782
})
8883

8984
new InterruptibleIterator[T](context, completionIter)

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ private[spark] class HashShuffleReader[K, C](
3636

3737
/** Read the combined key-values for this reduce task */
3838
override def read(): Iterator[Product2[K, C]] = {
39+
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
3940
val ser = Serializer.getSerializer(dep.serializer)
40-
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
41+
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser,
42+
readMetrics)
4143

4244
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
4345
if (dep.mapSideCombine) {

core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import scala.util.{Failure, Success}
2727
import io.netty.buffer.ByteBuf
2828

2929
import org.apache.spark.{Logging, SparkException}
30+
import org.apache.spark.executor.ShuffleReadMetrics
3031
import org.apache.spark.network.BufferMessage
3132
import org.apache.spark.network.ConnectionManagerId
3233
import org.apache.spark.network.netty.ShuffleCopier
@@ -47,10 +48,6 @@ import org.apache.spark.util.Utils
4748
private[storage]
4849
trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
4950
def initialize()
50-
def numLocalBlocks: Int
51-
def numRemoteBlocks: Int
52-
def fetchWaitTime: Long
53-
def remoteBytesRead: Long
5451
}
5552

5653

@@ -72,14 +69,12 @@ object BlockFetcherIterator {
7269
class BasicBlockFetcherIterator(
7370
private val blockManager: BlockManager,
7471
val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
75-
serializer: Serializer)
72+
serializer: Serializer,
73+
readMetrics: ShuffleReadMetrics)
7674
extends BlockFetcherIterator {
7775

7876
import blockManager._
7977

80-
private var _remoteBytesRead = 0L
81-
private var _fetchWaitTime = 0L
82-
8378
if (blocksByAddress == null) {
8479
throw new IllegalArgumentException("BlocksByAddress is null")
8580
}
@@ -89,13 +84,9 @@ object BlockFetcherIterator {
8984

9085
protected var startTime = System.currentTimeMillis
9186

92-
// This represents the number of local blocks, also counting zero-sized blocks
93-
private var numLocal = 0
9487
// BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
9588
protected val localBlocksToFetch = new ArrayBuffer[BlockId]()
9689

97-
// This represents the number of remote blocks, also counting zero-sized blocks
98-
private var numRemote = 0
9990
// BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
10091
protected val remoteBlocksToFetch = new HashSet[BlockId]()
10192

@@ -132,7 +123,10 @@ object BlockFetcherIterator {
132123
val networkSize = blockMessage.getData.limit()
133124
results.put(new FetchResult(blockId, sizeMap(blockId),
134125
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
135-
_remoteBytesRead += networkSize
126+
// TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can
127+
// be incrementing bytes read at the same time (SPARK-2625).
128+
readMetrics.remoteBytesRead += networkSize
129+
readMetrics.remoteBlocksFetched += 1
136130
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
137131
}
138132
}
@@ -155,14 +149,14 @@ object BlockFetcherIterator {
155149
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
156150
// at most maxBytesInFlight in order to limit the amount of data in flight.
157151
val remoteRequests = new ArrayBuffer[FetchRequest]
152+
var totalBlocks = 0
158153
for ((address, blockInfos) <- blocksByAddress) {
154+
totalBlocks += blockInfos.size
159155
if (address == blockManagerId) {
160-
numLocal = blockInfos.size
161156
// Filter out zero-sized blocks
162157
localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
163158
_numBlocksToFetch += localBlocksToFetch.size
164159
} else {
165-
numRemote += blockInfos.size
166160
val iterator = blockInfos.iterator
167161
var curRequestSize = 0L
168162
var curBlocks = new ArrayBuffer[(BlockId, Long)]
@@ -192,7 +186,7 @@ object BlockFetcherIterator {
192186
}
193187
}
194188
logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " +
195-
(numLocal + numRemote) + " blocks")
189+
totalBlocks + " blocks")
196190
remoteRequests
197191
}
198192

@@ -205,6 +199,7 @@ object BlockFetcherIterator {
205199
// getLocalFromDisk never return None but throws BlockException
206200
val iter = getLocalFromDisk(id, serializer).get
207201
// Pass 0 as size since it's not in flight
202+
readMetrics.localBlocksFetched += 1
208203
results.put(new FetchResult(id, 0, () => iter))
209204
logDebug("Got local block " + id)
210205
} catch {
@@ -238,12 +233,6 @@ object BlockFetcherIterator {
238233
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
239234
}
240235

241-
override def numLocalBlocks: Int = numLocal
242-
override def numRemoteBlocks: Int = numRemote
243-
override def fetchWaitTime: Long = _fetchWaitTime
244-
override def remoteBytesRead: Long = _remoteBytesRead
245-
246-
247236
// Implementing the Iterator methods with an iterator that reads fetched blocks off the queue
248237
// as they arrive.
249238
@volatile protected var resultsGotten = 0
@@ -255,7 +244,7 @@ object BlockFetcherIterator {
255244
val startFetchWait = System.currentTimeMillis()
256245
val result = results.take()
257246
val stopFetchWait = System.currentTimeMillis()
258-
_fetchWaitTime += (stopFetchWait - startFetchWait)
247+
readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
259248
if (! result.failed) bytesInFlight -= result.size
260249
while (!fetchRequests.isEmpty &&
261250
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
@@ -269,8 +258,9 @@ object BlockFetcherIterator {
269258
class NettyBlockFetcherIterator(
270259
blockManager: BlockManager,
271260
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
272-
serializer: Serializer)
273-
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
261+
serializer: Serializer,
262+
readMetrics: ShuffleReadMetrics)
263+
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) {
274264

275265
import blockManager._
276266

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props}
2929
import sun.nio.ch.DirectBuffer
3030

3131
import org.apache.spark._
32-
import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleWriteMetrics}
32+
import org.apache.spark.executor._
3333
import org.apache.spark.io.CompressionCodec
3434
import org.apache.spark.network._
3535
import org.apache.spark.serializer.Serializer
@@ -539,12 +539,15 @@ private[spark] class BlockManager(
539539
*/
540540
def getMultiple(
541541
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
542-
serializer: Serializer): BlockFetcherIterator = {
542+
serializer: Serializer,
543+
readMetrics: ShuffleReadMetrics): BlockFetcherIterator = {
543544
val iter =
544545
if (conf.getBoolean("spark.shuffle.use.netty", false)) {
545-
new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
546+
new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer,
547+
readMetrics)
546548
} else {
547-
new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
549+
new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer,
550+
readMetrics)
548551
}
549552
iter.initialize()
550553
iter

core/src/main/scala/org/apache/spark/util/JsonProtocol.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,9 +560,8 @@ private[spark] object JsonProtocol {
560560
metrics.resultSerializationTime = (json \ "Result Serialization Time").extract[Long]
561561
metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long]
562562
metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long]
563-
Utils.jsonOption(json \ "Shuffle Read Metrics").map { shuffleReadMetrics =>
564-
metrics.updateShuffleReadMetrics(shuffleReadMetricsFromJson(shuffleReadMetrics))
565-
}
563+
metrics.setShuffleReadMetrics(
564+
Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson))
566565
metrics.shuffleWriteMetrics =
567566
Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
568567
metrics.inputMetrics =

core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.mockito.invocation.InvocationOnMock
3333

3434
import org.apache.spark.storage.BlockFetcherIterator._
3535
import org.apache.spark.network.{ConnectionManager, Message}
36+
import org.apache.spark.executor.ShuffleReadMetrics
3637

3738
class BlockFetcherIteratorSuite extends FunSuite with Matchers {
3839

@@ -70,8 +71,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
7071
(bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
7172
)
7273

73-
val iterator = new BasicBlockFetcherIterator(blockManager,
74-
blocksByAddress, null)
74+
val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null,
75+
new ShuffleReadMetrics())
7576

7677
iterator.initialize()
7778

@@ -121,8 +122,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
121122
(bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
122123
)
123124

124-
val iterator = new BasicBlockFetcherIterator(blockManager,
125-
blocksByAddress, null)
125+
val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null,
126+
new ShuffleReadMetrics())
126127

127128
iterator.initialize()
128129

@@ -165,7 +166,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
165166
)
166167

167168
val iterator = new BasicBlockFetcherIterator(blockManager,
168-
blocksByAddress, null)
169+
blocksByAddress, null, new ShuffleReadMetrics())
169170

170171
iterator.initialize()
171172
iterator.foreach{
@@ -219,7 +220,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
219220
)
220221

221222
val iterator = new BasicBlockFetcherIterator(blockManager,
222-
blocksByAddress, null)
223+
blocksByAddress, null, new ShuffleReadMetrics())
223224
iterator.initialize()
224225
iterator.foreach{
225226
case (_, r) => {

core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
6565

6666
// finish this task, should get updated shuffleRead
6767
shuffleReadMetrics.remoteBytesRead = 1000
68-
taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
68+
taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
6969
var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
7070
taskInfo.finishTime = 1
7171
var task = new ShuffleMapTask(0)
@@ -142,7 +142,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
142142
val taskMetrics = new TaskMetrics()
143143
val shuffleReadMetrics = new ShuffleReadMetrics()
144144
val shuffleWriteMetrics = new ShuffleWriteMetrics()
145-
taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
145+
taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
146146
taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
147147
shuffleReadMetrics.remoteBytesRead = base + 1
148148
shuffleReadMetrics.remoteBlocksFetched = base + 2

core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ class JsonProtocolSuite extends FunSuite {
539539
sr.localBlocksFetched = e
540540
sr.fetchWaitTime = a + d
541541
sr.remoteBlocksFetched = f
542-
t.updateShuffleReadMetrics(sr)
542+
t.setShuffleReadMetrics(Some(sr))
543543
}
544544
sw.shuffleBytesWritten = a + b + c
545545
sw.shuffleWriteTime = b + c + d

0 commit comments

Comments
 (0)