Skip to content

Commit 9160149

Browse files
committed
SPARK-2532: Minimal shuffle consolidation fixes
All changes from this PR are by @mridulm and are drawn from his work in #1609. This patch is intended to fix all major issues related to shuffle file consolidation that @mridulm found, while minimizing changes to the code, with the hope that it may be more easily merged into 1.1. This patch is **not** intended as a replacement for #1609, which provides many additional benefits, including fixes to ExternalAppendOnlyMap, improvements to DiskBlockObjectWriter's API, and several new unit tests. If it is feasible to merge #1609 for the 1.1 deadline, that is a preferable option.
1 parent e966284 commit 9160149

File tree

6 files changed

+142
-47
lines changed

6 files changed

+142
-47
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,25 @@ private[spark] class HashShuffleWriter[K, V](
6565
}
6666

6767
/** Close this writer, passing along whether the map completed */
68-
override def stop(success: Boolean): Option[MapStatus] = {
68+
override def stop(initiallySuccess: Boolean): Option[MapStatus] = {
69+
var success = initiallySuccess
6970
try {
7071
if (stopping) {
7172
return None
7273
}
7374
stopping = true
7475
if (success) {
7576
try {
76-
return Some(commitWritesAndBuildStatus())
77+
Some(commitWritesAndBuildStatus())
7778
} catch {
7879
case e: Exception =>
80+
success = false
7981
revertWrites()
8082
throw e
8183
}
8284
} else {
8385
revertWrites()
84-
return None
86+
None
8587
}
8688
} finally {
8789
// Release the writers back to the shuffle block manager.
@@ -100,8 +102,7 @@ private[spark] class HashShuffleWriter[K, V](
100102
var totalBytes = 0L
101103
var totalTime = 0L
102104
val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter =>
103-
writer.commit()
104-
writer.close()
105+
writer.commitAndClose()
105106
val size = writer.fileSegment().length
106107
totalBytes += size
107108
totalTime += writer.timeWriting()
@@ -120,8 +121,7 @@ private[spark] class HashShuffleWriter[K, V](
120121
private def revertWrites(): Unit = {
121122
if (shuffle != null && shuffle.writers != null) {
122123
for (writer <- shuffle.writers) {
123-
writer.revertPartialWrites()
124-
writer.close()
124+
writer.revertPartialWritesAndClose()
125125
}
126126
}
127127
}

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,16 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
3939
def isOpen: Boolean
4040

4141
/**
42-
* Flush the partial writes and commit them as a single atomic block. Return the
43-
* number of bytes written for this commit.
42+
* Flush the partial writes and commit them as a single atomic block.
4443
*/
45-
def commit(): Long
44+
def commitAndClose(): Unit
4645

4746
/**
4847
* Reverts writes that haven't been flushed yet. Callers should invoke this function
49-
* when there are runtime exceptions.
48+
* when there are runtime exceptions. This method will not throw, though it may be
49+
* unsuccessful in truncating written data.
5050
*/
51-
def revertPartialWrites()
51+
def revertPartialWritesAndClose()
5252

5353
/**
5454
* Writes an object.
@@ -57,6 +57,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
5757

5858
/**
5959
* Returns the file segment of committed data that this Writer has written.
60+
* This is only valid after commitAndClose() has been called.
6061
*/
6162
def fileSegment(): FileSegment
6263

@@ -108,15 +109,14 @@ private[spark] class DiskBlockObjectWriter(
108109
private var ts: TimeTrackingOutputStream = null
109110
private var objOut: SerializationStream = null
110111
private val initialPosition = file.length()
111-
private var lastValidPosition = initialPosition
112+
private var finalPosition: Long = -1
112113
private var initialized = false
113114
private var _timeWriting = 0L
114115

115116
override def open(): BlockObjectWriter = {
116117
fos = new FileOutputStream(file, true)
117118
ts = new TimeTrackingOutputStream(fos)
118119
channel = fos.getChannel()
119-
lastValidPosition = initialPosition
120120
bs = compressStream(new BufferedOutputStream(ts, bufferSize))
121121
objOut = serializer.newInstance().serializeStream(bs)
122122
initialized = true
@@ -147,28 +147,36 @@ private[spark] class DiskBlockObjectWriter(
147147

148148
override def isOpen: Boolean = objOut != null
149149

150-
override def commit(): Long = {
150+
override def commitAndClose(): Unit = {
151151
if (initialized) {
152152
// NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
153153
// serializer stream and the lower level stream.
154154
objOut.flush()
155155
bs.flush()
156-
val prevPos = lastValidPosition
157-
lastValidPosition = channel.position()
158-
lastValidPosition - prevPos
159-
} else {
160-
// lastValidPosition is zero if stream is uninitialized
161-
lastValidPosition
156+
close()
157+
finalPosition = file.length()
162158
}
163159
}
164160

165-
override def revertPartialWrites() {
166-
if (initialized) {
167-
// Discard current writes. We do this by flushing the outstanding writes and
168-
// truncate the file to the last valid position.
169-
objOut.flush()
170-
bs.flush()
171-
channel.truncate(lastValidPosition)
161+
// Discard current writes. We do this by flushing the outstanding writes and then
162+
// truncating the file to its initial position.
163+
override def revertPartialWritesAndClose() {
164+
try {
165+
if (initialized) {
166+
objOut.flush()
167+
bs.flush()
168+
close()
169+
}
170+
171+
val truncateStream = new FileOutputStream(file, true)
172+
try {
173+
truncateStream.getChannel.truncate(initialPosition)
174+
} finally {
175+
truncateStream.close()
176+
}
177+
} catch {
178+
case e: Exception =>
179+
logError("Uncaught exception while reverting partial writes to file " + file, e)
172180
}
173181
}
174182

@@ -188,6 +196,7 @@ private[spark] class DiskBlockObjectWriter(
188196

189197
// Only valid if called after commit()
190198
override def bytesWritten: Long = {
191-
lastValidPosition - initialPosition
199+
assert(finalPosition != -1, "bytesWritten is only valid after successful commit()")
200+
finalPosition - initialPosition
192201
}
193202
}

core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
144144
if (consolidateShuffleFiles) {
145145
if (success) {
146146
val offsets = writers.map(_.fileSegment().offset)
147-
fileGroup.recordMapOutput(mapId, offsets)
147+
val lengths = writers.map(_.fileSegment().length)
148+
fileGroup.recordMapOutput(mapId, offsets, lengths)
148149
}
149150
recycleFileGroup(fileGroup)
150151
} else {
@@ -247,47 +248,48 @@ object ShuffleBlockManager {
247248
* A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
248249
*/
249250
private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) {
251+
private var numBlocks: Int = 0
252+
250253
/**
251254
* Stores the absolute index of each mapId in the files of this group. For instance,
252255
* if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
253256
*/
254257
private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
255258

256259
/**
257-
* Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.
258-
* This ordering allows us to compute block lengths by examining the following block offset.
260+
* Stores consecutive offsets and lengths of blocks into each reducer file, ordered by
261+
* position in the file.
259262
* Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every
260263
* reducer.
261264
*/
262265
private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) {
263266
new PrimitiveVector[Long]()
264267
}
265-
266-
def numBlocks = mapIdToIndex.size
268+
private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) {
269+
new PrimitiveVector[Long]()
270+
}
267271

268272
def apply(bucketId: Int) = files(bucketId)
269273

270-
def recordMapOutput(mapId: Int, offsets: Array[Long]) {
274+
def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) {
275+
assert(offsets.length == lengths.length)
271276
mapIdToIndex(mapId) = numBlocks
277+
numBlocks += 1
272278
for (i <- 0 until offsets.length) {
273279
blockOffsetsByReducer(i) += offsets(i)
280+
blockLengthsByReducer(i) += lengths(i)
274281
}
275282
}
276283

277284
/** Returns the FileSegment associated with the given map task, or None if no entry exists. */
278285
def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = {
279286
val file = files(reducerId)
280287
val blockOffsets = blockOffsetsByReducer(reducerId)
288+
val blockLengths = blockLengthsByReducer(reducerId)
281289
val index = mapIdToIndex.getOrElse(mapId, -1)
282290
if (index >= 0) {
283291
val offset = blockOffsets(index)
284-
val length =
285-
if (index + 1 < numBlocks) {
286-
blockOffsets(index + 1) - offset
287-
} else {
288-
file.length() - offset
289-
}
290-
assert(length >= 0)
292+
val length = blockLengths(index)
291293
Some(new FileSegment(file, offset, length))
292294
} else {
293295
None

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ class ExternalAppendOnlyMap[K, V, C](
199199

200200
// Flush the disk writer's contents to disk, and update relevant variables
201201
def flush() = {
202-
writer.commit()
202+
writer.commitAndClose()
203203
val bytesWritten = writer.bytesWritten
204204
batchSizes.append(bytesWritten)
205205
_diskBytesSpilled += bytesWritten

core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,14 @@ import java.io.{File, FileWriter}
2222
import scala.collection.mutable
2323
import scala.language.reflectiveCalls
2424

25+
import akka.actor.Props
2526
import com.google.common.io.Files
2627
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
2728

2829
import org.apache.spark.SparkConf
29-
import org.apache.spark.util.Utils
30+
import org.apache.spark.scheduler.LiveListenerBus
31+
import org.apache.spark.serializer.JavaSerializer
32+
import org.apache.spark.util.{AkkaUtils, Utils}
3033

3134
class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
3235
private val testConf = new SparkConf(false)
@@ -121,6 +124,88 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
121124
newFile.delete()
122125
}
123126

127+
private def checkSegments(segment1: FileSegment, segment2: FileSegment) {
128+
assert (segment1.file.getCanonicalPath === segment2.file.getCanonicalPath)
129+
assert (segment1.offset === segment2.offset)
130+
assert (segment1.length === segment2.length)
131+
}
132+
133+
test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") {
134+
135+
val serializer = new JavaSerializer(testConf)
136+
val confCopy = testConf.clone
137+
// reset after EACH object write. This is to ensure that there are bytes appended after
138+
// an object is written. So if the codepaths assume writeObject is end of data, this should
139+
// flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc.
140+
confCopy.set("spark.serializer.objectStreamReset", "1")
141+
142+
val securityManager = new org.apache.spark.SecurityManager(confCopy)
143+
// Do not use the shuffleBlockManager above !
144+
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, confCopy,
145+
securityManager)
146+
val master = new BlockManagerMaster(
147+
actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))),
148+
confCopy)
149+
val store = new BlockManager("<driver>", actorSystem, master , serializer, confCopy,
150+
securityManager, null)
151+
152+
try {
153+
154+
val shuffleManager = store.shuffleBlockManager
155+
156+
val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer)
157+
for (writer <- shuffle1.writers) {
158+
writer.write("test1")
159+
writer.write("test2")
160+
}
161+
for (writer <- shuffle1.writers) {
162+
writer.commitAndClose()
163+
}
164+
165+
val shuffle1Segment = shuffle1.writers(0).fileSegment()
166+
shuffle1.releaseWriters(success = true)
167+
168+
val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf))
169+
170+
for (writer <- shuffle2.writers) {
171+
writer.write("test3")
172+
writer.write("test4")
173+
}
174+
for (writer <- shuffle2.writers) {
175+
writer.commitAndClose()
176+
}
177+
val shuffle2Segment = shuffle2.writers(0).fileSegment()
178+
shuffle2.releaseWriters(success = true)
179+
180+
// Now comes the test :
181+
// Write to shuffle 3; and close it, but before registering it, check if the file lengths for
182+
// previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length
183+
// of block based on remaining data in file : which could mess things up when there is concurrent read
184+
// and writes happening to the same shuffle group.
185+
186+
val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf))
187+
for (writer <- shuffle3.writers) {
188+
writer.write("test3")
189+
writer.write("test4")
190+
}
191+
for (writer <- shuffle3.writers) {
192+
writer.commitAndClose()
193+
}
194+
// check before we register.
195+
checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0)))
196+
shuffle3.releaseWriters(success = true)
197+
checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0)))
198+
shuffleManager.removeShuffle(1)
199+
} finally {
200+
201+
if (store != null) {
202+
store.stop()
203+
}
204+
actorSystem.shutdown()
205+
actorSystem.awaitTermination()
206+
}
207+
}
208+
124209
def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) {
125210
val segment = diskBlockManager.getBlockLocation(blockId)
126211
assert(segment.file.getName === filename)

tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,9 @@ object StoragePerfTester {
6161
for (i <- 1 to recordsPerMap) {
6262
writers(i % numOutputSplits).write(writeData)
6363
}
64-
writers.map {w =>
65-
w.commit()
64+
writers.map { w =>
65+
w.commitAndClose()
6666
total.addAndGet(w.fileSegment().length)
67-
w.close()
6867
}
6968

7069
shuffle.releaseWriters(true)

0 commit comments

Comments
 (0)