Skip to content

Commit f5b4f48

Browse files
add UT
1 parent 467731c commit f5b4f48

File tree

2 files changed

+64
-23
lines changed

2 files changed

+64
-23
lines changed

core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,23 @@ object FakeTask {
4242
* locations for each task (given as varargs) if this sequence is not empty.
4343
*/
4444
def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
45-
createTaskSet(numTasks, stageAttemptId = 0, prefLocs: _*)
45+
createTaskSet(numTasks, stageId = 0, stageAttemptId = 0, priority = 0, prefLocs: _*)
4646
}
4747

48-
def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
49-
createTaskSet(numTasks, stageId = 0, stageAttemptId, prefLocs: _*)
48+
def createTaskSet(
49+
numTasks: Int,
50+
stageId: Int,
51+
stageAttemptId: Int,
52+
prefLocs: Seq[TaskLocation]*): TaskSet = {
53+
createTaskSet(numTasks, stageId, stageAttemptId, priority = 0, prefLocs: _*)
5054
}
5155

52-
def createTaskSet(numTasks: Int, stageId: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*):
53-
TaskSet = {
56+
def createTaskSet(
57+
numTasks: Int,
58+
stageId: Int,
59+
stageAttemptId: Int,
60+
priority: Int,
61+
prefLocs: Seq[TaskLocation]*): TaskSet = {
5462
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
5563
throw new IllegalArgumentException("Wrong number of task locations")
5664
}
@@ -65,6 +73,15 @@ object FakeTask {
6573
stageId: Int,
6674
stageAttemptId: Int,
6775
prefLocs: Seq[TaskLocation]*): TaskSet = {
76+
createShuffleMapTaskSet(numTasks, stageId, stageAttemptId, priority = 0, prefLocs: _*)
77+
}
78+
79+
def createShuffleMapTaskSet(
80+
numTasks: Int,
81+
stageId: Int,
82+
stageAttemptId: Int,
83+
priority: Int,
84+
prefLocs: Seq[TaskLocation]*): TaskSet = {
6885
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
6986
throw new IllegalArgumentException("Wrong number of task locations")
7087
}
@@ -74,24 +91,25 @@ object FakeTask {
7491
}, prefLocs(i), new Properties,
7592
SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array())
7693
}
77-
new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null)
94+
new TaskSet(tasks, stageId, stageAttemptId, priority = priority, null)
7895
}
7996

8097
def createBarrierTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
81-
createBarrierTaskSet(numTasks, stageId = 0, stageAttempId = 0, prefLocs: _*)
98+
createBarrierTaskSet(numTasks, stageId = 0, stageAttemptId = 0, priority = 0, prefLocs: _*)
8299
}
83100

84101
def createBarrierTaskSet(
85102
numTasks: Int,
86103
stageId: Int,
87-
stageAttempId: Int,
104+
stageAttemptId: Int,
105+
priority: Int,
88106
prefLocs: Seq[TaskLocation]*): TaskSet = {
89107
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
90108
throw new IllegalArgumentException("Wrong number of task locations")
91109
}
92110
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
93111
new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil, isBarrier = true)
94112
}
95-
new TaskSet(tasks, stageId, stageAttempId, priority = 0, null)
113+
new TaskSet(tasks, stageId, stageAttemptId, priority = priority, null)
96114
}
97115
}

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -228,19 +228,19 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
228228
taskScheduler.taskSetManagerForAttempt(taskset.stageId, taskset.stageAttemptId).get.isZombie
229229
}
230230

231-
val attempt1 = FakeTask.createTaskSet(1, 0)
231+
val attempt1 = FakeTask.createTaskSet(1, stageId = 0, stageAttemptId = 0)
232232
taskScheduler.submitTasks(attempt1)
233233
// The first submitted taskset is active
234234
assert(!isTasksetZombie(attempt1))
235235

236-
val attempt2 = FakeTask.createTaskSet(1, 1)
236+
val attempt2 = FakeTask.createTaskSet(1, stageId = 0, stageAttemptId = 1)
237237
taskScheduler.submitTasks(attempt2)
238238
// The first submitted taskset is zombie now
239239
assert(isTasksetZombie(attempt1))
240240
// The newly submitted taskset is active
241241
assert(!isTasksetZombie(attempt2))
242242

243-
val attempt3 = FakeTask.createTaskSet(1, 2)
243+
val attempt3 = FakeTask.createTaskSet(1, stageId = 0, stageAttemptId = 2)
244244
taskScheduler.submitTasks(attempt3)
245245
// The first submitted taskset remains zombie
246246
assert(isTasksetZombie(attempt1))
@@ -255,7 +255,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
255255

256256
val numFreeCores = 1
257257
val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores))
258-
val attempt1 = FakeTask.createTaskSet(10)
258+
val attempt1 = FakeTask.createTaskSet(10, stageId = 0, stageAttemptId = 0)
259259

260260
// submit attempt 1, offer some resources, some tasks get scheduled
261261
taskScheduler.submitTasks(attempt1)
@@ -271,7 +271,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
271271
assert(0 === taskDescriptions2.length)
272272

273273
// if we schedule another attempt for the same stage, it should get scheduled
274-
val attempt2 = FakeTask.createTaskSet(10, 1)
274+
val attempt2 = FakeTask.createTaskSet(10, stageId = 0, stageAttemptId = 1)
275275

276276
// submit attempt 2, offer some resources, some tasks get scheduled
277277
taskScheduler.submitTasks(attempt2)
@@ -287,7 +287,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
287287

