Skip to content

Commit 2f70689

Browse files
committed
Use stdin PIPE to share fate with driver
1 parent 8bf956e commit 2f70689

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ private[spark] object PythonGatewayServer {
4141
val callbackSocket = new Socket(callbackHost, callbackPort)
4242
val dos = new DataOutputStream(callbackSocket.getOutputStream)
4343
dos.writeInt(boundPort)
44-
dos.flush()
44+
dos.close()
45+
callbackSocket.close()
4546

46-
// Exit once the callback socket is closed to ensure that this process dies when the Python
47-
// driver dies:
48-
while (callbackSocket.getInputStream.read() != -1) {
47+
System.in.read()
48+
// Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
49+
while (System.in.read() != -1) {
4950
// Do nothing
5051
}
5152
System.exit(0)

python/pyspark/java_gateway.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@
2828

2929
from pyspark.serializers import read_int
3030

31-
_gateway_connection = None
3231

3332
def launch_gateway():
34-
global _gateway_connection
3533
SPARK_HOME = os.environ["SPARK_HOME"]
3634

3735
if "PYSPARK_GATEWAY_PORT" in os.environ:
@@ -55,6 +53,8 @@ def launch_gateway():
5553
env['PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
5654
env['PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)
5755

56+
# Launch the Java gateway.
57+
# We open a pipe to stdin so that the Java gateway can die when the pipe is broken
5858
if not on_windows:
5959
# Don't send ctrl-c / SIGINT to the Java gateway:
6060
def preexec_func():
@@ -65,9 +65,11 @@ def preexec_func():
6565
# preexec_fn not supported on Windows
6666
proc = Popen(command, stdout=PIPE, stdin=PIPE, env=env)
6767

68-
_gateway_connection = callback_socket.accept()[0]
68+
gateway_connection = callback_socket.accept()[0]
6969
# Determine which ephemeral port the server started on:
70-
gateway_port = read_int(_gateway_connection.makefile())
70+
gateway_port = read_int(gateway_connection.makefile())
71+
gateway_connection.close()
72+
callback_socket.close()
7173

7274
# In Windows, ensure the Java child processes do not linger after Python has exited.
7375
# In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when

0 commit comments

Comments
 (0)