Skip to content

Commit c5e68e5

Browse files
committed
SPARK-2566. Update ShuffleWriteMetrics incrementally
1 parent 1aad911 commit c5e68e5

File tree

11 files changed

+162
-82
lines changed

11 files changed

+162
-82
lines changed

core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,10 @@ class ShuffleWriteMetrics extends Serializable {
190190
/**
191191
* Number of bytes written for the shuffle by this task
192192
*/
193-
var shuffleBytesWritten: Long = _
193+
@volatile var shuffleBytesWritten: Long = _
194194

195195
/**
196196
* Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
197197
*/
198-
var shuffleWriteTime: Long = _
198+
@volatile var shuffleWriteTime: Long = _
199199
}

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,14 @@ private[spark] class HashShuffleWriter[K, V](
3939
// we don't try deleting files, etc twice.
4040
private var stopping = false
4141

42+
val writeMetrics = new ShuffleWriteMetrics()
43+
metrics.shuffleWriteMetrics = Some(writeMetrics)
44+
4245
private val blockManager = SparkEnv.get.blockManager
4346
private val shuffleBlockManager = blockManager.shuffleBlockManager
4447
private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
45-
private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser)
48+
private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser,
49+
writeMetrics)
4650

4751
/** Write a bunch of records to this task's output */
4852
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
@@ -99,22 +103,12 @@ private[spark] class HashShuffleWriter[K, V](
99103

100104
private def commitWritesAndBuildStatus(): MapStatus = {
101105
// Commit the writes. Get the size of each bucket block (total block size).
102-
var totalBytes = 0L
103-
var totalTime = 0L
104106
val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter =>
105107
writer.commitAndClose()
106108
val size = writer.fileSegment().length
107-
totalBytes += size
108-
totalTime += writer.timeWriting()
109109
MapOutputTracker.compressSize(size)
110110
}
111111

112-
// Update shuffle metrics.
113-
val shuffleMetrics = new ShuffleWriteMetrics
114-
shuffleMetrics.shuffleBytesWritten = totalBytes
115-
shuffleMetrics.shuffleWriteTime = totalTime
116-
metrics.shuffleWriteMetrics = Some(shuffleMetrics)
117-
118112
new MapStatus(blockManager.blockManagerId, compressedSizes)
119113
}
120114

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ private[spark] class SortShuffleWriter[K, V, C](
5252

5353
private var mapStatus: MapStatus = null
5454

55+
val writeMetrics = new ShuffleWriteMetrics()
56+
context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics)
57+
5558
/** Write a bunch of records to this task's output */
5659
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
5760
// Get an iterator with the elements for each partition ID
@@ -84,32 +87,23 @@ private[spark] class SortShuffleWriter[K, V, C](
8487
val offsets = new Array[Long](numPartitions + 1)
8588
val lengths = new Array[Long](numPartitions)
8689

87-
// Statistics
88-
var totalBytes = 0L
89-
var totalTime = 0L
90-
9190
for ((id, elements) <- partitions) {
9291
if (elements.hasNext) {
93-
val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize)
92+
val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize,
93+
writeMetrics)
9494
for (elem <- elements) {
9595
writer.write(elem)
9696
}
9797
writer.commitAndClose()
9898
val segment = writer.fileSegment()
9999
offsets(id + 1) = segment.offset + segment.length
100100
lengths(id) = segment.length
101-
totalTime += writer.timeWriting()
102-
totalBytes += segment.length
103101
} else {
104102
// The partition is empty; don't create a new writer to avoid writing headers, etc
105103
offsets(id + 1) = offsets(id)
106104
}
107105
}
108106

109-
val shuffleMetrics = new ShuffleWriteMetrics
110-
shuffleMetrics.shuffleBytesWritten = totalBytes
111-
shuffleMetrics.shuffleWriteTime = totalTime
112-
context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics)
113107
context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
114108
context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
115109

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props}
2929
import sun.nio.ch.DirectBuffer
3030

