-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-27992][PYTHON] Allow Python to join with connection thread to propagate errors #24834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
519926f
3a52960
b209e0a
c9f7fe9
2fddb43
785ce4f
ead8978
5fd8684
20eb748
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
|
||
package org.apache.spark.security | ||
|
||
import java.io.{BufferedOutputStream, OutputStream} | ||
import java.net.{InetAddress, ServerSocket, Socket} | ||
|
||
import scala.concurrent.Promise | ||
|
@@ -25,12 +26,15 @@ import scala.util.Try | |
|
||
import org.apache.spark.SparkEnv | ||
import org.apache.spark.network.util.JavaUtils | ||
import org.apache.spark.util.ThreadUtils | ||
import org.apache.spark.util.{ThreadUtils, Utils} | ||
|
||
|
||
/** | ||
* Creates a server in the JVM to communicate with external processes (e.g., Python and R) for | ||
* handling one batch of data, with authentication and error handling. | ||
* | ||
* The socket server can only accept one connection, or close if no connection | ||
* in 15 seconds. | ||
*/ | ||
private[spark] abstract class SocketAuthServer[T]( | ||
authHelper: SocketAuthHelper, | ||
|
@@ -41,10 +45,30 @@ private[spark] abstract class SocketAuthServer[T]( | |
|
||
private val promise = Promise[T]() | ||
|
||
val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { sock => | ||
promise.complete(Try(handleConnection(sock))) | ||
private def startServer(): (Int, String) = { | ||
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) | ||
// Close the socket if no connection in 15 seconds | ||
serverSocket.setSoTimeout(15000) | ||
|
||
new Thread(threadName) { | ||
setDaemon(true) | ||
override def run(): Unit = { | ||
var sock: Socket = null | ||
try { | ||
sock = serverSocket.accept() | ||
authHelper.authClient(sock) | ||
promise.complete(Try(handleConnection(sock))) | ||
} finally { | ||
JavaUtils.closeQuietly(serverSocket) | ||
JavaUtils.closeQuietly(sock) | ||
} | ||
} | ||
}.start() | ||
(serverSocket.getLocalPort, authHelper.secret) | ||
} | ||
|
||
val (port, secret) = startServer() | ||
|
||
/** | ||
* Handle a connection which has already been authenticated. Any error from this function | ||
* will clean up this connection and the entire server, and get propagated to [[getResult]]. | ||
|
@@ -66,42 +90,50 @@ private[spark] abstract class SocketAuthServer[T]( | |
|
||
} | ||
|
||
/** | ||
* Create a socket server class and run user function on the socket in a background thread | ||
* that can read and write to the socket input/output streams. The function is passed in a | ||
* socket that has been connected and authenticated. | ||
*/ | ||
private[spark] class SocketFuncServer( | ||
authHelper: SocketAuthHelper, | ||
threadName: String, | ||
func: Socket => Unit) extends SocketAuthServer[Unit](authHelper, threadName) { | ||
|
||
override def handleConnection(sock: Socket): Unit = { | ||
func(sock) | ||
} | ||
} | ||
|
||
private[spark] object SocketAuthServer { | ||
|
||
/** | ||
* Create a socket server and run user function on the socket in a background thread. | ||
* Convenience function to create a socket server and run a user function in a background | ||
* thread to write to an output stream. | ||
* | ||
* The socket server can only accept one connection, or close if no connection | ||
* in 15 seconds. | ||
* | ||
* The thread will terminate after the supplied user function, or if there are any exceptions. | ||
* | ||
* If you need to get a result of the supplied function, create a subclass of [[SocketAuthServer]] | ||
* | ||
* @return The port number of a local socket and the secret for authentication. | ||
* @param threadName Name for the background serving thread. | ||
* @param authHelper SocketAuthHelper for authentication | ||
* @param writeFunc User function to write to a given OutputStream | ||
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the | ||
* data collected from this job, the secret for authentication, and a socket auth | ||
* server object that can be used to join the JVM serving thread in Python. | ||
*/ | ||
def setupOneConnectionServer( | ||
authHelper: SocketAuthHelper, | ||
threadName: String) | ||
(func: Socket => Unit): (Int, String) = { | ||
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) | ||
// Close the socket if no connection in 15 seconds | ||
serverSocket.setSoTimeout(15000) | ||
|
||
new Thread(threadName) { | ||
setDaemon(true) | ||
override def run(): Unit = { | ||
var sock: Socket = null | ||
try { | ||
sock = serverSocket.accept() | ||
authHelper.authClient(sock) | ||
func(sock) | ||
} finally { | ||
JavaUtils.closeQuietly(serverSocket) | ||
JavaUtils.closeQuietly(sock) | ||
} | ||
def serveToStream( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved this from |
||
threadName: String, | ||
authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = { | ||
val handleFunc = (sock: Socket) => { | ||
val out = new BufferedOutputStream(sock.getOutputStream()) | ||
Utils.tryWithSafeFinally { | ||
writeFunc(out) | ||
} { | ||
out.close() | ||
} | ||
}.start() | ||
(serverSocket.getLocalPort, authHelper.secret) | ||
} | ||
|
||
val server = new SocketFuncServer(authHelper, threadName, handleFunc) | ||
Array(server.port, server.secret, server) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1389,7 +1389,9 @@ private[spark] object Utils extends Logging { | |
originalThrowable = cause | ||
try { | ||
logError("Aborting task", originalThrowable) | ||
TaskContext.get().markTaskFailed(originalThrowable) | ||
if (TaskContext.get() != null) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using this utility here https://github.com/apache/spark/pull/24834/files#diff-0a67bc4d171abe4df8eb305b0f4123a2R184, where the task fails and completes before hitting the |
||
TaskContext.get().markTaskFailed(originalThrowable) | ||
} | ||
catchBlock | ||
} catch { | ||
case t: Throwable => | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need
SockAuthServer.setupOneConnectionServer
if we have this also, so it could be cleaned upThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed
SockAuthServer.setupOneConnectionServer
and replaced usage with SocketFuncServer