From ad45299d047c10472fd3a86103930fe7c54a4cf1 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 21 Aug 2018 15:54:30 -0700 Subject: [PATCH] [SPARK-25095][PYSPARK] Python support for BarrierTaskContext ## What changes were proposed in this pull request? Add method `barrier()` and `getTaskInfos()` in python TaskContext, these two methods are only allowed for barrier tasks. ## How was this patch tested? Add new tests in `tests.py` Closes #22085 from jiangxb1987/python.barrier. Authored-by: Xingbo Jiang Signed-off-by: Xiangrui Meng --- .../spark/api/python/PythonRunner.scala | 106 +++++++++++++ python/pyspark/serializers.py | 7 + python/pyspark/taskcontext.py | 144 ++++++++++++++++++ python/pyspark/tests.py | 36 ++++- python/pyspark/worker.py | 16 +- 5 files changed, 305 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 7b31857588252..f8241915e4849 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -20,12 +20,14 @@ package org.apache.spark.api.python import java.io._ import java.net._ import java.nio.charset.StandardCharsets +import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -76,6 +78,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // TODO: support accumulator in multiple UDF protected val accumulator = funcs.head.funcs.head.accumulator + // Expose a ServerSocket to support method calls via socket from Python side. + private[spark] var serverSocket: Option[ServerSocket] = None + + // Authentication helper used when serving method calls via socket from Python side. + private lazy val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + def compute( inputIterator: Iterator[IN], partitionIndex: Int, @@ -180,7 +188,73 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) + // Init a ServerSocket to accept method calls from Python side. + val isBarrier = context.isInstanceOf[BarrierTaskContext] + if (isBarrier) { + serverSocket = Some(new ServerSocket(/* port */ 0, + /* backlog */ 1, + InetAddress.getByName("localhost"))) + // A call to accept() for ServerSocket shall block infinitely. + serverSocket.map(_.setSoTimeout(0)) + new Thread("accept-connections") { + setDaemon(true) + + override def run(): Unit = { + while (!serverSocket.get.isClosed()) { + var sock: Socket = null + try { + sock = serverSocket.get.accept() + // Wait for function call from python side. + sock.setSoTimeout(10000) + val input = new DataInputStream(sock.getInputStream()) + input.readInt() match { + case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => + // The barrier() function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(0) + barrierAndServe(sock) + + case _ => + val out = new DataOutputStream(new BufferedOutputStream( + sock.getOutputStream)) + writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) + } + } catch { + case e: SocketException if e.getMessage.contains("Socket closed") => + // It is possible that the ServerSocket is not closed, but the native socket + // has already been closed, we shall catch and silently ignore this case. + } finally { + if (sock != null) { + sock.close() + } + } + } + } + }.start() + } + val secret = if (isBarrier) { + authHelper.secret + } else { + "" + } + // Close ServerSocket on task completion. + serverSocket.foreach { server => + context.addTaskCompletionListener(_ => server.close()) + } + val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) + if (boundPort == -1) { + val message = "ServerSocket failed to bind to Java side." + logError(message) + throw new SparkException(message) + } else if (isBarrier) { + logDebug(s"Started ServerSocket on port $boundPort.") + } // Write out the TaskContextInfo + dataOut.writeBoolean(isBarrier) + dataOut.writeInt(boundPort) + val secretBytes = secret.getBytes(UTF_8) + dataOut.writeInt(secretBytes.length) + dataOut.write(secretBytes, 0, secretBytes.length) dataOut.writeInt(context.stageId()) dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) @@ -243,6 +317,32 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } } } + + /** + * Gateway to call BarrierTaskContext.barrier(). + */ + def barrierAndServe(sock: Socket): Unit = { + require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") + + authHelper.authClient(sock) + + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + try { + context.asInstanceOf[BarrierTaskContext].barrier() + writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out) + } catch { + case e: SparkException => + writeUTF(e.getMessage, out) + } finally { + out.close() + } + } + + def writeUTF(str: String, dataOut: DataOutputStream) { + val bytes = str.getBytes(UTF_8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } } abstract class ReaderIterator( @@ -465,3 +565,9 @@ private[spark] object SpecialLengths { val NULL = -5 val START_ARROW_STREAM = -6 } + +private[spark] object BarrierTaskContextMessageProtocol { + val BARRIER_FUNCTION = 1 + val BARRIER_RESULT_SUCCESS = "success" + val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." +} diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 47c4c3e663b97..10385589c4d3b 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -715,6 +715,13 @@ def write_int(value, stream): stream.write(struct.pack("!i", value)) +def read_bool(stream): + length = stream.read(1) + if not length: + raise EOFError + return struct.unpack("!?", length)[0] + + def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 63ae1f30e17ca..c0312e5265c6e 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -16,6 +16,10 @@ # from __future__ import print_function +import socket + +from pyspark.java_gateway import do_server_auth +from pyspark.serializers import write_int, UTF8Deserializer class TaskContext(object): @@ -95,3 +99,143 @@ def getLocalProperty(self, key): Get a local property set upstream in the driver, or None if it is missing. """ return self._localProperties.get(key, None) + + +BARRIER_FUNCTION = 1 + + +def _load_from_socket(port, auth_secret): + """ + Load data from a given socket, this is a blocking method thus only return when the socket + connection has been closed. + + This is copied from context.py, while modified the message protocol. + """ + sock = None + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + try: + # Do not allow timeout for socket reading operation. + sock.settimeout(None) + sock.connect(sa) + except socket.error: + sock.close() + sock = None + continue + break + if not sock: + raise Exception("could not open socket") + + # We don't really need a socket file here, it's just for convenience that we can reuse the + # do_server_auth() function and data serialization methods. + sockfile = sock.makefile("rwb", 65536) + + # Make a barrier() function call. + write_int(BARRIER_FUNCTION, sockfile) + sockfile.flush() + + # Do server auth. + do_server_auth(sockfile, auth_secret) + + # Collect result. + res = UTF8Deserializer().loads(sockfile) + + # Release resources. + sockfile.close() + sock.close() + + return res + + +class BarrierTaskContext(TaskContext): + + """ + .. note:: Experimental + + A TaskContext with extra info and tooling for a barrier stage. To access the BarrierTaskContext + for a running task, use: + L{BarrierTaskContext.get()}. + + .. versionadded:: 2.4.0 + """ + + _port = None + _secret = None + + def __init__(self): + """Construct a BarrierTaskContext, use get instead""" + pass + + @classmethod + def _getOrCreate(cls): + """Internal function to get or create global BarrierTaskContext.""" + if cls._taskContext is None: + cls._taskContext = BarrierTaskContext() + return cls._taskContext + + @classmethod + def get(cls): + """ + Return the currently active BarrierTaskContext. This can be called inside of user functions + to access contextual information about running tasks. + + .. note:: Must be called on the worker, not the driver. Returns None if not initialized. + """ + return cls._taskContext + + @classmethod + def _initialize(cls, port, secret): + """ + Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called + after BarrierTaskContext is initialized. + """ + cls._port = port + cls._secret = secret + + def barrier(self): + """ + .. note:: Experimental + + Sets a global barrier and waits until all tasks in this stage hit this barrier. + Note this method is only allowed for a BarrierTaskContext. + + .. versionadded:: 2.4.0 + """ + if self._port is None or self._secret is None: + raise Exception("Not supported to call barrier() before initialize " + + "BarrierTaskContext.") + else: + _load_from_socket(self._port, self._secret) + + def getTaskInfos(self): + """ + .. note:: Experimental + + Returns the all task infos in this barrier stage, the task infos are ordered by + partitionId. + Note this method is only allowed for a BarrierTaskContext. + + .. versionadded:: 2.4.0 + """ + if self._port is None or self._secret is None: + raise Exception("Not supported to call getTaskInfos() before initialize " + + "BarrierTaskContext.") + else: + addresses = self._localProperties.get("addresses", "") + return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")] + + +class BarrierTaskInfo(object): + """ + .. note:: Experimental + + Carries all task infos of a barrier task. + + .. versionadded:: 2.4.0 + """ + + def __init__(self, address): + self.address = address diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index a4c5fb1db8b37..8ac1df52fc597 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -70,7 +70,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler -from pyspark.taskcontext import TaskContext +from pyspark.taskcontext import BarrierTaskContext, TaskContext _have_scipy = False _have_numpy = False @@ -588,6 +588,40 @@ def test_get_local_property(self): finally: self.sc.setLocalProperty(key, None) + def test_barrier(self): + """ + Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks + within a stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + tc.barrier() + return time.time() + + times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() + self.assertTrue(max(times) - min(times) < 1) + + def test_barrier_infos(self): + """ + Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the + barrier stage. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get() + .getTaskInfos()).collect() + self.assertTrue(len(taskInfos) == 4) + self.assertTrue(len(taskInfos[0]) == 4) + class RDDTests(ReusedPySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index eaaae2b14e107..d54a5b8e396ea 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -28,10 +28,10 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.java_gateway import do_server_auth -from pyspark.taskcontext import TaskContext +from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType -from pyspark.serializers import write_with_length, write_int, read_long, \ +from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type @@ -259,8 +259,18 @@ def main(infile, outfile): "PYSPARK_DRIVER_PYTHON are correctly set.") % ("%d.%d" % sys.version_info[:2], version)) + # read inputs only for a barrier task + isBarrier = read_bool(infile) + boundPort = read_int(infile) + secret = UTF8Deserializer().loads(infile) # initialize global state - taskContext = TaskContext._getOrCreate() + taskContext = None + if isBarrier: + taskContext = BarrierTaskContext._getOrCreate() + BarrierTaskContext._initialize(boundPort, secret) + else: + taskContext = TaskContext._getOrCreate() + # read inputs for TaskContext info taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile)