Skip to content

Commit 37fe06f

Browse files
committed
use random port for callback server
1 parent d05871e commit 37fe06f

File tree

2 files changed

+51
-15
lines changed

2 files changed

+51
-15
lines changed

python/pyspark/streaming/context.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import sys
1919

2020
from py4j.java_collections import ListConverter
21-
from py4j.java_gateway import java_import
21+
from py4j.java_gateway import java_import, JavaObject
2222

2323
from pyspark import RDD, SparkConf
2424
from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer
@@ -38,6 +38,8 @@ def _daemonize_callback_server():
3838
from exiting if it's not shutdown. The following code replace `start()`
3939
of CallbackServer with a new version, which set daemon=True for this
4040
thread.
41+
42+
Also, it will update the port number (0) with real port
4143
"""
4244
# TODO: create a patch for Py4J
4345
import socket
@@ -54,8 +56,11 @@ def start(self):
5456
1)
5557
try:
5658
self.server_socket.bind((self.address, self.port))
57-
except Exception:
58-
msg = 'An error occurred while trying to start the callback server'
59+
if not self.port:
60+
# update port with real port
61+
self.port = self.server_socket.getsockname()[1]
62+
except Exception as e:
63+
msg = 'An error occurred while trying to start the callback server: %s' % e
5964
logger.exception(msg)
6065
raise Py4JNetworkError(msg)
6166

@@ -105,15 +110,24 @@ def _jduration(self, seconds):
105110
def _ensure_initialized(cls):
106111
SparkContext._ensure_initialized()
107112
gw = SparkContext._gateway
108-
# start callback server
109-
# getattr will fallback to JVM
110-
if "_callback_server" not in gw.__dict__:
111-
_daemonize_callback_server()
112-
gw._start_callback_server(gw._python_proxy_port)
113113

114114
java_import(gw.jvm, "org.apache.spark.streaming.*")
115115
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
116116
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
117+
118+
# start callback server
119+
# getattr will fallback to JVM, so we cannot test by hasattr()
120+
if "_callback_server" not in gw.__dict__:
121+
_daemonize_callback_server()
122+
# use random port
123+
gw._start_callback_server(0)
124+
# gateway with real port
125+
gw._python_proxy_port = gw._callback_server.port
126+
# get the GatewayServer object in JVM by ID
127+
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
128+
# update the port of CallbackClient with real port
129+
gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port)
130+
117131
# register serializer for TransformFunction
118132
# it happens before creating SparkContext when loading from checkpointing
119133
cls._transformerSerializer = TransformFunctionSerializer(

streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import scala.collection.JavaConversions._
2424
import scala.collection.JavaConverters._
2525
import scala.language.existentials
2626

27+
import py4j.GatewayServer
28+
2729
import org.apache.spark.api.java._
2830
import org.apache.spark.api.python._
2931
import org.apache.spark.rdd.RDD
@@ -88,10 +90,14 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun
8890
*/
8991
private[python] object PythonTransformFunctionSerializer {
9092

91-
// A serializer in Python, used to serialize PythonTransformFunction
93+
/**
94+
* A serializer in Python, used to serialize PythonTransformFunction
95+
*/
9296
private var serializer: PythonTransformFunctionSerializer = _
9397

94-
// Register a serializer from Python, should be called during initialization
98+
/*
99+
* Register a serializer from Python, should be called during initialization
100+
*/
95101
def register(ser: PythonTransformFunctionSerializer): Unit = {
96102
serializer = ser
97103
}
@@ -117,20 +123,36 @@ private[python] object PythonTransformFunctionSerializer {
117123
*/
118124
private[python] object PythonDStream {
119125

120-
// can not access PythonTransformFunctionSerializer.register() via Py4j
121-
// Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM
126+
/**
127+
* can not access PythonTransformFunctionSerializer.register() via Py4j
128+
* Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM
129+
*/
122130
def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = {
123131
PythonTransformFunctionSerializer.register(ser)
124132
}
125133

126-
// helper function for DStream.foreachRDD(),
127-
// cannot be `foreachRDD`, it will confusing py4j
134+
/**
135+
* Update the port of callback client to `port`
136+
*/
137+
def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = {
138+
val cl = gws.getCallbackClient
139+
val f = cl.getClass.getDeclaredField("port")
140+
f.setAccessible(true)
141+
f.setInt(cl, port)
142+
}
143+
144+
/**
145+
* helper function for DStream.foreachRDD(),
146+
* cannot be `foreachRDD`, it will confusing py4j
147+
*/
128148
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) {
129149
val func = new TransformFunction((pfunc))
130150
jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
131151
}
132152

133-
// convert list of RDD into queue of RDDs, for ssc.queueStream()
153+
/**
154+
* convert list of RDD into queue of RDDs, for ssc.queueStream()
155+
*/
134156
def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = {
135157
val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]]
136158
rdds.forall(queue.add(_))

0 commit comments

Comments
 (0)