Skip to content

[SPARK-27992][PYTHON] Allow Python to join with connection thread to propagate errors #24834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

90 changes: 49 additions & 41 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.BUFFER_SIZE
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, SocketFuncServer}
import org.apache.spark.util._


Expand Down Expand Up @@ -137,8 +137,9 @@ private[spark] object PythonRDD extends Logging {
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, the secret for authentication, and a socket auth
* server object that can be used to join the JVM serving thread in Python.
*/
def runJob(
sc: SparkContext,
Expand All @@ -156,8 +157,9 @@ private[spark] object PythonRDD extends Logging {
/**
* A helper function to collect an RDD as an iterator, then serve it via socket.
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, the secret for authentication, and a socket auth
* server object that can be used to join the JVM serving thread in Python.
*/
def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
Expand All @@ -168,58 +170,59 @@ private[spark] object PythonRDD extends Logging {
* are collected as separate jobs, by order of index. Partition data is first requested by a
* non-zero integer to start a collection job. The response is prefaced by an integer with 1
* meaning partition data will be served, 0 meaning the local iterator has been consumed,
* and -1 meaining an error occurred during collection. This function is used by
* and -1 meaning an error occurred during collection. This function is used by
* pyspark.rdd._local_iterator_from_socket().
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from these jobs, and the secret for authentication.
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, the secret for authentication, and a socket auth
* server object that can be used to join the JVM serving thread in Python.
*/
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
val (port, secret) = SocketAuthServer.setupOneConnectionServer(
authHelper, "serve toLocalIterator") { s =>
val out = new DataOutputStream(s.getOutputStream)
val in = new DataInputStream(s.getInputStream)
Utils.tryWithSafeFinally {

val handleFunc = (sock: Socket) => {
val out = new DataOutputStream(sock.getOutputStream)
val in = new DataInputStream(sock.getInputStream)
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
}

// Read request for data and send next partition if nonzero
// Write data until iteration is complete, client stops iteration, or error occurs
var complete = false
while (!complete && in.readInt() != 0) {
if (collectPartitionIter.hasNext) {
try {
// Attempt to collect the next partition
val partitionArray = collectPartitionIter.next()

// Send response there is a partition to read
out.writeInt(1)

// Write the next object and signal end of data for this iteration
writeIteratorToStream(partitionArray.toIterator, out)
out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
out.flush()
} catch {
case e: SparkException =>
// Send response that an error occurred followed by error message
out.writeInt(-1)
writeUTF(e.getMessage, out)
complete = true
}
while (!complete) {

// Read request for data, value of zero will stop iteration or non-zero to continue
if (in.readInt() == 0) {
complete = true
} else if (collectPartitionIter.hasNext) {

// Client requested more data, attempt to collect the next partition
val partitionArray = collectPartitionIter.next()

// Send response there is a partition to read
out.writeInt(1)

// Write the next object and signal end of data for this iteration
writeIteratorToStream(partitionArray.toIterator, out)
out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
out.flush()
} else {
// Send response there are no more partitions to read and close
out.writeInt(0)
complete = true
}
}
} {
})(catchBlock = {
// Send response that an error occurred, original exception is re-thrown
out.writeInt(-1)
}, finallyBlock = {
out.close()
in.close()
}
})
}
Array(port, secret)

val server = new SocketFuncServer(authHelper, "serve toLocalIterator", handleFunc)
Array(server.port, server.secret, server)
}

def readRDDFromFile(
Expand Down Expand Up @@ -443,8 +446,9 @@ private[spark] object PythonRDD extends Logging {
*
* The thread will terminate after all the data are sent or any exceptions happen.
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, the secret for authentication, and a socket auth
* server object that can be used to join the JVM serving thread in Python.
*/
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
serveToStream(threadName) { out =>
Expand All @@ -464,10 +468,14 @@ private[spark] object PythonRDD extends Logging {
*
* The thread will terminate after the block of code is executed or any
* exceptions happen.
*
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, the secret for authentication, and a socket auth
* server object that can be used to join the JVM serving thread in Python.
*/
private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
SocketAuthHelper.serveToStream(threadName, authHelper)(writeFunc)
SocketAuthServer.serveToStream(threadName, authHelper)(writeFunc)
}

private def getMergedConf(confAsMap: java.util.HashMap[String, String],
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/api/r/RRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
import org.apache.spark.security.SocketAuthServer

private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
parent: RDD[T],
Expand Down Expand Up @@ -166,7 +166,7 @@ private[spark] object RRDD {

private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
SocketAuthHelper.serveToStream(threadName, new RAuthHelper(SparkEnv.get.conf))(writeFunc)
SocketAuthServer.serveToStream(threadName, new RAuthHelper(SparkEnv.get.conf))(writeFunc)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.security

import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, OutputStream}
import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket
import java.nio.charset.StandardCharsets.UTF_8

Expand Down Expand Up @@ -113,21 +113,4 @@ private[spark] class SocketAuthHelper(conf: SparkConf) {
dout.write(bytes, 0, bytes.length)
dout.flush()
}

}

private[spark] object SocketAuthHelper {
def serveToStream(
threadName: String,
authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = {
val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { s =>
val out = new BufferedOutputStream(s.getOutputStream())
Utils.tryWithSafeFinally {
writeFunc(out)
} {
out.close()
}
}
Array(port, secret)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.security

import java.io.{BufferedOutputStream, OutputStream}
import java.net.{InetAddress, ServerSocket, Socket}

import scala.concurrent.Promise
Expand All @@ -25,12 +26,15 @@ import scala.util.Try

import org.apache.spark.SparkEnv
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.{ThreadUtils, Utils}


/**
* Creates a server in the JVM to communicate with external processes (e.g., Python and R) for
* handling one batch of data, with authentication and error handling.
*
* The socket server can only accept one connection, or close if no connection
* in 15 seconds.
*/
private[spark] abstract class SocketAuthServer[T](
authHelper: SocketAuthHelper,
Expand All @@ -41,10 +45,30 @@ private[spark] abstract class SocketAuthServer[T](

private val promise = Promise[T]()

val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { sock =>
promise.complete(Try(handleConnection(sock)))
private def startServer(): (Int, String) = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)

new Thread(threadName) {
setDaemon(true)
override def run(): Unit = {
var sock: Socket = null
try {
sock = serverSocket.accept()
authHelper.authClient(sock)
promise.complete(Try(handleConnection(sock)))
} finally {
JavaUtils.closeQuietly(serverSocket)
JavaUtils.closeQuietly(sock)
}
}
}.start()
(serverSocket.getLocalPort, authHelper.secret)
}

val (port, secret) = startServer()

/**
* Handle a connection which has already been authenticated. Any error from this function
* will clean up this connection and the entire server, and get propagated to [[getResult]].
Expand All @@ -66,42 +90,50 @@ private[spark] abstract class SocketAuthServer[T](

}

/**
* Create a socket server class and run user function on the socket in a background thread
* that can read and write to the socket input/output streams. The function is passed in a
* socket that has been connected and authenticated.
*/
private[spark] class SocketFuncServer(
authHelper: SocketAuthHelper,
threadName: String,
func: Socket => Unit) extends SocketAuthServer[Unit](authHelper, threadName) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need SockAuthServer.setupOneConnectionServer if we have this also, so it could be cleaned up

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed SockAuthServer.setupOneConnectionServer and replaced usage with SocketFuncServer


override def handleConnection(sock: Socket): Unit = {
func(sock)
}
}

private[spark] object SocketAuthServer {

/**
* Create a socket server and run user function on the socket in a background thread.
* Convenience function to create a socket server and run a user function in a background
* thread to write to an output stream.
*
* The socket server can only accept one connection, or close if no connection
* in 15 seconds.
*
* The thread will terminate after the supplied user function, or if there are any exceptions.
*
* If you need to get a result of the supplied function, create a subclass of [[SocketAuthServer]]
*
* @return The port number of a local socket and the secret for authentication.
* @param threadName Name for the background serving thread.
* @param authHelper SocketAuthHelper for authentication
* @param writeFunc User function to write to a given OutputStream
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, the secret for authentication, and a socket auth
* server object that can be used to join the JVM serving thread in Python.
*/
def setupOneConnectionServer(
authHelper: SocketAuthHelper,
threadName: String)
(func: Socket => Unit): (Int, String) = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)

new Thread(threadName) {
setDaemon(true)
override def run(): Unit = {
var sock: Socket = null
try {
sock = serverSocket.accept()
authHelper.authClient(sock)
func(sock)
} finally {
JavaUtils.closeQuietly(serverSocket)
JavaUtils.closeQuietly(sock)
}
def serveToStream(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this from SocketAuthHelper because it seemed more fitting to be here

threadName: String,
authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = {
val handleFunc = (sock: Socket) => {
val out = new BufferedOutputStream(sock.getOutputStream())
Utils.tryWithSafeFinally {
writeFunc(out)
} {
out.close()
}
}.start()
(serverSocket.getLocalPort, authHelper.secret)
}

val server = new SocketFuncServer(authHelper, threadName, handleFunc)
Array(server.port, server.secret, server)
}
}
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,9 @@ private[spark] object Utils extends Logging {
originalThrowable = cause
try {
logError("Aborting task", originalThrowable)
TaskContext.get().markTaskFailed(originalThrowable)
if (TaskContext.get() != null) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using this utility here https://github.com/apache/spark/pull/24834/files#diff-0a67bc4d171abe4df8eb305b0f4123a2R184, where the task fails and completes before hitting the catchBlock, so TaskContext.get() returns a null

TaskContext.get().markTaskFailed(originalThrowable)
}
catchBlock
} catch {
case t: Throwable =>
Expand Down
Loading