Skip to content

Commit 69455d1

Browse files
committed
Merge branch 'master' of github.com:apache/spark into task-context
Conflicts: core/src/main/scala/org/apache/spark/TaskContext.scala
2 parents c471490 + 3308722 commit 69455d1

File tree

86 files changed

+2075
-315
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+2075
-315
lines changed

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,25 @@ class TaskContext(
4141
// List of callback functions to execute when the task completes.
4242
@transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
4343

44-
// Whether the corresponding task has been killed
44+
// Whether the corresponding task has been killed.
4545
@volatile var interrupted: Boolean = false
4646

47+
// Whether the task has completed, before the onCompleteCallbacks are executed.
48+
@volatile var completed: Boolean = false
49+
4750
/**
4851
* Add a callback function to be executed on task completion. An example use
4952
* is for HadoopRDD to register a callback to close the input stream.
53+
* Will be called in any situation - success, failure, or cancellation.
5054
* @param f Callback function.
5155
*/
5256
def addOnCompleteCallback(f: () => Unit) {
5357
onCompleteCallbacks += f
5458
}
5559

5660
def executeOnCompleteCallbacks() {
61+
completed = true
5762
// Process complete callbacks in the reverse order of registration
58-
onCompleteCallbacks.reverse.foreach{_()}
63+
onCompleteCallbacks.reverse.foreach{ _() }
5964
}
6065
}

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 114 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -56,122 +56,37 @@ private[spark] class PythonRDD[T: ClassTag](
5656
val env = SparkEnv.get
5757
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
5858

59-
// Ensure worker socket is closed on task completion. Closing sockets is idempotent.
60-
context.addOnCompleteCallback(() =>
59+
// Start a thread to feed the process input from our parent's iterator
60+
val writerThread = new WriterThread(env, worker, split, context)
61+
62+
context.addOnCompleteCallback { () =>
63+
writerThread.shutdownOnTaskCompletion()
64+
65+
// Cleanup the worker socket. This will also cause the Python worker to exit.
6166
try {
6267
worker.close()
6368
} catch {
6469
case e: Exception => logWarning("Failed to close worker socket", e)
6570
}
66-
)
67-
68-
@volatile var readerException: Exception = null
69-
70-
// Start a thread to feed the process input from our parent's iterator
71-
new Thread("stdin writer for " + pythonExec) {
72-
override def run() {
73-
try {
74-
SparkEnv.set(env)
75-
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
76-
val dataOut = new DataOutputStream(stream)
77-
// Partition index
78-
dataOut.writeInt(split.index)
79-
// sparkFilesDir
80-
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
81-
// Broadcast variables
82-
dataOut.writeInt(broadcastVars.length)
83-
for (broadcast <- broadcastVars) {
84-
dataOut.writeLong(broadcast.id)
85-
dataOut.writeInt(broadcast.value.length)
86-
dataOut.write(broadcast.value)
87-
}
88-
// Python includes (*.zip and *.egg files)
89-
dataOut.writeInt(pythonIncludes.length)
90-
for (include <- pythonIncludes) {
91-
PythonRDD.writeUTF(include, dataOut)
92-
}
93-
dataOut.flush()
94-
// Serialized command:
95-
dataOut.writeInt(command.length)
96-
dataOut.write(command)
97-
// Data values
98-
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
99-
dataOut.flush()
100-
worker.shutdownOutput()
101-
} catch {
102-
103-
case e: java.io.FileNotFoundException =>
104-
readerException = e
105-
Try(worker.shutdownOutput()) // kill Python worker process
106-
107-
case e: IOException =>
108-
// This can happen for legitimate reasons if the Python code stops returning data
109-
// before we are done passing elements through, e.g., for take(). Just log a message to
110-
// say it happened (as it could also be hiding a real IOException from a data source).
111-
logInfo("stdin writer to Python finished early (may not be an error)", e)
112-
113-
case e: Exception =>
114-
// We must avoid throwing exceptions here, because the thread uncaught exception handler
115-
// will kill the whole executor (see Executor).
116-
readerException = e
117-
Try(worker.shutdownOutput()) // kill Python worker process
118-
}
119-
}
120-
}.start()
121-
122-
// Necessary to distinguish between a task that has failed and a task that is finished
123-
@volatile var complete: Boolean = false
124-
125-
// It is necessary to have a monitor thread for python workers if the user cancels with
126-
// interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
127-
// threads can block indefinitely.
128-
new Thread(s"Worker Monitor for $pythonExec") {
129-
override def run() {
130-
// Kill the worker if it is interrupted or completed
131-
// When a python task completes, the context is always set to interupted
132-
while (!context.interrupted) {
133-
Thread.sleep(2000)
134-
}
135-
if (!complete) {
136-
try {
137-
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
138-
env.destroyPythonWorker(pythonExec, envVars.toMap)
139-
} catch {
140-
case e: Exception =>
141-
logError("Exception when trying to kill worker", e)
142-
}
143-
}
144-
}
145-
}.start()
146-
147-
/*
148-
* Partial fix for SPARK-1019: Attempts to stop reading the input stream since
149-
* other completion callbacks might invalidate the input. Because interruption
150-
* is not synchronous this still leaves a potential race where the interruption is
151-
* processed only after the stream becomes invalid.
152-
*/
153-
context.addOnCompleteCallback{ () =>
154-
complete = true // Indicate that the task has completed successfully
155-
context.interrupted = true
15671
}
15772

73+
writerThread.start()
74+
new MonitorThread(env, worker, context).start()
75+
15876
// Return an iterator that read lines from the process's stdout
15977
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
16078
val stdoutIterator = new Iterator[Array[Byte]] {
16179
def next(): Array[Byte] = {
16280
val obj = _nextObj
16381
if (hasNext) {
164-
// FIXME: can deadlock if worker is waiting for us to
165-
// respond to current message (currently irrelevant because
166-
// output is shutdown before we read any input)
16782
_nextObj = read()
16883
}
16984
obj
17085
}
17186

17287
private def read(): Array[Byte] = {
173-
if (readerException != null) {
174-
throw readerException
88+
if (writerThread.exception.isDefined) {
89+
throw writerThread.exception.get
17590
}
17691
try {
17792
stream.readInt() match {
@@ -190,13 +105,14 @@ private[spark] class PythonRDD[T: ClassTag](
190105
val total = finishTime - startTime
191106
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
192107
init, finish))
193-
read
108+
read()
194109
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
195110
// Signals that an exception has been thrown in python
196111
val exLength = stream.readInt()
197112
val obj = new Array[Byte](exLength)
198113
stream.readFully(obj)
199-
throw new PythonException(new String(obj, "utf-8"), readerException)
114+
throw new PythonException(new String(obj, "utf-8"),
115+
writerThread.exception.getOrElse(null))
200116
case SpecialLengths.END_OF_DATA_SECTION =>
201117
// We've finished the data section of the output, but we can still
202118
// read some accumulator updates:
@@ -210,10 +126,15 @@ private[spark] class PythonRDD[T: ClassTag](
210126
Array.empty[Byte]
211127
}
212128
} catch {
213-
case e: Exception if readerException != null =>
129+
130+
case e: Exception if context.interrupted =>
131+
logDebug("Exception thrown after task interruption", e)
132+
throw new TaskKilledException
133+
134+
case e: Exception if writerThread.exception.isDefined =>
214135
logError("Python worker exited unexpectedly (crashed)", e)
215-
logError("Python crash may have been caused by prior exception:", readerException)
216-
throw readerException
136+
logError("This may have been caused by a prior exception:", writerThread.exception.get)
137+
throw writerThread.exception.get
217138

218139
case eof: EOFException =>
219140
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
@@ -224,10 +145,100 @@ private[spark] class PythonRDD[T: ClassTag](
224145

225146
def hasNext = _nextObj.length != 0
226147
}
227-
stdoutIterator
148+
new InterruptibleIterator(context, stdoutIterator)
228149
}
229150

230151
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
152+
153+
/**
154+
* The thread responsible for writing the data from the PythonRDD's parent iterator to the
155+
* Python process.
156+
*/
157+
class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext)
158+
extends Thread(s"stdout writer for $pythonExec") {
159+
160+
@volatile private var _exception: Exception = null
161+
162+
setDaemon(true)
163+
164+
/** Contains the exception thrown while writing the parent iterator to the Python process. */
165+
def exception: Option[Exception] = Option(_exception)
166+
167+
/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
168+
def shutdownOnTaskCompletion() {
169+
assert(context.completed)
170+
this.interrupt()
171+
}
172+
173+
override def run() {
174+
try {
175+
SparkEnv.set(env)
176+
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
177+
val dataOut = new DataOutputStream(stream)
178+
// Partition index
179+
dataOut.writeInt(split.index)
180+
// sparkFilesDir
181+
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
182+
// Broadcast variables
183+
dataOut.writeInt(broadcastVars.length)
184+
for (broadcast <- broadcastVars) {
185+
dataOut.writeLong(broadcast.id)
186+
dataOut.writeInt(broadcast.value.length)
187+
dataOut.write(broadcast.value)
188+
}
189+
// Python includes (*.zip and *.egg files)
190+
dataOut.writeInt(pythonIncludes.length)
191+
for (include <- pythonIncludes) {
192+
PythonRDD.writeUTF(include, dataOut)
193+
}
194+
dataOut.flush()
195+
// Serialized command:
196+
dataOut.writeInt(command.length)
197+
dataOut.write(command)
198+
// Data values
199+
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
200+
dataOut.flush()
201+
} catch {
202+
case e: Exception if context.completed || context.interrupted =>
203+
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
204+
205+
case e: Exception =>
206+
// We must avoid throwing exceptions here, because the thread uncaught exception handler
207+
// will kill the whole executor (see org.apache.spark.executor.Executor).
208+
_exception = e
209+
} finally {
210+
Try(worker.shutdownOutput()) // kill Python worker process
211+
}
212+
}
213+
}
214+
215+
/**
216+
* It is necessary to have a monitor thread for python workers if the user cancels with
217+
* interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
218+
* threads can block indefinitely.
219+
*/
220+
class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
221+
extends Thread(s"Worker Monitor for $pythonExec") {
222+
223+
setDaemon(true)
224+
225+
override def run() {
226+
// Kill the worker if it is interrupted, checking until task completion.
227+
// TODO: This has a race condition if interruption occurs, as completed may still become true.
228+
while (!context.interrupted && !context.completed) {
229+
Thread.sleep(2000)
230+
}
231+
if (!context.completed) {
232+
try {
233+
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
234+
env.destroyPythonWorker(pythonExec, envVars.toMap)
235+
} catch {
236+
case e: Exception =>
237+
logError("Exception when trying to kill worker", e)
238+
}
239+
}
240+
}
241+
}
231242
}
232243

233244
/** Thrown for exceptions in user Python code. */

core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,16 @@ object CommandUtils extends Logging {
4747
*/
4848
def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = {
4949
val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M")
50-
// Note, this will coalesce multiple options into a single command component
5150
val extraOpts = command.extraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq())
5251

52+
// Exists for backwards compatibility with older Spark versions
53+
val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString)
54+
.getOrElse(Nil)
55+
if (workerLocalOpts.length > 0) {
56+
logWarning("SPARK_JAVA_OPTS was set on the worker. It is deprecated in Spark 1.0.")
57+
logWarning("Set SPARK_LOCAL_DIRS for node-specific storage locations.")
58+
}
59+
5360
val libraryOpts =
5461
if (command.libraryPathEntries.size > 0) {
5562
val joined = command.libraryPathEntries.mkString(File.pathSeparator)
@@ -66,7 +73,7 @@ object CommandUtils extends Logging {
6673
val userClassPath = command.classPathEntries ++ Seq(classPath)
6774

6875
Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++
69-
libraryOpts ++ extraOpts ++ memoryOpts
76+
libraryOpts ++ extraOpts ++ workerLocalOpts ++ memoryOpts
7077
}
7178

7279
/** Spawn a thread that will redirect a given stream to a file */

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ abstract class RDD[T: ClassTag](
128128
@transient var name: String = null
129129

130130
/** Assign a name to this RDD */
131-
def setName(_name: String): RDD[T] = {
131+
def setName(_name: String): this.type = {
132132
name = _name
133133
this
134134
}
@@ -138,7 +138,7 @@ abstract class RDD[T: ClassTag](
138138
* it is computed. This can only be used to assign a new storage level if the RDD does not
139139
* have a storage level set yet..
140140
*/
141-
def persist(newLevel: StorageLevel): RDD[T] = {
141+
def persist(newLevel: StorageLevel): this.type = {
142142
// TODO: Handle changes of StorageLevel
143143
if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
144144
throw new UnsupportedOperationException(
@@ -152,18 +152,18 @@ abstract class RDD[T: ClassTag](
152152
}
153153

154154
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
155-
def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY)
155+
def persist(): this.type = persist(StorageLevel.MEMORY_ONLY)
156156

157157
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
158-
def cache(): RDD[T] = persist()
158+
def cache(): this.type = persist()
159159

160160
/**
161161
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
162162
*
163163
* @param blocking Whether to block until all blocks are deleted.
164164
* @return This RDD.
165165
*/
166-
def unpersist(blocking: Boolean = true): RDD[T] = {
166+
def unpersist(blocking: Boolean = true): this.type = {
167167
logInfo("Removing RDD " + id + " from persistence list")
168168
sc.unpersistRDD(id, blocking)
169169
storageLevel = StorageLevel.NONE

0 commit comments

Comments
 (0)