3131
import org.apache.spark._
32-
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
32+
import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleWriteMetrics}
3333
import org.apache.spark.io.CompressionCodec
3434
import org.apache.spark.network._
3535
import org.apache.spark.serializer.Serializer
@@ -560,17 +560,19 @@ private[spark] class BlockManager(
560560

561561
/**
562562
* A short circuited method to get a block writer that can write data directly to disk.
563-
* The Block will be appended to the File specified by filename. This is currently used for
564-
* writing shuffle files out. Callers should handle error cases.
563+
* The Block will be appended to the File specified by filename. Callers should handle error
564+
* cases.
565565
*/
566566
def getDiskWriter(
567567
blockId: BlockId,
568568
file: File,
569569
serializer: Serializer,
570-
bufferSize: Int): BlockObjectWriter = {
570+
bufferSize: Int,
571+
writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = {
571572
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
572573
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
573-
new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites)
574+
new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites,
575+
writeMetrics)
574576
}
575577

576578
/**

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.nio.channels.FileChannel
2222

2323
import org.apache.spark.Logging
2424
import org.apache.spark.serializer.{SerializationStream, Serializer}
25+
import org.apache.spark.executor.ShuffleWriteMetrics
2526

2627
/**
2728
* An interface for writing JVM objects to some underlying storage. This interface allows
@@ -60,41 +61,26 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
6061
* This is only valid after commitAndClose() has been called.
6162
*/
6263
def fileSegment(): FileSegment
63-
64-
/**
65-
* Cumulative time spent performing blocking writes, in ns.
66-
*/
67-
def timeWriting(): Long
68-
69-
/**
70-
* Number of bytes written so far
71-
*/
72-
def bytesWritten: Long
7364
}
7465

