18
18
import sys
19
19
20
20
from py4j .java_collections import ListConverter
21
- from py4j .java_gateway import java_import
21
+ from py4j .java_gateway import java_import , JavaObject
22
22
23
23
from pyspark import RDD , SparkConf
24
24
from pyspark .serializers import UTF8Deserializer , CloudPickleSerializer
@@ -38,6 +38,8 @@ def _daemonize_callback_server():
38
38
from exiting if it's not shutdown. The following code replace `start()`
39
39
of CallbackServer with a new version, which set daemon=True for this
40
40
thread.
41
+
42
+ Also, it will update the port number (0) with real port
41
43
"""
42
44
# TODO: create a patch for Py4J
43
45
import socket
@@ -54,8 +56,11 @@ def start(self):
54
56
1 )
55
57
try :
56
58
self .server_socket .bind ((self .address , self .port ))
57
- except Exception :
58
- msg = 'An error occurred while trying to start the callback server'
59
+ if not self .port :
60
+ # update port with real port
61
+ self .port = self .server_socket .getsockname ()[1 ]
62
+ except Exception as e :
63
+ msg = 'An error occurred while trying to start the callback server: %s' % e
59
64
logger .exception (msg )
60
65
raise Py4JNetworkError (msg )
61
66
@@ -105,15 +110,24 @@ def _jduration(self, seconds):
105
110
def _ensure_initialized (cls ):
106
111
SparkContext ._ensure_initialized ()
107
112
gw = SparkContext ._gateway
108
- # start callback server
109
- # getattr will fallback to JVM
110
- if "_callback_server" not in gw .__dict__ :
111
- _daemonize_callback_server ()
112
- gw ._start_callback_server (gw ._python_proxy_port )
113
113
114
114
java_import (gw .jvm , "org.apache.spark.streaming.*" )
115
115
java_import (gw .jvm , "org.apache.spark.streaming.api.java.*" )
116
116
java_import (gw .jvm , "org.apache.spark.streaming.api.python.*" )
117
+
118
+ # start callback server
119
+ # getattr will fallback to JVM, so we cannot test by hasattr()
120
+ if "_callback_server" not in gw .__dict__ :
121
+ _daemonize_callback_server ()
122
+ # use random port
123
+ gw ._start_callback_server (0 )
124
+ # gateway with real port
125
+ gw ._python_proxy_port = gw ._callback_server .port
126
+ # get the GatewayServer object in JVM by ID
127
+ jgws = JavaObject ("GATEWAY_SERVER" , gw ._gateway_client )
128
+ # update the port of CallbackClient with real port
129
+ gw .jvm .PythonDStream .updatePythonGatewayPort (jgws , gw ._python_proxy_port )
130
+
117
131
# register serializer for TransformFunction
118
132
# it happens before creating SparkContext when loading from checkpointing
119
133
cls ._transformerSerializer = TransformFunctionSerializer (
0 commit comments