Skip to content

Commit 647bc45

Browse files
committed
Fix scheduler to account for tasks using > 1 CPUs.
Move CPUS_PER_TASK to TaskSchedulerImpl as the value is a constant and use it in both Mesos and CoarseGrained scheduler backends.
1 parent 8265dc7 commit 647bc45

File tree

5 files changed

+52
-17
lines changed

5 files changed

+52
-17
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ private[spark] class TaskSchedulerImpl(
6262
// Threshold above which we warn user initial TaskSet may be starved
6363
val STARVATION_TIMEOUT = conf.getLong("spark.starvation.timeout", 15000)
6464

65+
// CPUs to request per task
66+
val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1)
67+
6568
// TaskSetManagers are not thread safe, so any access to one should be synchronized
6669
// on this class.
6770
val activeTaskSets = new HashMap[String, TaskSetManager]
@@ -228,16 +231,18 @@ private[spark] class TaskSchedulerImpl(
228231
for (i <- 0 until shuffledOffers.size) {
229232
val execId = shuffledOffers(i).executorId
230233
val host = shuffledOffers(i).host
231-
for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) {
232-
tasks(i) += task
233-
val tid = task.taskId
234-
taskIdToTaskSetId(tid) = taskSet.taskSet.id
235-
taskIdToExecutorId(tid) = execId
236-
activeExecutorIds += execId
237-
executorsByHost(host) += execId
238-
availableCpus(i) -= taskSet.CPUS_PER_TASK
239-
assert (availableCpus(i) >= 0)
240-
launchedTask = true
234+
if (availableCpus(i) >= CPUS_PER_TASK) {
235+
for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) {
236+
tasks(i) += task
237+
val tid = task.taskId
238+
taskIdToTaskSetId(tid) = taskSet.taskSet.id
239+
taskIdToExecutorId(tid) = execId
240+
activeExecutorIds += execId
241+
executorsByHost(host) += execId
242+
availableCpus(i) -= CPUS_PER_TASK
243+
assert (availableCpus(i) >= 0)
244+
launchedTask = true
245+
}
241246
}
242247
}
243248
} while (launchedTask)

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ private[spark] class TaskSetManager(
5656
{
5757
val conf = sched.sc.conf
5858

59-
// CPUs to request per task
60-
val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1)
61-
6259
/*
6360
* Sometimes if an executor is dead or in an otherwise invalid state, the driver
6461
* does not realize right away leading to repeated task failures. If enabled,
@@ -388,7 +385,7 @@ private[spark] class TaskSetManager(
388385
maxLocality: TaskLocality.TaskLocality)
389386
: Option[TaskDescription] =
390387
{
391-
if (!isZombie && availableCpus >= CPUS_PER_TASK) {
388+
if (!isZombie) {
392389
val curTime = clock.getTime()
393390

394391
var allowedLocality = getAllowedLocalityLevel(curTime)

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
8989
scheduler.statusUpdate(taskId, state, data.value)
9090
if (TaskState.isFinished(state)) {
9191
if (executorActor.contains(executorId)) {
92-
freeCores(executorId) += 1
92+
freeCores(executorId) += scheduler.CPUS_PER_TASK
9393
makeOffers(executorId)
9494
} else {
9595
// Ignoring the update since we don't know about the executor.
@@ -140,7 +140,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
140140
// Launch tasks returned by a set of resource offers
141141
def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
142142
for (task <- tasks.flatten) {
143-
freeCores(task.executorId) -= 1
143+
freeCores(task.executorId) -= scheduler.CPUS_PER_TASK
144144
executorActor(task.executorId) ! LaunchTask(task)
145145
}
146146
}

core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ private[spark] class MesosSchedulerBackend(
246246
val cpuResource = Resource.newBuilder()
247247
.setName("cpus")
248248
.setType(Value.Type.SCALAR)
249-
.setScalar(Value.Scalar.newBuilder().setValue(1).build())
249+
.setScalar(Value.Scalar.newBuilder().setValue(scheduler.CPUS_PER_TASK).build())
250250
.build()
251251
MesosTaskInfo.newBuilder()
252252
.setTaskId(taskId)

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,4 +293,37 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin
293293
assert(count > 0)
294294
assert(count < numTrials)
295295
}
296+
297+
test("Scheduler correctly accounts for multiple CPUs per task") {
298+
sc = new SparkContext("local", "TaskSchedulerImplSuite")
299+
val taskCpus = 2
300+
301+
sc.conf.set("spark.task.cpus", taskCpus.toString)
302+
val taskScheduler = new TaskSchedulerImpl(sc)
303+
taskScheduler.initialize(new FakeSchedulerBackend)
304+
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
305+
val dagScheduler = new DAGScheduler(sc, taskScheduler) {
306+
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
307+
override def executorAdded(execId: String, host: String) {}
308+
}
309+
310+
val numFreeCores = 1
311+
val singleCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores),
312+
new WorkerOffer("executor1", "host1", numFreeCores))
313+
314+
// No tasks should run as we only have 1 core free.
315+
val taskSet = FakeTask.createTaskSet(1)
316+
taskScheduler.submitTasks(taskSet)
317+
var taskDescriptions = taskScheduler.resourceOffers(singleCoreWorkerOffers).flatten
318+
assert(0 === taskDescriptions.length)
319+
320+
// Now change the offers to have 2 cores in one executor and verify if it
321+
// is chosen.
322+
val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus),
323+
new WorkerOffer("executor1", "host1", numFreeCores))
324+
taskScheduler.submitTasks(taskSet)
325+
taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten
326+
assert(1 === taskDescriptions.length)
327+
assert("executor0" === taskDescriptions(0).executorId)
328+
}
296329
}

0 commit comments

Comments
 (0)