Skip to content

Commit 7677aec

Browse files
author
pgandhi
committed
[SPARK-25250] : Addressing Reviews January 2, 2019
1 parent ee5bc68 commit 7677aec

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -295,16 +295,17 @@ private[spark] class TaskSchedulerImpl(
295295
override def markPartitionIdAsCompletedAndKillCorrespondingTaskAttempts(
296296
partitionId: Int, stageId: Int): Unit = {
297297
taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
298-
val index: Option[Int] = tsm.partitionToIndex.get(partitionId)
299-
if (!index.isEmpty) {
300-
tsm.markPartitionIdAsCompletedForTaskAttempt(index.get)
301-
val taskInfoList = tsm.taskAttempts(index.get)
302-
taskInfoList.foreach { taskInfo =>
303-
if (taskInfo.running) {
304-
killTaskAttempt(taskInfo.taskId, false, "Corresponding Partition Id " + partitionId +
305-
" has been marked as Completed")
298+
tsm.partitionToIndex.get(partitionId) match {
299+
case Some(index) =>
300+
tsm.markPartitionIdAsCompletedForTaskAttempt(index)
301+
val taskInfoList = tsm.taskAttempts(index)
302+
taskInfoList.filter(_.running).foreach { taskInfo =>
303+
killTaskAttempt(taskInfo.taskId, false,
304+
s"Corresponding Partition ID $partitionId has been marked as Completed")
306305
}
307-
}
306+
307+
case None =>
308+
logError(s"No corresponding index found for partition ID $partitionId")
308309
}
309310
}
310311
}

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
package org.apache.spark.scheduler
1919

2020
import java.nio.ByteBuffer
21-
import java.util.HashSet
2221

2322
import scala.collection.mutable.HashMap
23+
import scala.collection.mutable.Set
2424
import scala.concurrent.duration._
2525

2626
import org.mockito.Matchers.{anyInt, anyObject, anyString, eq => meq}
@@ -40,7 +40,7 @@ class FakeSchedulerBackend extends SchedulerBackend {
4040
def reviveOffers() {}
4141
def defaultParallelism(): Int = 1
4242
def maxNumConcurrentTasks(): Int = 0
43-
val killedTaskIds: HashSet[Long] = new HashSet[Long]()
43+
val killedTaskIds: Set[Long] = Set[Long]()
4444
override def killTask(
4545
taskId: Long,
4646
executorId: String,
@@ -1328,22 +1328,30 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
13281328
tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, TaskKilled("test"))
13291329
assert(tsm.isZombie)
13301330
}
1331+
13311332
test("SPARK-25250 On successful completion of a task attempt on a partition id, kill other" +
13321333
" running task attempts on that same partition") {
13331334
val taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
1335+
13341336
val firstAttempt = FakeTask.createTaskSet(10, stageAttemptId = 0)
13351337
taskScheduler.submitTasks(firstAttempt)
1338+
13361339
val offersFirstAttempt = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
13371340
taskScheduler.resourceOffers(offersFirstAttempt)
1341+
13381342
val tsm0 = taskScheduler.taskSetManagerForAttempt(0, 0).get
13391343
val matchingTaskInfoFirstAttempt = tsm0.taskAttempts(0).head
13401344
tsm0.handleFailedTask(matchingTaskInfoFirstAttempt.taskId, TaskState.FAILED,
13411345
FetchFailed(null, 0, 0, 0, "fetch failed"))
1346+
13421347
val secondAttempt = FakeTask.createTaskSet(10, stageAttemptId = 1)
13431348
taskScheduler.submitTasks(secondAttempt)
1349+
13441350
val offersSecondAttempt = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
13451351
taskScheduler.resourceOffers(offersSecondAttempt)
1352+
13461353
taskScheduler.markPartitionIdAsCompletedAndKillCorrespondingTaskAttempts(2, 0)
1354+
13471355
val tsm1 = taskScheduler.taskSetManagerForAttempt(0, 1).get
13481356
val indexInTsm = tsm1.partitionToIndex(2)
13491357
val matchingTaskInfoSecondAttempt = tsm1.taskAttempts.flatten.filter(_.index == indexInTsm).head

0 commit comments

Comments
 (0)