Skip to content

Commit 6340f9b

Browse files
committed
Fix the migration to store ShuffleDataBlockId, check that data and index blocks have both been migrated, check that RDD blocks are duplicated not just broadcast blocks, make the number of partitions smaller so the test can run faster, avoid the Thread.sleep for all of the tests except for the midflight test where we need it, check for the broadcast blocks landing (further along in scheduling) beyond just task start, force fetching the shuffle block to local disk if in shuffle block test mode, start the job as soon as the first executor comes online.
1 parent 069dd3b commit 6340f9b

File tree

2 files changed

+74
-36
lines changed

2 files changed

+74
-36
lines changed

core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ private[spark] class IndexShuffleBlockResolver(
180180
StreamCallbackWithID = {
181181
val file = blockId match {
182182
case ShuffleIndexBlockId(shuffleId, mapId, _) =>
183-
getIndexFile(shuffleId, mapId)
184-
case ShuffleBlockBatchId(shuffleId, mapId, _, _) =>
183+
getIndexFile(shuffleId, mapId)
184+
case ShuffleDataBlockId(shuffleId, mapId, _) =>
185185
getDataFile(shuffleId, mapId)
186186
case _ =>
187187
throw new Exception(s"Unexpected shuffle block transfer ${blockId} as " +

core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionSuite.scala

Lines changed: 72 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ import scala.concurrent.duration._
2424

2525
import org.scalatest.concurrent.Eventually
2626

27-
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, Success,
28-
TestUtils}
27+
import org.apache.spark._
2928
import org.apache.spark.internal.config
3029
import org.apache.spark.scheduler._
3130
import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend
@@ -35,41 +34,51 @@ class BlockManagerDecommissionSuite extends SparkFunSuite with LocalSparkContext
3534
with ResetSystemProperties with Eventually {
3635

3736
val numExecs = 3
37+
val numParts = 3
3838

3939
test(s"verify that an already running task which is going to cache data succeeds " +
4040
s"on a decommissioned executor") {
41-
runDecomTest(true, false)
41+
runDecomTest(true, false, true)
4242
}
4343

4444
test(s"verify that shuffle blocks are migrated.") {
45-
runDecomTest(false, true)
45+
runDecomTest(false, true, false)
4646
}
4747

4848
test(s"verify that both migrations can work at the same time.") {
49-
runDecomTest(true, true)
49+
runDecomTest(true, true, false)
5050
}
5151

52-
private def runDecomTest(persist: Boolean, shuffle: Boolean) = {
52+
private def runDecomTest(persist: Boolean, shuffle: Boolean, migrateDuring: Boolean) = {
5353
val master = s"local-cluster[${numExecs}, 1, 1024]"
5454
val conf = new SparkConf().setAppName("test").setMaster(master)
5555
.set(config.Worker.WORKER_DECOMMISSION_ENABLED, true)
5656
.set(config.STORAGE_DECOMMISSION_ENABLED, true)
5757
.set(config.STORAGE_RDD_DECOMMISSION_ENABLED, persist)
5858
.set(config.STORAGE_SHUFFLE_DECOMMISSION_ENABLED, shuffle)
59+
// Just replicate blocks as fast as we can during testing, there isn't another
60+
// workload we need to worry about.
5961
.set(config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL, 1L)
6062

63+
// Force fetching to local disk
64+
if (shuffle) {
65+
conf.set("spark.network.maxRemoteBlockSizeFetchToMem", "1")
66+
}
67+
6168
sc = new SparkContext(master, "test", conf)
6269

6370
// Create input RDD with 10 partitions
64-
val input = sc.parallelize(1 to 10, 10)
71+
val input = sc.parallelize(1 to numParts, numParts)
6572
val accum = sc.longAccumulator("mapperRunAccumulator")
6673
// Do a count to wait for the executors to be registered.
6774
input.count()
6875

6976
// Create a new RDD where we have sleep in each partition, we are also increasing
7077
// the value of accumulator in each partition
7178
val sleepyRdd = input.mapPartitions { x =>
72-
Thread.sleep(250)
79+
if (migrateDuring) {
80+
Thread.sleep(500)
81+
}
7382
accum.add(1)
7483
x.map(y => (y, y))
7584
}
@@ -79,19 +88,26 @@ class BlockManagerDecommissionSuite extends SparkFunSuite with LocalSparkContext
7988
}
8089

8190
// Listen for the job & block updates
82-
val sem = new Semaphore(0)
91+
val taskStartSem = new Semaphore(0)
92+
val broadcastSem = new Semaphore(0)
8393
val taskEndEvents = ArrayBuffer.empty[SparkListenerTaskEnd]
8494
val blocksUpdated = ArrayBuffer.empty[SparkListenerBlockUpdated]
8595
sc.addSparkListener(new SparkListener {
96+
8697
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
87-
sem.release()
98+
taskStartSem.release()
8899
}
89100

90101
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
91102
taskEndEvents.append(taskEnd)
92103
}
93104

94105
override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = {
106+
// Once broadcast start landing on the executors we're good to proceed.
107+
// We don't only use task start as it can occur before the work is on the executor.
108+
if (blockUpdated.blockUpdatedInfo.blockId.isBroadcast) {
109+
broadcastSem.release()
110+
}
95111
blocksUpdated.append(blockUpdated)
96112
}
97113
})
@@ -102,19 +118,32 @@ class BlockManagerDecommissionSuite extends SparkFunSuite with LocalSparkContext
102118
testRdd.persist()
103119
}
104120

105-
// Wait for all of the executors to start
121+
// Wait for the first executor to start
106122
TestUtils.waitUntilExecutorsUp(sc = sc,
107-
numExecutors = numExecs,
123+
numExecutors = 1,
108124
timeout = 10000) // 10s
109125

110126
// Start the computation of RDD - this step will also cache the RDD
111127
val asyncCount = testRdd.countAsync()
112128

113-
// Wait for the job to have started
114-
sem.acquire(1)
129+
// Wait for all of the executors to start
130+
TestUtils.waitUntilExecutorsUp(sc = sc,
131+
numExecutors = numExecs,
132+
timeout = 10000) // 10s
115133

116-
// Give Spark a tiny bit to start the tasks after the listener says hello
117-
Thread.sleep(50)
134+
// Wait for the job to have started.
135+
taskStartSem.acquire(1)
136+
// Wait for each executor + driver to have it's broadcast info delivered.
137+
broadcastSem.acquire((numExecs + 1))
138+
139+
// Make sure the job is either mid run or otherwise has data to migrate.
140+
if (migrateDuring) {
141+
// Give Spark a tiny bit to start executing after the broadcast blocks land.
142+
// For me this works at 100, set to 300 for system variance.
143+
Thread.sleep(300)
144+
} else {
145+
ThreadUtils.awaitResult(asyncCount, 15.seconds)
146+
}
118147

119148
// Decommission one of the executor
120149
val sched = sc.schedulerBackend.asInstanceOf[StandaloneSchedulerBackend]
@@ -127,49 +156,58 @@ class BlockManagerDecommissionSuite extends SparkFunSuite with LocalSparkContext
127156

128157
// Wait for job to finish
129158
val asyncCountResult = ThreadUtils.awaitResult(asyncCount, 15.seconds)
130-
assert(asyncCountResult === 10)
131-
// All 10 tasks finished, so accum should have been increased 10 times
132-
assert(accum.value === 10)
159+
assert(asyncCountResult === numParts)
160+
// All tasks finished, so accum should have been increased numParts times
161+
assert(accum.value === numParts)
133162

134163
// All tasks should be successful, nothing should have failed
135164
sc.listenerBus.waitUntilEmpty()
136165
if (shuffle) {
137-
// 10 mappers & 10 reducers which succeeded
138-
assert(taskEndEvents.count(_.reason == Success) === 20,
139-
s"Expected 20 tasks got ${taskEndEvents.size} (${taskEndEvents})")
166+
// mappers & reducers which succeeded
167+
assert(taskEndEvents.count(_.reason == Success) === 2 * numParts,
168+
s"Expected ${2 * numParts} tasks got ${taskEndEvents.size} (${taskEndEvents})")
140169
} else {
141-
// 10 mappers which executed successfully
142-
assert(taskEndEvents.count(_.reason == Success) === 10,
143-
s"Expected 10 tasks got ${taskEndEvents.size} (${taskEndEvents})")
170+
// only mappers which executed successfully
171+
assert(taskEndEvents.count(_.reason == Success) === numParts,
172+
s"Expected ${numParts} tasks got ${taskEndEvents.size} (${taskEndEvents})")
144173
}
145174

146175
// Wait for our respective blocks to have migrated
147176
eventually(timeout(15.seconds), interval(10.milliseconds)) {
148177
if (persist) {
149178
// One of our blocks should have moved.
150-
val blockLocs = blocksUpdated.map{ update =>
179+
val rddUpdates = blocksUpdated.filter{update =>
180+
val blockId = update.blockUpdatedInfo.blockId
181+
blockId.isRDD}
182+
val blockLocs = rddUpdates.map{ update =>
151183
(update.blockUpdatedInfo.blockId.name,
152184
update.blockUpdatedInfo.blockManagerId)}
153185
val blocksToManagers = blockLocs.groupBy(_._1).mapValues(_.toSet.size)
154186
assert(!blocksToManagers.filter(_._2 > 1).isEmpty,
155-
s"We should have a block that has been on multiple BMs in ${blocksUpdated}")
187+
s"We should have a block that has been on multiple BMs in rdds:\n ${rddUpdates} from:\n" +
188+
s"${blocksUpdated}\n but instead we got:\n ${blocksToManagers}")
156189
}
157190
// If we're migrating shuffles we look for any shuffle block updates
158191
// as there is no block update on the initial shuffle block write.
159192
if (shuffle) {
160-
val numLocs = blocksUpdated.filter{ update =>
193+
val numDataLocs = blocksUpdated.filter{ update =>
194+
val blockId = update.blockUpdatedInfo.blockId
195+
blockId.isInstanceOf[ShuffleDataBlockId]
196+
}.toSet.size
197+
val numIndexLocs = blocksUpdated.filter{ update =>
161198
val blockId = update.blockUpdatedInfo.blockId
162-
blockId.isShuffle || blockId.isInternalShuffle
199+
blockId.isInstanceOf[ShuffleIndexBlockId]
163200
}.toSet.size
164-
assert(numLocs > 0, s"No shuffle block updates in ${blocksUpdated}")
201+
assert(numDataLocs >= 1, s"Expect shuffle data block updates in ${blocksUpdated}")
202+
assert(numIndexLocs >= 1, s"Expect shuffle index block updates in ${blocksUpdated}")
165203
}
166204
}
167205

168206
// Since the RDD is cached or shuffled so further usage of same RDD should use the
169207
// cached data. Original RDD partitions should not be recomputed i.e. accum
170208
// should have same value like before
171-
assert(testRdd.count() === 10)
172-
assert(accum.value === 10)
209+
assert(testRdd.count() === numParts)
210+
assert(accum.value === numParts)
173211

174212
val storageStatus = sc.env.blockManager.master.getStorageStatus
175213
val execIdToBlocksMapping = storageStatus.map(
@@ -178,8 +216,8 @@ class BlockManagerDecommissionSuite extends SparkFunSuite with LocalSparkContext
178216
assert(execIdToBlocksMapping(execToDecommission).keys.filter(_.isRDD).toSeq === Seq(),
179217
"Cache blocks should be migrated")
180218
if (persist) {
181-
// There should still be all 10 RDD blocks cached
182-
assert(execIdToBlocksMapping.values.flatMap(_.keys).count(_.isRDD) === 10)
219+
// There should still be all the RDD blocks cached
220+
assert(execIdToBlocksMapping.values.flatMap(_.keys).count(_.isRDD) === numParts)
183221
}
184222
}
185223
}

0 commit comments

Comments
 (0)