Skip to content

Commit 1fdf659

Browse files
committed
SPARK-1601 & SPARK-1602: two bug fixes related to cancellation
This should go into 1.0 since it would return wrong data when the bug happens (which is pretty likely if cancellation is used). Test case attached. 1. Do not put partially executed partitions into cache (in task killing). 2. Iterator returned by CacheManager#getOrCompute was not an InterruptibleIterator, and was thus leading to uninterruptible jobs. Thanks @aarondav and @ahirreddy for reporting and helping debug. Author: Reynold Xin <rxin@apache.org> Closes #521 from rxin/kill and squashes the following commits: 401033f [Reynold Xin] Merge branch 'master' of https://git-wip-us.apache.org/repos/asf/spark into kill 7a7bdd2 [Reynold Xin] Add a new line in the end of JobCancellationSuite.scala. 35cd9f7 [Reynold Xin] Fixed a bug that partially executed partitions can be put into cache (in task killing).
1 parent dd681f5 commit 1fdf659

File tree

5 files changed

+86
-15
lines changed

5 files changed

+86
-15
lines changed

core/src/main/scala/org/apache/spark/CacheManager.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
4747
if (loading.contains(key)) {
4848
logInfo("Another thread is loading %s, waiting for it to finish...".format(key))
4949
while (loading.contains(key)) {
50-
try {loading.wait()} catch {case _ : Throwable =>}
50+
try {
51+
loading.wait()
52+
} catch {
53+
case e: Exception =>
54+
logWarning(s"Got an exception while waiting for another thread to load $key", e)
55+
}
5156
}
5257
logInfo("Finished waiting for %s".format(key))
5358
/* See whether someone else has successfully loaded it. The main way this would fail
@@ -72,7 +77,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
7277
val computedValues = rdd.computeOrReadCheckpoint(split, context)
7378

7479
// Persist the result, so long as the task is not running locally
75-
if (context.runningLocally) { return computedValues }
80+
if (context.runningLocally) {
81+
return computedValues
82+
}
7683

7784
// Keep track of blocks with updated statuses
7885
var updatedBlocks = Seq[(BlockId, BlockStatus)]()
@@ -88,7 +95,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
8895
updatedBlocks = blockManager.put(key, computedValues, storageLevel, tellMaster = true)
8996
blockManager.get(key) match {
9097
case Some(values) =>
91-
new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
98+
values.asInstanceOf[Iterator[T]]
9299
case None =>
93100
logInfo("Failure to store %s".format(key))
94101
throw new Exception("Block manager failed to return persisted valued")
@@ -107,7 +114,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
107114
val metrics = context.taskMetrics
108115
metrics.updatedBlocks = Some(updatedBlocks)
109116

110-
returnValue
117+
new InterruptibleIterator(context, returnValue)
111118

112119
} finally {
113120
loading.synchronized {

core/src/main/scala/org/apache/spark/InterruptibleIterator.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,17 @@ package org.apache.spark
2424
private[spark] class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
2525
extends Iterator[T] {
2626

27-
def hasNext: Boolean = !context.interrupted && delegate.hasNext
27+
def hasNext: Boolean = {
28+
// TODO(aarondav/rxin): Check Thread.interrupted instead of context.interrupted if interrupt
29+
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
30+
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
31+
// introduces an expensive read fence.
32+
if (context.interrupted) {
33+
throw new TaskKilledException
34+
} else {
35+
delegate.hasNext
36+
}
37+
}
2838

2939
def next(): T = delegate.next()
3040
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark
19+
20+
/**
21+
* Exception for a task getting killed.
22+
*/
23+
private[spark] class TaskKilledException extends RuntimeException

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,6 @@ private[spark] class Executor(
161161
class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
162162
extends Runnable {
163163

164-
object TaskKilledException extends Exception
165-
166164
@volatile private var killed = false
167165
@volatile private var task: Task[Any] = _
168166

@@ -200,7 +198,7 @@ private[spark] class Executor(
200198
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
201199
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
202200
// for the task.
203-
throw TaskKilledException
201+
throw new TaskKilledException
204202
}
205203

206204
attemptedTask = Some(task)
@@ -214,7 +212,7 @@ private[spark] class Executor(
214212

215213
// If the task has been killed, let's fail it.
216214
if (task.killed) {
217-
throw TaskKilledException
215+
throw new TaskKilledException
218216
}
219217

220218
val resultSer = SparkEnv.get.serializer.newInstance()
@@ -257,7 +255,7 @@ private[spark] class Executor(
257255
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
258256
}
259257

260-
case TaskKilledException | _: InterruptedException if task.killed => {
258+
case _: TaskKilledException | _: InterruptedException if task.killed => {
261259
logInfo("Executor killed task " + taskId)
262260
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
263261
}

core/src/test/scala/org/apache/spark/JobCancellationSuite.scala

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,35 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
8484
assert(sc.parallelize(1 to 10, 2).count === 10)
8585
}
8686

87+
test("do not put partially executed partitions into cache") {
88+
// In this test case, we create a scenario in which a partition is only partially executed,
89+
// and make sure CacheManager does not put that partially executed partition into the
90+
// BlockManager.
91+
import JobCancellationSuite._
92+
sc = new SparkContext("local", "test")
93+
94+
// Run from 1 to 10, and then block and wait for the task to be killed.
95+
val rdd = sc.parallelize(1 to 1000, 2).map { x =>
96+
if (x > 10) {
97+
taskStartedSemaphore.release()
98+
taskCancelledSemaphore.acquire()
99+
}
100+
x
101+
}.cache()
102+
103+
val rdd1 = rdd.map(x => x)
104+
105+
future {
106+
taskStartedSemaphore.acquire()
107+
sc.cancelAllJobs()
108+
taskCancelledSemaphore.release(100000)
109+
}
110+
111+
intercept[SparkException] { rdd1.count() }
112+
// If the partial block is put into cache, rdd.count() would return a number less than 1000.
113+
assert(rdd.count() === 1000)
114+
}
115+
87116
test("job group") {
88117
sc = new SparkContext("local[2]", "test")
89118

@@ -114,7 +143,6 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
114143
assert(jobB.get() === 100)
115144
}
116145

117-
118146
test("job group with interruption") {
119147
sc = new SparkContext("local[2]", "test")
120148

@@ -145,15 +173,14 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
145173
assert(jobB.get() === 100)
146174
}
147175

148-
/*
149-
test("two jobs sharing the same stage") {
176+
ignore("two jobs sharing the same stage") {
150177
// sem1: make sure cancel is issued after some tasks are launched
151178
// sem2: make sure the first stage is not finished until cancel is issued
152179
val sem1 = new Semaphore(0)
153180
val sem2 = new Semaphore(0)
154181

155182
sc = new SparkContext("local[2]", "test")
156-
sc.dagScheduler.addSparkListener(new SparkListener {
183+
sc.addSparkListener(new SparkListener {
157184
override def onTaskStart(taskStart: SparkListenerTaskStart) {
158185
sem1.release()
159186
}
@@ -179,7 +206,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
179206
intercept[SparkException] { f1.get() }
180207
intercept[SparkException] { f2.get() }
181208
}
182-
*/
209+
183210
def testCount() {
184211
// Cancel before launching any tasks
185212
{
@@ -238,3 +265,9 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
238265
}
239266
}
240267
}
268+
269+
270+
object JobCancellationSuite {
271+
val taskStartedSemaphore = new Semaphore(0)
272+
val taskCancelledSemaphore = new Semaphore(0)
273+
}

0 commit comments

Comments
 (0)