75-
/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
66+
/**
67+
* BlockObjectWriter which writes directly to a file on disk. Appends to the given file.
68+
* The given write metrics will be updated incrementally, but will not necessarily be current until
69+
* commitAndClose is called.
70+
*/
7671
private[spark] class DiskBlockObjectWriter(
7772
blockId: BlockId,
7873
file: File,
7974
serializer: Serializer,
8075
bufferSize: Int,
8176
compressStream: OutputStream => OutputStream,
82-
syncWrites: Boolean)
77+
syncWrites: Boolean,
78+
writeMetrics: ShuffleWriteMetrics)
8379
extends BlockObjectWriter(blockId)
8480
with Logging
8581
{
86-
8782
/** Intercepts write calls and tracks total time spent writing. Not thread safe. */
8883
private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
89-
def timeWriting = _timeWriting
90-
private var _timeWriting = 0L
91-
92-
private def callWithTiming(f: => Unit) = {
93-
val start = System.nanoTime()
94-
f
95-
_timeWriting += (System.nanoTime() - start)
96-
}
97-
9884
def write(i: Int): Unit = callWithTiming(out.write(i))
9985
override def write(b: Array[Byte]) = callWithTiming(out.write(b))
10086
override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
@@ -111,7 +97,11 @@ private[spark] class DiskBlockObjectWriter(
11197
private val initialPosition = file.length()
11298
private var finalPosition: Long = -1
11399
private var initialized = false
114-
private var _timeWriting = 0L
100+
101+
/** Calling channel.position() to update the write metrics can be a little bit expensive, so we
102+
* only call it every N writes */
103+
private var writesSinceMetricsUpdate = 0
104+
private var lastPosition = initialPosition
115105

116106
override def open(): BlockObjectWriter = {
117107
fos = new FileOutputStream(file, true)
@@ -128,14 +118,11 @@ private[spark] class DiskBlockObjectWriter(
128118
if (syncWrites) {
129119
// Force outstanding writes to disk and track how long it takes
130120
objOut.flush()
131-
val start = System.nanoTime()
132-
fos.getFD.sync()
133-
_timeWriting += System.nanoTime() - start
121+
def sync = fos.getFD.sync()
122+
callWithTiming(sync)
134123
}
135124
objOut.close()
136125

137-
_timeWriting += ts.timeWriting
138-
139126
channel = null
140127
bs = null
141128
fos = null
@@ -153,6 +140,7 @@ private[spark] class DiskBlockObjectWriter(
153140
// serializer stream and the lower level stream.
154141
objOut.flush()
155142
bs.flush()
143+
updateBytesWritten()
156144
close()
157145
}
158146
finalPosition = file.length()
@@ -162,6 +150,8 @@ private[spark] class DiskBlockObjectWriter(
162150
// truncating the file to its initial position.
163151
override def revertPartialWritesAndClose() {
164152
try {
153+
writeMetrics.shuffleBytesWritten -= (lastPosition - initialPosition)
154+
165155
if (initialized) {
166156
objOut.flush()
167157
bs.flush()
@@ -184,19 +174,36 @@ private[spark] class DiskBlockObjectWriter(
184174
if (!initialized) {
185175
open()
186176
}
177+
187178
objOut.writeObject(value)
179+
180+
if (writesSinceMetricsUpdate == 32) {
181+
writesSinceMetricsUpdate = 0
182+
updateBytesWritten()
183+
} else {
184+
writesSinceMetricsUpdate += 1
185+
}
188186
}
189187

190188
override def fileSegment(): FileSegment = {
191-
new FileSegment(file, initialPosition, bytesWritten)
189+
new FileSegment(file, initialPosition, writeMetrics.shuffleBytesWritten)
192190
}
193191

194-
// Only valid if called after close()
195-
override def timeWriting() = _timeWriting
192+
private def updateBytesWritten() {
193+
val pos = channel.position()
194+
writeMetrics.shuffleBytesWritten += (pos - lastPosition)
195+
lastPosition = pos
196+
}
197+
198+
private def callWithTiming(f: => Unit) = {
199+
val start = System.nanoTime()
200+
f
201+
writeMetrics.shuffleWriteTime += (System.nanoTime() - start)
202+
}
196203

197-
// Only valid if called after commit()
198-
override def bytesWritten: Long = {
199-
assert(finalPosition != -1, "bytesWritten is only valid after successful commit()")
200-
finalPosition - initialPosition
204+
// For testing
205+
private[spark] def flush() {
206+
objOut.flush()
207+
bs.flush()
201208
}
202209
}

core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
2929
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
3030
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
3131
import org.apache.spark.shuffle.sort.SortShuffleManager
32+
import org.apache.spark.executor.ShuffleWriteMetrics
3233

3334
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
3435
private[spark] trait ShuffleWriterGroup {
@@ -111,7 +112,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
111112
* Get a ShuffleWriterGroup for the given map task, which will register it as complete
112113
* when the writers are closed successfully
113114
*/
114-
def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
115+
def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer,
116+
writeMetrics: ShuffleWriteMetrics) = {
115117
new ShuffleWriterGroup {
116118
shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
117119
private val shuffleState = shuffleStates(shuffleId)
@@ -121,7 +123,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
121123
fileGroup = getUnusedFileGroup()
122124
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
123125
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
124-
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
126+
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,
127+
writeMetrics)
125128
}
126129
} else {
127130
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
@@ -136,7 +139,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
136139
logWarning(s"Failed to remove existing shuffle file $blockFile")
137140
}
138141
}
139-
blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
142+
blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)
140143
}
141144
}
142145

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.annotation.DeveloperApi
3131
import org.apache.spark.serializer.{DeserializationStream, Serializer}
3232
import org.apache.spark.storage.{BlockId, BlockManager}
3333
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
34+
import org.apache.spark.executor.ShuffleWriteMetrics
3435

3536
/**
3637
* :: DeveloperApi ::
@@ -102,6 +103,10 @@ class ExternalAppendOnlyMap[K, V, C](
102103
private var _diskBytesSpilled = 0L
103104

104105
private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024
106+
107+
// Write metrics for current spill
108+
private var curWriteMetrics: ShuffleWriteMetrics = _
109+
105110
private val keyComparator = new HashComparator[K]
106111
private val ser = serializer.newInstance()
107112

@@ -172,7 +177,9 @@ class ExternalAppendOnlyMap[K, V, C](
172177
logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
173178
.format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
174179
val (blockId, file) = diskBlockManager.createTempBlock()
175-
var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
180+
curWriteMetrics = new ShuffleWriteMetrics()
181+
var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize,
182+
curWriteMetrics)
176183
var objectsWritten = 0
177184

178185
// List of batch sizes (bytes) in the order they are written to disk
@@ -183,9 +190,8 @@ class ExternalAppendOnlyMap[K, V, C](
183190
val w = writer
184191
writer = null
185192
w.commitAndClose()
186-
val bytesWritten = w.bytesWritten
187-
batchSizes.append(bytesWritten)
188-
_diskBytesSpilled += bytesWritten
193+
_diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
194+
batchSizes.append(curWriteMetrics.shuffleBytesWritten)
189195
objectsWritten = 0
190196
}
191197

@@ -199,7 +205,8 @@ class ExternalAppendOnlyMap[K, V, C](
199205

200206
if (objectsWritten == serializerBatchSize) {
201207
flush()
202-
writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
208+
writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize,
209+
curWriteMetrics)
203210
}
204211
}
205212
if (objectsWritten > 0) {

0 commit comments

Comments
 (0)