Skip to content

Commit

Permalink
[SPARK-25095][PYSPARK] Python support for BarrierTaskContext
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Add method `barrier()` and `getTaskInfos()` in python TaskContext, these two methods are only allowed for barrier tasks.

## How was this patch tested?

Add new tests in `tests.py`

Closes apache#22085 from jiangxb1987/python.barrier.

Authored-by: Xingbo Jiang <xingbo.jiang@databricks.com>
Signed-off-by: Xiangrui Meng <meng@databricks.com>
  • Loading branch information
jiangxb1987 authored and mengxr committed Aug 21, 2018
1 parent 42035a4 commit ad45299
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 4 deletions.
106 changes: 106 additions & 0 deletions core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ package org.apache.spark.api.python
import java.io._
import java.net._
import java.nio.charset.StandardCharsets
import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.JavaConverters._

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util._


Expand Down Expand Up @@ -76,6 +78,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
// TODO: support accumulator in multiple UDF
protected val accumulator = funcs.head.funcs.head.accumulator

// Expose a ServerSocket to support method calls via socket from Python side.
private[spark] var serverSocket: Option[ServerSocket] = None

// Authentication helper used when serving method calls via socket from Python side.
private lazy val authHelper = new SocketAuthHelper(SparkEnv.get.conf)

def compute(
inputIterator: Iterator[IN],
partitionIndex: Int,
Expand Down Expand Up @@ -180,7 +188,73 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
dataOut.writeInt(partitionIndex)
// Python version of driver
PythonRDD.writeUTF(pythonVer, dataOut)
// Init a ServerSocket to accept method calls from Python side.
val isBarrier = context.isInstanceOf[BarrierTaskContext]
if (isBarrier) {
serverSocket = Some(new ServerSocket(/* port */ 0,
/* backlog */ 1,
InetAddress.getByName("localhost")))
// A call to accept() for ServerSocket shall block infinitely.
serverSocket.map(_.setSoTimeout(0))
new Thread("accept-connections") {
setDaemon(true)

override def run(): Unit = {
while (!serverSocket.get.isClosed()) {
var sock: Socket = null
try {
sock = serverSocket.get.accept()
// Wait for function call from python side.
sock.setSoTimeout(10000)
val input = new DataInputStream(sock.getInputStream())
input.readInt() match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
// The barrier() function may wait infinitely, socket shall not timeout
// before the function finishes.
sock.setSoTimeout(0)
barrierAndServe(sock)

case _ =>
val out = new DataOutputStream(new BufferedOutputStream(
sock.getOutputStream))
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
}
} catch {
case e: SocketException if e.getMessage.contains("Socket closed") =>
// It is possible that the ServerSocket is not closed, but the native socket
// has already been closed, we shall catch and silently ignore this case.
} finally {
if (sock != null) {
sock.close()
}
}
}
}
}.start()
}
val secret = if (isBarrier) {
authHelper.secret
} else {
""
}
// Close ServerSocket on task completion.
serverSocket.foreach { server =>
context.addTaskCompletionListener(_ => server.close())
}
val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
if (boundPort == -1) {
val message = "ServerSocket failed to bind to Java side."
logError(message)
throw new SparkException(message)
} else if (isBarrier) {
logDebug(s"Started ServerSocket on port $boundPort.")
}
// Write out the TaskContextInfo
dataOut.writeBoolean(isBarrier)
dataOut.writeInt(boundPort)
val secretBytes = secret.getBytes(UTF_8)
dataOut.writeInt(secretBytes.length)
dataOut.write(secretBytes, 0, secretBytes.length)
dataOut.writeInt(context.stageId())
dataOut.writeInt(context.partitionId())
dataOut.writeInt(context.attemptNumber())
Expand Down Expand Up @@ -243,6 +317,32 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
}
}

/**
* Gateway to call BarrierTaskContext.barrier().
*/
def barrierAndServe(sock: Socket): Unit = {
require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.")

authHelper.authClient(sock)

val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
try {
context.asInstanceOf[BarrierTaskContext].barrier()
writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out)
} catch {
case e: SparkException =>
writeUTF(e.getMessage, out)
} finally {
out.close()
}
}

