Skip to content

Commit 30a811b

Browse files
squitoHyukjinKwon
authored andcommitted
[SPARK-26019][PYSPARK] Allow insecure py4j gateways
Spark always creates secure py4j connections between java and python, but it also allows users to pass in their own connection. This restores the ability for users to pass in an _insecure_ connection, though it forces them to set the env variable 'PYSPARK_ALLOW_INSECURE_GATEWAY=1', and still issues a warning. Added test cases verifying the failure without the extra configuration, and verifying things still work with an insecure configuration (in particular, accumulators, as those were broken with an insecure py4j gateway before). For the tests, I added ways to create insecure gateways, but I tried to put in protections to make sure that wouldn't get used incorrectly. Closes #23337 from squito/SPARK-26019. Authored-by: Imran Rashid <irashid@cloudera.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org> (cherry picked from commit 1e99f4e) Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 70a99ba commit 30a811b

File tree

6 files changed

+81
-12
lines changed

6 files changed

+81
-12
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,17 @@ private[spark] object PythonGatewayServer extends Logging {
4343
// with the same secret, in case the app needs callbacks from the JVM to the underlying
4444
// python processes.
4545
val localhost = InetAddress.getLoopbackAddress()
46-
val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
47-
.authToken(secret)
46+
val builder = new GatewayServer.GatewayServerBuilder()
4847
.javaPort(0)
4948
.javaAddress(localhost)
5049
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
51-
.build()
50+
if (sys.env.getOrElse("_PYSPARK_CREATE_INSECURE_GATEWAY", "0") != "1") {
51+
builder.authToken(secret)
52+
} else {
53+
assert(sys.env.getOrElse("SPARK_TESTING", "0") == "1",
54+
"Creating insecure Java gateways only allowed for testing")
55+
}
56+
val gatewayServer: GatewayServer = builder.build()
5257

5358
gatewayServer.start()
5459
val boundPort: Int = gatewayServer.getListeningPort

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,10 @@ private[spark] class PythonAccumulatorV2(
595595
if (socket == null || socket.isClosed) {
596596
socket = new Socket(serverHost, serverPort)
597597
logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort")
598-
// send the secret just for the initial authentication when opening a new connection
599-
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
598+
if (secretToken != null) {
599+
// send the secret just for the initial authentication when opening a new connection
600+
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
601+
}
600602
}
601603
socket
602604
}

python/pyspark/accumulators.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,10 @@ def authenticate_and_accum_updates():
263263
raise Exception(
264264
"The value of the provided token to the AccumulatorServer is not correct.")
265265

266-
# first we keep polling till we've received the authentication token
267-
poll(authenticate_and_accum_updates)
268-
# now we've authenticated, don't need to check for the token anymore
266+
if auth_token is not None:
267+
# first we keep polling till we've received the authentication token
268+
poll(authenticate_and_accum_updates)
269+
# now we've authenticated if needed, don't need to check for the token anymore
269270
poll(accum_updates)
270271

271272

python/pyspark/context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,20 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
112112
ValueError:...
113113
"""
114114
self._callsite = first_spark_call() or CallSite(None, None, None)
115+
if gateway is not None and gateway.gateway_parameters.auth_token is None:
116+
allow_insecure_env = os.environ.get("PYSPARK_ALLOW_INSECURE_GATEWAY", "0")
117+
if allow_insecure_env == "1" or allow_insecure_env.lower() == "true":
118+
warnings.warn(
119+
"You are passing in an insecure Py4j gateway. This "
120+
"presents a security risk, and will be completely forbidden in Spark 3.0")
121+
else:
122+
raise ValueError(
123+
"You are trying to pass an insecure Py4j gateway to Spark. This"
124+
" presents a security risk. If you are sure you understand and accept this"
125+
" risk, you can set the environment variable"
126+
" 'PYSPARK_ALLOW_INSECURE_GATEWAY=1', but"
127+
" note this option will be removed in Spark 3.0")
128+
115129
SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)
116130
try:
117131
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,

python/pyspark/java_gateway.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,20 @@ def launch_gateway(conf=None):
4040
"""
4141
launch jvm gateway
4242
:param conf: spark configuration passed to spark-submit
43-
:return:
43+
:return: a JVM gateway
4444
"""
45+
return _launch_gateway(conf)
46+
47+
48+
def _launch_gateway(conf=None, insecure=False):
49+
"""
50+
launch jvm gateway
51+
:param conf: spark configuration passed to spark-submit
52+
:param insecure: True to create an insecure gateway; only for testing
53+
:return: a JVM gateway
54+
"""
55+
if insecure and os.environ.get("SPARK_TESTING", "0") != "1":
56+
raise ValueError("creating insecure gateways is only for testing")
4557
if "PYSPARK_GATEWAY_PORT" in os.environ:
4658
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
4759
gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
@@ -73,6 +85,8 @@ def launch_gateway(conf=None):
7385

7486
env = dict(os.environ)
7587
env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
88+
if insecure:
89+
env["_PYSPARK_CREATE_INSECURE_GATEWAY"] = "1"
7690

7791
# Launch the Java gateway.
7892
# We open a pipe to stdin so that the Java gateway can die when the pipe is broken
@@ -115,9 +129,10 @@ def killChild():
115129
atexit.register(killChild)
116130

117131
# Connect to the gateway
118-
gateway = JavaGateway(
119-
gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
120-
auto_convert=True))
132+
gateway_params = GatewayParameters(port=gateway_port, auto_convert=True)
133+
if not insecure:
134+
gateway_params.auth_token = gateway_secret
135+
gateway = JavaGateway(gateway_parameters=gateway_params)
121136

122137
# Import the classes used by PySpark
123138
java_import(gateway.jvm, "org.apache.spark.SparkConf")

python/pyspark/tests.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from pyspark import keyword_only
6262
from pyspark.conf import SparkConf
6363
from pyspark.context import SparkContext
64+
from pyspark.java_gateway import _launch_gateway
6465
from pyspark.rdd import RDD
6566
from pyspark.files import SparkFiles
6667
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
@@ -2295,6 +2296,37 @@ def test_startTime(self):
22952296
with SparkContext() as sc:
22962297
self.assertGreater(sc.startTime, 0)
22972298

2299+
def test_forbid_insecure_gateway(self):
2300+
# By default, we fail immediately if you try to create a SparkContext
2301+
# with an insecure gateway
2302+
gateway = _launch_gateway(insecure=True)
2303+
log4j = gateway.jvm.org.apache.log4j
2304+
old_level = log4j.LogManager.getRootLogger().getLevel()
2305+
try:
2306+
log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL)
2307+
with self.assertRaises(Exception) as context:
2308+
SparkContext(gateway=gateway)
2309+
self.assertIn("insecure Py4j gateway", str(context.exception))
2310+
self.assertIn("PYSPARK_ALLOW_INSECURE_GATEWAY", str(context.exception))
2311+
self.assertIn("removed in Spark 3.0", str(context.exception))
2312+
finally:
2313+
log4j.LogManager.getRootLogger().setLevel(old_level)
2314+
2315+
def test_allow_insecure_gateway_with_conf(self):
2316+
with SparkContext._lock:
2317+
SparkContext._gateway = None
2318+
SparkContext._jvm = None
2319+
gateway = _launch_gateway(insecure=True)
2320+
try:
2321+
os.environ["PYSPARK_ALLOW_INSECURE_GATEWAY"] = "1"
2322+
with SparkContext(gateway=gateway) as sc:
2323+
a = sc.accumulator(1)
2324+
rdd = sc.parallelize([1, 2, 3])
2325+
rdd.foreach(lambda x: a.add(x))
2326+
self.assertEqual(7, a.value)
2327+
finally:
2328+
os.environ.pop("PYSPARK_ALLOW_INSECURE_GATEWAY", None)
2329+
22982330

22992331
class ConfTests(unittest.TestCase):
23002332
def test_memory_conf(self):

0 commit comments

Comments
 (0)