Skip to content

Commit 2d40cba

Browse files
committed
mock the insecure gateway, clean up everything related to creating insecure gateways
1 parent ee6343c commit 2d40cba

File tree

3 files changed

+20
-38
lines changed

3 files changed

+20
-38
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,12 @@ private[spark] object PythonGatewayServer extends Logging {
4343
// with the same secret, in case the app needs callbacks from the JVM to the underlying
4444
// python processes.
4545
val localhost = InetAddress.getLoopbackAddress()
46-
val builder = new GatewayServer.GatewayServerBuilder()
46+
val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
47+
.authToken(secret)
4748
.javaPort(0)
4849
.javaAddress(localhost)
4950
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
50-
if (sys.env.getOrElse("_PYSPARK_CREATE_INSECURE_GATEWAY", "0") != "1") {
51-
builder.authToken(secret)
52-
} else {
53-
assert(sys.env.getOrElse("SPARK_TESTING", "0") == "1",
54-
"Creating insecure Java gateways only allowed for testing")
55-
}
56-
val gatewayServer: GatewayServer = builder.build()
51+
.build()
5752

5853
gatewayServer.start()
5954
val boundPort: Int = gatewayServer.getListeningPort

python/pyspark/java_gateway.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,8 @@ def launch_gateway(conf=None):
4141
"""
4242
launch jvm gateway
4343
:param conf: spark configuration passed to spark-submit
44-
:return: a JVM gateway
44+
:return:
4545
"""
46-
return _launch_gateway(conf)
47-
48-
49-
def _launch_gateway(conf=None, insecure=False):
50-
"""
51-
launch jvm gateway
52-
:param conf: spark configuration passed to spark-submit
53-
:param insecure: True to create an insecure gateway; only for testing
54-
:return: a JVM gateway
55-
"""
56-
if insecure and os.environ.get("SPARK_TESTING", "0") != "1":
57-
raise ValueError("creating insecure gateways is only for testing")
5846
if "PYSPARK_GATEWAY_PORT" in os.environ:
5947
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
6048
gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
@@ -86,8 +74,6 @@ def _launch_gateway(conf=None, insecure=False):
8674

8775
env = dict(os.environ)
8876
env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
89-
if insecure:
90-
env["_PYSPARK_CREATE_INSECURE_GATEWAY"] = "1"
9177

9278
# Launch the Java gateway.
9379
# We open a pipe to stdin so that the Java gateway can die when the pipe is broken
@@ -130,10 +116,9 @@ def killChild():
130116
atexit.register(killChild)
131117

132118
# Connect to the gateway
133-
gateway_params = GatewayParameters(port=gateway_port, auto_convert=True)
134-
if not insecure:
135-
gateway_params.auth_token = gateway_secret
136-
gateway = JavaGateway(gateway_parameters=gateway_params)
119+
gateway = JavaGateway(
120+
gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
121+
auto_convert=True))
137122

138123
# Import the classes used by PySpark
139124
java_import(gateway.jvm, "org.apache.spark.SparkConf")

python/pyspark/tests/test_context.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@
2020
import threading
2121
import time
2222
import unittest
23+
import sys
24+
if sys.version >= '3':
25+
from unittest.mock import MagicMock
26+
else:
27+
from mock import MagicMock
28+
29+
2330

2431
from pyspark import SparkFiles, SparkContext
25-
from pyspark.java_gateway import _launch_gateway
2632
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME
2733

2834

@@ -250,16 +256,12 @@ def test_startTime(self):
250256
def test_forbid_insecure_gateway(self):
251257
# Fail immediately if you try to create a SparkContext
252258
# with an insecure gateway
253-
gateway = _launch_gateway(insecure=True)
254-
log4j = gateway.jvm.org.apache.log4j
255-
old_level = log4j.LogManager.getRootLogger().getLevel()
256-
try:
257-
log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL)
258-
with self.assertRaises(Exception) as context:
259-
SparkContext(gateway=gateway)
260-
self.assertIn("insecure Py4j gateway", str(context.exception))
261-
finally:
262-
log4j.LogManager.getRootLogger().setLevel(old_level)
259+
mock_insecure_gateway = MagicMock()
260+
mock_insecure_gateway.gateway_parameters = MagicMock()
261+
mock_insecure_gateway.gateway_parameters.auth_token = None
262+
with self.assertRaises(Exception) as context:
263+
SparkContext(gateway=mock_insecure_gateway)
264+
self.assertIn("insecure Py4j gateway", str(context.exception))
263265

264266

265267
if __name__ == "__main__":

0 commit comments

Comments
 (0)