Skip to content

Commit d988e7f

Browse files
committed
Add an option to migrate shuffle blocks as well as the current cache blocks during decommissioning
1 parent dc0709f commit d988e7f

File tree

22 files changed

+655
-73
lines changed

22 files changed

+655
-73
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ import org.apache.spark.util._
4949
*
5050
* All public methods of this class are thread-safe.
5151
*/
52-
private class ShuffleStatus(numPartitions: Int) {
52+
private class ShuffleStatus(numPartitions: Int) extends Logging {
5353

5454
private val (readLock, writeLock) = {
5555
val lock = new ReentrantReadWriteLock()
@@ -121,12 +121,28 @@ private class ShuffleStatus(numPartitions: Int) {
121121
mapStatuses(mapIndex) = status
122122
}
123123

124+
/**
125+
* Update the map output location (e.g. during migration).
126+
*/
127+
def updateMapOutput(mapId: Long, bmAddress: BlockManagerId): Unit = withWriteLock {
128+
val mapStatusOpt = mapStatuses.find(_.mapId == mapId)
129+
mapStatusOpt match {
130+
case Some(mapStatus) =>
131+
logInfo("Updating map output for ${mapId} to ${bmAddress}")
132+
mapStatus.updateLocation(bmAddress)
133+
invalidateSerializedMapOutputStatusCache()
134+
case None =>
135+
logError("Asked to update map output ${mapId} for untracked map status.")
136+
}
137+
}
138+
124139
/**
125140
* Remove the map output which was served by the specified block manager.
126141
* This is a no-op if there is no registered map output or if the registered output is from a
127142
* different block manager.
128143
*/
129144
def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
145+
logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}")
130146
if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) {
131147
_numAvailableOutputs -= 1
132148
mapStatuses(mapIndex) = null
@@ -139,6 +155,7 @@ private class ShuffleStatus(numPartitions: Int) {
139155
* outputs which are served by an external shuffle server (if one exists).
140156
*/
141157
def removeOutputsOnHost(host: String): Unit = withWriteLock {
158+
logDebug(s"Removing outputs for host ${host}")
142159
removeOutputsByFilter(x => x.host == host)
143160
}
144161

@@ -148,6 +165,7 @@ private class ShuffleStatus(numPartitions: Int) {
148165
* still registered with that execId.
149166
*/
150167
def removeOutputsOnExecutor(execId: String): Unit = withWriteLock {
168+
logDebug(s"Removing outputs for execId ${execId}")
151169
removeOutputsByFilter(x => x.executorId == execId)
152170
}
153171

@@ -265,7 +283,7 @@ private[spark] class MapOutputTrackerMasterEndpoint(
265283
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
266284
case GetMapOutputStatuses(shuffleId: Int) =>
267285
val hostPort = context.senderAddress.hostPort
268-
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
286+
logInfo(s"Asked to send map output locations for shuffle ${shuffleId} to ${hostPort}")
269287
tracker.post(new GetMapOutputMessage(shuffleId, context))
270288

271289
case StopMapOutputTracker =>
@@ -479,6 +497,16 @@ private[spark] class MapOutputTrackerMaster(
479497
}
480498
}
481499

500+
def updateMapOutput(shuffleId: Int, mapId: Long, bmAddress: BlockManagerId): Unit = {
501+
shuffleStatuses.get(shuffleId) match {
502+
case Some(shuffleStatus) =>
503+
shuffleStatus.updateMapOutput(mapId, bmAddress)
504+
shuffleStatus.invalidateSerializedMapOutputStatusCache()
505+
case None =>
506+
logError(s"Asked to update map output for unknown shuffle ${shuffleId}")
507+
}
508+
}
509+
482510
def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = {
483511
shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
484512
}

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ import org.apache.spark.resource._
5757
import org.apache.spark.resource.ResourceUtils._
5858
import org.apache.spark.rpc.RpcEndpointRef
5959
import org.apache.spark.scheduler._
60-
import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend
60+
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend}
6161
import org.apache.spark.scheduler.local.LocalSchedulerBackend
6262
import org.apache.spark.shuffle.ShuffleDataIOUtils
6363
import org.apache.spark.shuffle.api.ShuffleDriverComponents
@@ -1725,6 +1725,16 @@ class SparkContext(config: SparkConf) extends Logging {
17251725
}
17261726
}
17271727

1728+
1729+
private[spark] def decommissionExecutors(executorIds: Seq[String]): Unit = {
1730+
schedulerBackend match {
1731+
case b: CoarseGrainedSchedulerBackend =>
1732+
executorIds.foreach(b.decommissionExecutor)
1733+
case _ =>
1734+
logWarning("Decommissioning executors is not supported by current scheduler.")
1735+
}
1736+
}
1737+
17281738
/** The version of Spark on which this application is running. */
17291739
def version: String = SPARK_VERSION
17301740

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ object SparkEnv extends Logging {
367367
externalShuffleClient
368368
} else {
369369
None
370-
}, blockManagerInfo)),
370+
}, blockManagerInfo,
371+
mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])),
371372
registerOrLookupEndpoint(
372373
BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME,
373374
new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)),

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,21 @@ package object config {
420420
.booleanConf
421421
.createWithDefault(false)
422422

423+
private[spark] val STORAGE_SHUFFLE_DECOMMISSION_ENABLED =
424+
ConfigBuilder("spark.storage.decommission.shuffle_blocks")
425+
.doc("Whether to transfer shuffle blocks during block manager decommissioning. Requires " +
426+
"an indexed shuffle resolver (like sort based shuffe)")
427+
.version("3.1.0")
428+
.booleanConf
429+
.createWithDefault(false)
430+
431+
private[spark] val STORAGE_RDD_DECOMMISSION_ENABLED =
432+
ConfigBuilder("spark.storage.decommission.rdd_blocks")
433+
.doc("Whether to transfer RDD blocks during block manager decommissioning.")
434+
.version("3.1.0")
435+
.booleanConf
436+
.createWithDefault(false)
437+
423438
private[spark] val STORAGE_DECOMMISSION_MAX_REPLICATION_FAILURE_PER_BLOCK =
424439
ConfigBuilder("spark.storage.decommission.maxReplicationFailuresPerBlock")
425440
.internal()

core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ import org.apache.spark.util.Utils
3333
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
3434
*/
3535
private[spark] sealed trait MapStatus {
36-
/** Location where this task was run. */
36+
/** Location where this task output is. */
3737
def location: BlockManagerId
3838

39+
def updateLocation(bm: BlockManagerId): Unit
40+
3941
/**
4042
* Estimated size for the reduce block, in bytes.
4143
*
@@ -126,6 +128,10 @@ private[spark] class CompressedMapStatus(
126128

127129
override def location: BlockManagerId = loc
128130

131+
override def updateLocation(bm: BlockManagerId): Unit = {
132+
loc = bm
133+
}
134+
129135
override def getSizeForBlock(reduceId: Int): Long = {
130136
MapStatus.decompressSize(compressedSizes(reduceId))
131137
}
@@ -178,6 +184,10 @@ private[spark] class HighlyCompressedMapStatus private (
178184

179185
override def location: BlockManagerId = loc
180186

187+
override def updateLocation(bm: BlockManagerId): Unit = {
188+
loc = bm
189+
}
190+
181191
override def getSizeForBlock(reduceId: Int): Long = {
182192
assert(hugeBlockSizes != null)
183193
if (emptyBlocks.contains(reduceId)) {

core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ private[spark] class StandaloneSchedulerBackend(
4444
with StandaloneAppClientListener
4545
with Logging {
4646

47-
private var client: StandaloneAppClient = null
47+
private[spark] var client: StandaloneAppClient = null
4848
private val stopping = new AtomicBoolean(false)
4949
private val launcherBackend = new LauncherBackend() {
5050
override protected def conf: SparkConf = sc.conf

core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@
1818
package org.apache.spark.shuffle
1919

2020
import java.io._
21+
import java.nio.ByteBuffer
2122
import java.nio.channels.Channels
2223
import java.nio.file.Files
2324

2425
import org.apache.spark.{SparkConf, SparkEnv}
2526
import org.apache.spark.internal.Logging
2627
import org.apache.spark.io.NioBufferedFileInputStream
2728
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
29+
import org.apache.spark.network.client.StreamCallbackWithID
2830
import org.apache.spark.network.netty.SparkTransportConf
2931
import org.apache.spark.network.shuffle.ExecutorDiskUtils
32+
import org.apache.spark.serializer.SerializerManager
3033
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
3134
import org.apache.spark.storage._
3235
import org.apache.spark.util.Utils
@@ -46,7 +49,7 @@ private[spark] class IndexShuffleBlockResolver(
4649
conf: SparkConf,
4750
_blockManager: BlockManager = null)
4851
extends ShuffleBlockResolver
49-
with Logging {
52+
with Logging with MigratableResolver {
5053

5154
private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager)
5255

@@ -55,6 +58,25 @@ private[spark] class IndexShuffleBlockResolver(
5558

5659
def getDataFile(shuffleId: Int, mapId: Long): File = getDataFile(shuffleId, mapId, None)
5760

61+
/**
62+
* Get the shuffle files that are stored locally. Used for block migrations.
63+
*/
64+
override def getStoredShuffles(): Set[(Int, Long)] = {
65+
// Matches ShuffleIndexBlockId name
66+
val pattern = "shuffle_(\\d+)_(\\d+)_.+\\.index".r
67+
val rootDirs = blockManager.diskBlockManager.localDirs
68+
// ExecutorDiskUtil puts things inside one level hashed sub directories
69+
val searchDirs = rootDirs.flatMap(_.listFiles()).filter(_.isDirectory()) ++ rootDirs
70+
val filenames = searchDirs.flatMap(_.list())
71+
logDebug(s"Got block files ${filenames.toList}")
72+
filenames.flatMap{ fname =>
73+
pattern.findAllIn(fname).matchData.map {
74+
matched => (matched.group(1).toInt, matched.group(2).toLong)
75+
}
76+
}.toSet
77+
}
78+
79+
5880
/**
5981
* Get the shuffle data file.
6082
*
@@ -148,6 +170,86 @@ private[spark] class IndexShuffleBlockResolver(
148170
}
149171
}
150172

173+
/**
174+
* Write a provided shuffle block as a stream. Used for block migrations.
175+
* ShuffleBlockBatchIds must contain the full range represented in the ShuffleIndexBlock.
176+
* Requires the caller to delete any shuffle index blocks where the shuffle block fails to
177+
* put.
178+
*/
179+
override def putShuffleBlockAsStream(blockId: BlockId, serializerManager: SerializerManager):
180+
StreamCallbackWithID = {
181+
val file = blockId match {
182+
case ShuffleIndexBlockId(shuffleId, mapId, _) =>
183+
getIndexFile(shuffleId, mapId)
184+
case ShuffleDataBlockId(shuffleId, mapId, _) =>
185+
getDataFile(shuffleId, mapId)
186+
case _ =>
187+
throw new Exception(s"Unexpected shuffle block transfer ${blockId} as " +
188+
s"${blockId.getClass().getSimpleName()}")
189+
}
190+
val fileTmp = Utils.tempFileWith(file)
191+
val channel = Channels.newChannel(
192+
serializerManager.wrapStream(blockId,
193+
new FileOutputStream(fileTmp)))
194+
195+
new StreamCallbackWithID {
196+
197+
override def getID: String = blockId.name
198+
199+
override def onData(streamId: String, buf: ByteBuffer): Unit = {
200+
while (buf.hasRemaining) {
201+
channel.write(buf)
202+
}
203+
}
204+
205+
override def onComplete(streamId: String): Unit = {
206+
logTrace(s"Done receiving block $blockId, now putting into local shuffle service")
207+
channel.close()
208+
val diskSize = fileTmp.length()
209+
this.synchronized {
210+
if (file.exists()) {
211+
file.delete()
212+
}
213+
if (!fileTmp.renameTo(file)) {
214+
throw new IOException(s"fail to rename file ${fileTmp} to ${file}")
215+
}
216+
}
217+
blockManager.reportBlockStatus(blockId, BlockStatus(
218+
StorageLevel(
219+
useDisk = true,
220+
useMemory = false,
221+
useOffHeap = false,
222+
deserialized = false,
223+
replication = 0)
224+
, 0, diskSize))
225+
}
226+
227+
override def onFailure(streamId: String, cause: Throwable): Unit = {
228+
// the framework handles the connection itself, we just need to do local cleanup
229+
channel.close()
230+
fileTmp.delete()
231+
}
232+
}
233+
}
234+
235+
/**
236+
* Get the index & data block for migration.
237+
*/
238+
def getMigrationBlocks(shuffleId: Int, mapId: Long): List[(BlockId, ManagedBuffer)] = {
239+
// Load the index block
240+
val indexFile = getIndexFile(shuffleId, mapId)
241+
val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
242+
val indexFileSize = indexFile.length()
243+
val indexBlockData = new FileSegmentManagedBuffer(transportConf, indexFile, 0, indexFileSize)
244+
245+
// Load the data block
246+
val dataFile = getDataFile(shuffleId, mapId)
247+
val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
248+
val dataBlockData = new FileSegmentManagedBuffer(transportConf, dataFile, 0, dataFile.length())
249+
List((indexBlockId, indexBlockData), (dataBlockId, dataBlockData))
250+
}
251+
252+
151253
/**
152254
* Write an index file with the offsets of each block, plus a final offset at the end for the
153255
* end of the output file. This will be used by getBlockData to figure out where each block
@@ -169,7 +271,7 @@ private[spark] class IndexShuffleBlockResolver(
169271
val dataFile = getDataFile(shuffleId, mapId)
170272
// There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
171273
// the following check and rename are atomic.
172-
synchronized {
274+
this.synchronized {
173275
val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
174276
if (existingLengths != null) {
175277
// Another attempt for the same task has already written our map outputs successfully,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle
19+
20+
import org.apache.spark.annotation.Experimental
21+
import org.apache.spark.network.buffer.ManagedBuffer
22+
import org.apache.spark.network.client.StreamCallbackWithID
23+
import org.apache.spark.serializer.SerializerManager
24+
import org.apache.spark.storage.BlockId
25+
26+
/**
27+
* :: Experimental ::
28+
* An experimental trait to allow Spark to migrate shuffle blocks.
29+
*/
30+
@Experimental
31+
trait MigratableResolver {
32+
/**
33+
* Get the shuffle ids that are stored locally. Used for block migrations.
34+
*/
35+
def getStoredShuffles(): Set[(Int, Long)]
36+
37+
/**
38+
* Write a provided shuffle block as a stream. Used for block migrations.
39+
*/
40+
def putShuffleBlockAsStream(blockId: BlockId, serializerManager: SerializerManager):
41+
StreamCallbackWithID
42+
43+
/**
44+
* Get the blocks for migration for a particular shuffle and map.
45+
*/
46+
def getMigrationBlocks(shuffleId: Int, mapId: Long): List[(BlockId, ManagedBuffer)]
47+
}

core/src/main/scala/org/apache/spark/storage/BlockId.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ sealed abstract class BlockId {
4040
def isRDD: Boolean = isInstanceOf[RDDBlockId]
4141
def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] || isInstanceOf[ShuffleBlockBatchId]
4242
def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId]
43+
def isInternalShuffle: Boolean = {
44+
isInstanceOf[ShuffleDataBlockId] || isInstanceOf[ShuffleIndexBlockId]
45+
}
4346

4447
override def toString: String = name
4548
}

0 commit comments

Comments
 (0)