288288
val numFreeCores = 10
289289
val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", numFreeCores))
290-
val attempt1 = FakeTask.createTaskSet(10)
290+
val attempt1 = FakeTask.createTaskSet(10, stageId = 0, stageAttemptId = 0)
291291

292292
// submit attempt 1, offer some resources, some tasks get scheduled
293293
taskScheduler.submitTasks(attempt1)
@@ -303,7 +303,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
303303
assert(0 === taskDescriptions2.length)
304304

305305
// submit attempt 2
306-
val attempt2 = FakeTask.createTaskSet(10, 1)
306+
val attempt2 = FakeTask.createTaskSet(10, stageId = 0, stageAttemptId = 1)
307307
taskScheduler.submitTasks(attempt2)
308308

309309
// attempt 1 finished (this can happen even if it was marked zombie earlier -- all tasks were
@@ -497,7 +497,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
497497

498498
test("abort stage when all executors are blacklisted and we cannot acquire new executor") {
499499
taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
500-
val taskSet = FakeTask.createTaskSet(numTasks = 10, stageAttemptId = 0)
500+
val taskSet = FakeTask.createTaskSet(numTasks = 10)
501501
taskScheduler.submitTasks(taskSet)
502502
val tsm = stageToMockTaskSetManager(0)
503503

@@ -539,7 +539,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
539539
config.UNSCHEDULABLE_TASKSET_TIMEOUT.key -> "0")
540540

541541
// We have only 1 task remaining with 1 executor
542-
val taskSet = FakeTask.createTaskSet(numTasks = 1, stageAttemptId = 0)
542+
val taskSet = FakeTask.createTaskSet(numTasks = 1)
543543
taskScheduler.submitTasks(taskSet)
544544
val tsm = stageToMockTaskSetManager(0)
545545

@@ -571,7 +571,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
571571
config.UNSCHEDULABLE_TASKSET_TIMEOUT.key -> "10")
572572

573573
// We have only 1 task remaining with 1 executor
574-
val taskSet = FakeTask.createTaskSet(numTasks = 1, stageAttemptId = 0)
574+
val taskSet = FakeTask.createTaskSet(numTasks = 1)
575575
taskScheduler.submitTasks(taskSet)
576576
val tsm = stageToMockTaskSetManager(0)
577577

@@ -910,7 +910,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
910910
test("SPARK-16106 locality levels updated if executor added to existing host") {
911911
val taskScheduler = setupScheduler()
912912

913-
taskScheduler.submitTasks(FakeTask.createTaskSet(2, 0,
913+
taskScheduler.submitTasks(FakeTask.createTaskSet(2, stageId = 0, stageAttemptId = 0,
914914
(0 until 2).map { _ => Seq(TaskLocation("host0", "executor2")) }: _*
915915
))
916916

@@ -948,7 +948,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
948948
test("scheduler checks for executors that can be expired from blacklist") {
949949
taskScheduler = setupScheduler()
950950

951-
taskScheduler.submitTasks(FakeTask.createTaskSet(1, 0))
951+
taskScheduler.submitTasks(FakeTask.createTaskSet(1, stageId = 0, stageAttemptId = 0))
952952
taskScheduler.resourceOffers(IndexedSeq(
953953
new WorkerOffer("executor0", "host0", 1)
954954
)).flatten
@@ -1154,6 +1154,29 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
11541154
assert(3 === taskDescriptions.length)
11551155
}
11561156

1157+
test("SPARK-29263: barrier TaskSet can't schedule when higher prio taskset takes the slots") {
1158+
val taskCpus = 2
1159+
val taskScheduler = setupSchedulerWithMaster(
1160+
s"local[$taskCpus]",
1161+
config.CPUS_PER_TASK.key -> taskCpus.toString)
1162+
1163+
val numFreeCores = 3
1164+
val workerOffers = IndexedSeq(
1165+
new WorkerOffer("executor0", "host0", numFreeCores, Some("192.168.0.101:49625")),
1166+
new WorkerOffer("executor1", "host1", numFreeCores, Some("192.168.0.101:49627")),
1167+
new WorkerOffer("executor2", "host2", numFreeCores, Some("192.168.0.101:49629")))
1168+
val barrier = FakeTask.createBarrierTaskSet(3, stageId = 0, stageAttemptId = 0, priority = 1)
1169+
val highPrio = FakeTask.createTaskSet(1, stageId = 1, stageAttemptId = 0, priority = 0)
1170+
1171+
// submit highPrio and barrier taskSet
1172+
taskScheduler.submitTasks(highPrio)
1173+
taskScheduler.submitTasks(barrier)
1174+
val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
1175+
// it schedules the highPrio task first, and then will not have enough slots to schedule
1176+
// the barrier taskset
1177+
assert(1 === taskDescriptions.length)
1178+
}
1179+
11571180
test("cancelTasks shall kill all the running tasks and fail the stage") {
11581181
val taskScheduler = setupScheduler()
11591182

@@ -1169,7 +1192,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
11691192
}
11701193
})
11711194

1172-
val attempt1 = FakeTask.createTaskSet(10, 0)
1195+
val attempt1 = FakeTask.createTaskSet(10)
11731196
taskScheduler.submitTasks(attempt1)
11741197

11751198
val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1),
@@ -1200,7 +1223,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
12001223
}
12011224
})
12021225

1203-
val attempt1 = FakeTask.createTaskSet(10, 0)
1226+
val attempt1 = FakeTask.createTaskSet(10)
12041227
taskScheduler.submitTasks(attempt1)
12051228

12061229
val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1),

0 commit comments

Comments
 (0)