Skip to content

SPARK-1601 & SPARK-1602: two bug fixes related to cancellation #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions core/src/main/scala/org/apache/spark/CacheManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
if (loading.contains(key)) {
logInfo("Another thread is loading %s, waiting for it to finish...".format(key))
while (loading.contains(key)) {
try {loading.wait()} catch {case _ : Throwable =>}
try {
loading.wait()
} catch {
case e: Exception =>
logWarning(s"Got an exception while waiting for another thread to load $key", e)
}
}
logInfo("Finished waiting for %s".format(key))
/* See whether someone else has successfully loaded it. The main way this would fail
Expand All @@ -72,7 +77,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val computedValues = rdd.computeOrReadCheckpoint(split, context)

// Persist the result, so long as the task is not running locally
if (context.runningLocally) { return computedValues }
if (context.runningLocally) {
return computedValues
}

// Keep track of blocks with updated statuses
var updatedBlocks = Seq[(BlockId, BlockStatus)]()
Expand All @@ -88,7 +95,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
updatedBlocks = blockManager.put(key, computedValues, storageLevel, tellMaster = true)
blockManager.get(key) match {
case Some(values) =>
new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
values.asInstanceOf[Iterator[T]]
case None =>
logInfo("Failure to store %s".format(key))
throw new Exception("Block manager failed to return persisted valued")
Expand All @@ -107,7 +114,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val metrics = context.taskMetrics
metrics.updatedBlocks = Some(updatedBlocks)

returnValue
new InterruptibleIterator(context, returnValue)

} finally {
loading.synchronized {
Expand Down
12 changes: 11 additions & 1 deletion core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@ package org.apache.spark
private[spark] class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
extends Iterator[T] {

def hasNext: Boolean = !context.interrupted && delegate.hasNext
def hasNext: Boolean = {
// TODO(aarondav/rxin): Check Thread.interrupted instead of context.interrupted if interrupt
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
// introduces an expensive read fence.
if (context.interrupted) {
throw new TaskKilledException
} else {
delegate.hasNext
}
}

def next(): T = delegate.next()
}
23 changes: 23 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskKilledException.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

/**
* Exception for a task getting killed.
*/
private[spark] class TaskKilledException extends RuntimeException
8 changes: 3 additions & 5 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,6 @@ private[spark] class Executor(
class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {

object TaskKilledException extends Exception

@volatile private var killed = false
@volatile private var task: Task[Any] = _

Expand Down Expand Up @@ -200,7 +198,7 @@ private[spark] class Executor(
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw TaskKilledException
throw new TaskKilledException
}

attemptedTask = Some(task)
Expand All @@ -214,7 +212,7 @@ private[spark] class Executor(

// If the task has been killed, let's fail it.
if (task.killed) {
throw TaskKilledException
throw new TaskKilledException
}

val resultSer = SparkEnv.get.serializer.newInstance()
Expand Down Expand Up @@ -257,7 +255,7 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}

case TaskKilledException | _: InterruptedException if task.killed => {
case _: TaskKilledException | _: InterruptedException if task.killed => {
logInfo("Executor killed task " + taskId)
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
}
Expand Down
43 changes: 38 additions & 5 deletions core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,35 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(sc.parallelize(1 to 10, 2).count === 10)
}

test("do not put partially executed partitions into cache") {
// In this test case, we create a scenario in which a partition is only partially executed,
// and make sure CacheManager does not put that partially executed partition into the
// BlockManager.
import JobCancellationSuite._
sc = new SparkContext("local", "test")

// Run from 1 to 10, and then block and wait for the task to be killed.
val rdd = sc.parallelize(1 to 1000, 2).map { x =>
if (x > 10) {
taskStartedSemaphore.release()
taskCancelledSemaphore.acquire()
}
x
}.cache()

val rdd1 = rdd.map(x => x)

future {
taskStartedSemaphore.acquire()
sc.cancelAllJobs()
taskCancelledSemaphore.release(100000)
}

intercept[SparkException] { rdd1.count() }
// If the partial block is put into cache, rdd.count() would return a number less than 1000.
assert(rdd.count() === 1000)
}

test("job group") {
sc = new SparkContext("local[2]", "test")

Expand Down Expand Up @@ -114,7 +143,6 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(jobB.get() === 100)
}


test("job group with interruption") {
sc = new SparkContext("local[2]", "test")

Expand Down Expand Up @@ -145,15 +173,14 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(jobB.get() === 100)
}

/*
test("two jobs sharing the same stage") {
ignore("two jobs sharing the same stage") {
// sem1: make sure cancel is issued after some tasks are launched
// sem2: make sure the first stage is not finished until cancel is issued
val sem1 = new Semaphore(0)
val sem2 = new Semaphore(0)

sc = new SparkContext("local[2]", "test")
sc.dagScheduler.addSparkListener(new SparkListener {
sc.addSparkListener(new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart) {
sem1.release()
}
Expand All @@ -179,7 +206,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
intercept[SparkException] { f1.get() }
intercept[SparkException] { f2.get() }
}
*/

def testCount() {
// Cancel before launching any tasks
{
Expand Down Expand Up @@ -238,3 +265,9 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
}
}
}


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra new line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually two blank lines to separate top level objects is a habit and sometimes recommended style :)

object JobCancellationSuite {
val taskStartedSemaphore = new Semaphore(0)
val taskCancelledSemaphore = new Semaphore(0)
}