@@ -56,122 +56,37 @@ private[spark] class PythonRDD[T: ClassTag](
56
56
val env = SparkEnv .get
57
57
val worker : Socket = env.createPythonWorker(pythonExec, envVars.toMap)
58
58
59
- // Ensure worker socket is closed on task completion. Closing sockets is idempotent.
60
- context.addOnCompleteCallback(() =>
59
+ // Start a thread to feed the process input from our parent's iterator
60
+ val writerThread = new WriterThread (env, worker, split, context)
61
+
62
+ context.addOnCompleteCallback { () =>
63
+ writerThread.shutdownOnTaskCompletion()
64
+
65
+ // Cleanup the worker socket. This will also cause the Python worker to exit.
61
66
try {
62
67
worker.close()
63
68
} catch {
64
69
case e : Exception => logWarning(" Failed to close worker socket" , e)
65
70
}
66
- )
67
-
68
- @ volatile var readerException : Exception = null
69
-
70
- // Start a thread to feed the process input from our parent's iterator
71
- new Thread (" stdin writer for " + pythonExec) {
72
- override def run () {
73
- try {
74
- SparkEnv .set(env)
75
- val stream = new BufferedOutputStream (worker.getOutputStream, bufferSize)
76
- val dataOut = new DataOutputStream (stream)
77
- // Partition index
78
- dataOut.writeInt(split.index)
79
- // sparkFilesDir
80
- PythonRDD .writeUTF(SparkFiles .getRootDirectory, dataOut)
81
- // Broadcast variables
82
- dataOut.writeInt(broadcastVars.length)
83
- for (broadcast <- broadcastVars) {
84
- dataOut.writeLong(broadcast.id)
85
- dataOut.writeInt(broadcast.value.length)
86
- dataOut.write(broadcast.value)
87
- }
88
- // Python includes (*.zip and *.egg files)
89
- dataOut.writeInt(pythonIncludes.length)
90
- for (include <- pythonIncludes) {
91
- PythonRDD .writeUTF(include, dataOut)
92
- }
93
- dataOut.flush()
94
- // Serialized command:
95
- dataOut.writeInt(command.length)
96
- dataOut.write(command)
97
- // Data values
98
- PythonRDD .writeIteratorToStream(parent.iterator(split, context), dataOut)
99
- dataOut.flush()
100
- worker.shutdownOutput()
101
- } catch {
102
-
103
- case e : java.io.FileNotFoundException =>
104
- readerException = e
105
- Try (worker.shutdownOutput()) // kill Python worker process
106
-
107
- case e : IOException =>
108
- // This can happen for legitimate reasons if the Python code stops returning data
109
- // before we are done passing elements through, e.g., for take(). Just log a message to
110
- // say it happened (as it could also be hiding a real IOException from a data source).
111
- logInfo(" stdin writer to Python finished early (may not be an error)" , e)
112
-
113
- case e : Exception =>
114
- // We must avoid throwing exceptions here, because the thread uncaught exception handler
115
- // will kill the whole executor (see Executor).
116
- readerException = e
117
- Try (worker.shutdownOutput()) // kill Python worker process
118
- }
119
- }
120
- }.start()
121
-
122
- // Necessary to distinguish between a task that has failed and a task that is finished
123
- @ volatile var complete : Boolean = false
124
-
125
- // It is necessary to have a monitor thread for python workers if the user cancels with
126
- // interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
127
- // threads can block indefinitely.
128
- new Thread (s " Worker Monitor for $pythonExec" ) {
129
- override def run () {
130
- // Kill the worker if it is interrupted or completed
131
- // When a python task completes, the context is always set to interupted
132
- while (! context.interrupted) {
133
- Thread .sleep(2000 )
134
- }
135
- if (! complete) {
136
- try {
137
- logWarning(" Incomplete task interrupted: Attempting to kill Python Worker" )
138
- env.destroyPythonWorker(pythonExec, envVars.toMap)
139
- } catch {
140
- case e : Exception =>
141
- logError(" Exception when trying to kill worker" , e)
142
- }
143
- }
144
- }
145
- }.start()
146
-
147
- /*
148
- * Partial fix for SPARK-1019: Attempts to stop reading the input stream since
149
- * other completion callbacks might invalidate the input. Because interruption
150
- * is not synchronous this still leaves a potential race where the interruption is
151
- * processed only after the stream becomes invalid.
152
- */
153
- context.addOnCompleteCallback{ () =>
154
- complete = true // Indicate that the task has completed successfully
155
- context.interrupted = true
156
71
}
157
72
73
+ writerThread.start()
74
+ new MonitorThread (env, worker, context).start()
75
+
158
76
// Return an iterator that read lines from the process's stdout
159
77
val stream = new DataInputStream (new BufferedInputStream (worker.getInputStream, bufferSize))
160
78
val stdoutIterator = new Iterator [Array [Byte ]] {
161
79
def next (): Array [Byte ] = {
162
80
val obj = _nextObj
163
81
if (hasNext) {
164
- // FIXME: can deadlock if worker is waiting for us to
165
- // respond to current message (currently irrelevant because
166
- // output is shutdown before we read any input)
167
82
_nextObj = read()
168
83
}
169
84
obj
170
85
}
171
86
172
87
private def read (): Array [Byte ] = {
173
- if (readerException != null ) {
174
- throw readerException
88
+ if (writerThread.exception.isDefined ) {
89
+ throw writerThread.exception.get
175
90
}
176
91
try {
177
92
stream.readInt() match {
@@ -190,13 +105,14 @@ private[spark] class PythonRDD[T: ClassTag](
190
105
val total = finishTime - startTime
191
106
logInfo(" Times: total = %s, boot = %s, init = %s, finish = %s" .format(total, boot,
192
107
init, finish))
193
- read
108
+ read()
194
109
case SpecialLengths .PYTHON_EXCEPTION_THROWN =>
195
110
// Signals that an exception has been thrown in python
196
111
val exLength = stream.readInt()
197
112
val obj = new Array [Byte ](exLength)
198
113
stream.readFully(obj)
199
- throw new PythonException (new String (obj, " utf-8" ), readerException)
114
+ throw new PythonException (new String (obj, " utf-8" ),
115
+ writerThread.exception.getOrElse(null ))
200
116
case SpecialLengths .END_OF_DATA_SECTION =>
201
117
// We've finished the data section of the output, but we can still
202
118
// read some accumulator updates:
@@ -210,10 +126,15 @@ private[spark] class PythonRDD[T: ClassTag](
210
126
Array .empty[Byte ]
211
127
}
212
128
} catch {
213
- case e : Exception if readerException != null =>
129
+
130
+ case e : Exception if context.interrupted =>
131
+ logDebug(" Exception thrown after task interruption" , e)
132
+ throw new TaskKilledException
133
+
134
+ case e : Exception if writerThread.exception.isDefined =>
214
135
logError(" Python worker exited unexpectedly (crashed)" , e)
215
- logError(" Python crash may have been caused by prior exception:" , readerException )
216
- throw readerException
136
+ logError(" This may have been caused by a prior exception:" , writerThread.exception.get )
137
+ throw writerThread.exception.get
217
138
218
139
case eof : EOFException =>
219
140
throw new SparkException (" Python worker exited unexpectedly (crashed)" , eof)
@@ -224,10 +145,100 @@ private[spark] class PythonRDD[T: ClassTag](
224
145
225
146
def hasNext = _nextObj.length != 0
226
147
}
227
- stdoutIterator
148
+ new InterruptibleIterator (context, stdoutIterator)
228
149
}
229
150
230
151
val asJavaRDD : JavaRDD [Array [Byte ]] = JavaRDD .fromRDD(this )
152
+
153
+ /**
154
+ * The thread responsible for writing the data from the PythonRDD's parent iterator to the
155
+ * Python process.
156
+ */
157
+ class WriterThread (env : SparkEnv , worker : Socket , split : Partition , context : TaskContext )
158
+ extends Thread (s " stdout writer for $pythonExec" ) {
159
+
160
+ @ volatile private var _exception : Exception = null
161
+
162
+ setDaemon(true )
163
+
164
+ /** Contains the exception thrown while writing the parent iterator to the Python process. */
165
+ def exception : Option [Exception ] = Option (_exception)
166
+
167
+ /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
168
+ def shutdownOnTaskCompletion () {
169
+ assert(context.completed)
170
+ this .interrupt()
171
+ }
172
+
173
+ override def run () {
174
+ try {
175
+ SparkEnv .set(env)
176
+ val stream = new BufferedOutputStream (worker.getOutputStream, bufferSize)
177
+ val dataOut = new DataOutputStream (stream)
178
+ // Partition index
179
+ dataOut.writeInt(split.index)
180
+ // sparkFilesDir
181
+ PythonRDD .writeUTF(SparkFiles .getRootDirectory, dataOut)
182
+ // Broadcast variables
183
+ dataOut.writeInt(broadcastVars.length)
184
+ for (broadcast <- broadcastVars) {
185
+ dataOut.writeLong(broadcast.id)
186
+ dataOut.writeInt(broadcast.value.length)
187
+ dataOut.write(broadcast.value)
188
+ }
189
+ // Python includes (*.zip and *.egg files)
190
+ dataOut.writeInt(pythonIncludes.length)
191
+ for (include <- pythonIncludes) {
192
+ PythonRDD .writeUTF(include, dataOut)
193
+ }
194
+ dataOut.flush()
195
+ // Serialized command:
196
+ dataOut.writeInt(command.length)
197
+ dataOut.write(command)
198
+ // Data values
199
+ PythonRDD .writeIteratorToStream(parent.iterator(split, context), dataOut)
200
+ dataOut.flush()
201
+ } catch {
202
+ case e : Exception if context.completed || context.interrupted =>
203
+ logDebug(" Exception thrown after task completion (likely due to cleanup)" , e)
204
+
205
+ case e : Exception =>
206
+ // We must avoid throwing exceptions here, because the thread uncaught exception handler
207
+ // will kill the whole executor (see org.apache.spark.executor.Executor).
208
+ _exception = e
209
+ } finally {
210
+ Try (worker.shutdownOutput()) // kill Python worker process
211
+ }
212
+ }
213
+ }
214
+
215
+ /**
216
+ * It is necessary to have a monitor thread for python workers if the user cancels with
217
+ * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
218
+ * threads can block indefinitely.
219
+ */
220
+ class MonitorThread (env : SparkEnv , worker : Socket , context : TaskContext )
221
+ extends Thread (s " Worker Monitor for $pythonExec" ) {
222
+
223
+ setDaemon(true )
224
+
225
+ override def run () {
226
+ // Kill the worker if it is interrupted, checking until task completion.
227
+ // TODO: This has a race condition if interruption occurs, as completed may still become true.
228
+ while (! context.interrupted && ! context.completed) {
229
+ Thread .sleep(2000 )
230
+ }
231
+ if (! context.completed) {
232
+ try {
233
+ logWarning(" Incomplete task interrupted: Attempting to kill Python Worker" )
234
+ env.destroyPythonWorker(pythonExec, envVars.toMap)
235
+ } catch {
236
+ case e : Exception =>
237
+ logError(" Exception when trying to kill worker" , e)
238
+ }
239
+ }
240
+ }
241
+ }
231
242
}
232
243
233
244
/** Thrown for exceptions in user Python code. */
0 commit comments