def writeUTF(str: String, dataOut: DataOutputStream) {
val bytes = str.getBytes(UTF_8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
}

abstract class ReaderIterator(
Expand Down Expand Up @@ -465,3 +565,9 @@ private[spark] object SpecialLengths {
val NULL = -5
val START_ARROW_STREAM = -6
}

private[spark] object BarrierTaskContextMessageProtocol {
val BARRIER_FUNCTION = 1
val BARRIER_RESULT_SUCCESS = "success"
val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side."
}
7 changes: 7 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,13 @@ def write_int(value, stream):
stream.write(struct.pack("!i", value))


def read_bool(stream):
length = stream.read(1)
if not length:
raise EOFError
return struct.unpack("!?", length)[0]


def write_with_length(obj, stream):
write_int(len(obj), stream)
stream.write(obj)
Expand Down
144 changes: 144 additions & 0 deletions python/pyspark/taskcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
#

from __future__ import print_function
import socket

from pyspark.java_gateway import do_server_auth
from pyspark.serializers import write_int, UTF8Deserializer


class TaskContext(object):
Expand Down Expand Up @@ -95,3 +99,143 @@ def getLocalProperty(self, key):
Get a local property set upstream in the driver, or None if it is missing.
"""
return self._localProperties.get(key, None)


BARRIER_FUNCTION = 1


def _load_from_socket(port, auth_secret):
"""
Load data from a given socket, this is a blocking method thus only return when the socket
connection has been closed.
This is copied from context.py, while modified the message protocol.
"""
sock = None
# Support for both IPv4 and IPv6.
# On most of IPv6-ready systems, IPv6 will take precedence.
for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = socket.socket(af, socktype, proto)
try:
# Do not allow timeout for socket reading operation.
sock.settimeout(None)
sock.connect(sa)
except socket.error:
sock.close()
sock = None
continue
break
if not sock:
raise Exception("could not open socket")

# We don't really need a socket file here, it's just for convenience that we can reuse the
# do_server_auth() function and data serialization methods.
sockfile = sock.makefile("rwb", 65536)

# Make a barrier() function call.
write_int(BARRIER_FUNCTION, sockfile)
sockfile.flush()

# Do server auth.
do_server_auth(sockfile, auth_secret)

# Collect result.
res = UTF8Deserializer().loads(sockfile)

# Release resources.
sockfile.close()
sock.close()

return res


class BarrierTaskContext(TaskContext):

"""
.. note:: Experimental
A TaskContext with extra info and tooling for a barrier stage. To access the BarrierTaskContext
for a running task, use:
L{BarrierTaskContext.get()}.
.. versionadded:: 2.4.0
"""

_port = None
_secret = None

def __init__(self):
"""Construct a BarrierTaskContext, use get instead"""
pass

@classmethod
def _getOrCreate(cls):
"""Internal function to get or create global BarrierTaskContext."""
if cls._taskContext is None:
cls._taskContext = BarrierTaskContext()
return cls._taskContext

@classmethod
def get(cls):
"""
Return the currently active BarrierTaskContext. This can be called inside of user functions
to access contextual information about running tasks.
.. note:: Must be called on the worker, not the driver. Returns None if not initialized.
"""
return cls._taskContext

@classmethod
def _initialize(cls, port, secret):
"""
Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called
after BarrierTaskContext is initialized.
"""
cls._port = port
cls._secret = secret

def barrier(self):
"""
.. note:: Experimental
Sets a global barrier and waits until all tasks in this stage hit this barrier.
Note this method is only allowed for a BarrierTaskContext.
.. versionadded:: 2.4.0
"""
if self._port is None or self._secret is None:
raise Exception("Not supported to call barrier() before initialize " +
"BarrierTaskContext.")
else:
_load_from_socket(self._port, self._secret)

def getTaskInfos(self):
"""
.. note:: Experimental
Returns the all task infos in this barrier stage, the task infos are ordered by
partitionId.
Note this method is only allowed for a BarrierTaskContext.
.. versionadded:: 2.4.0
"""
if self._port is None or self._secret is None:
raise Exception("Not supported to call getTaskInfos() before initialize " +
"BarrierTaskContext.")
else:
addresses = self._localProperties.get("addresses", "")
return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")]


class BarrierTaskInfo(object):
"""
.. note:: Experimental
Carries all task infos of a barrier task.
.. versionadded:: 2.4.0
"""

def __init__(self, address):
self.address = address
36 changes: 35 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
from pyspark import shuffle
from pyspark.profiler import BasicProfiler
from pyspark.taskcontext import TaskContext
from pyspark.taskcontext import BarrierTaskContext, TaskContext

_have_scipy = False
_have_numpy = False
Expand Down Expand Up @@ -588,6 +588,40 @@ def test_get_local_property(self):
finally:
self.sc.setLocalProperty(key, None)

def test_barrier(self):
"""
Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks
within a stage.
"""
rdd = self.sc.parallelize(range(10), 4)

def f(iterator):
yield sum(iterator)

def context_barrier(x):
tc = BarrierTaskContext.get()
time.sleep(random.randint(1, 10))
tc.barrier()
return time.time()

times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
self.assertTrue(max(times) - min(times) < 1)

def test_barrier_infos(self):
"""
Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
barrier stage.
"""
rdd = self.sc.parallelize(range(10), 4)

def f(iterator):
yield sum(iterator)

taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get()
.getTaskInfos()).collect()
self.assertTrue(len(taskInfos) == 4)
self.assertTrue(len(taskInfos[0]) == 4)


class RDDTests(ReusedPySparkTestCase):

Expand Down
16 changes: 13 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.java_gateway import do_server_auth
from pyspark.taskcontext import TaskContext
from pyspark.taskcontext import BarrierTaskContext, TaskContext
from pyspark.files import SparkFiles
from pyspark.rdd import PythonEvalType
from pyspark.serializers import write_with_length, write_int, read_long, \
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
Expand Down Expand Up @@ -259,8 +259,18 @@ def main(infile, outfile):
"PYSPARK_DRIVER_PYTHON are correctly set.") %
("%d.%d" % sys.version_info[:2], version))

# read inputs only for a barrier task
isBarrier = read_bool(infile)
boundPort = read_int(infile)
secret = UTF8Deserializer().loads(infile)
# initialize global state
taskContext = TaskContext._getOrCreate()
taskContext = None
if isBarrier:
taskContext = BarrierTaskContext._getOrCreate()
BarrierTaskContext._initialize(boundPort, secret)
else:
taskContext = TaskContext._getOrCreate()
# read inputs for TaskContext info
taskContext._stageId = read_int(infile)
taskContext._partitionId = read_int(infile)
taskContext._attemptNumber = read_int(infile)
Expand Down

0 comments on commit ad45299

Please sign in to comment.