Skip to content

Commit fbb8ea3

Browse files
squitokai-chi
authored andcommitted
[SPARK-26019][PYSPARK] Allow insecure py4j gateways
Spark always creates secure py4j connections between java and python, but it also allows users to pass in their own connection. This restores the ability for users to pass in an _insecure_ connection, though it forces them to set the env variable 'PYSPARK_ALLOW_INSECURE_GATEWAY=1', and still issues a warning. Added test cases verifying the failure without the extra configuration, and verifying things still work with an insecure configuration (in particular, accumulators, as those were broken with an insecure py4j gateway before). For the tests, I added ways to create insecure gateways, but I tried to put in protections to make sure that wouldn't get used incorrectly. Closes apache#23337 from squito/SPARK-26019. Authored-by: Imran Rashid <irashid@cloudera.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 6612588 commit fbb8ea3

File tree

6 files changed

+81
-12
lines changed

6 files changed

+81
-12
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,17 @@ private[spark] object PythonGatewayServer extends Logging {
4343
// with the same secret, in case the app needs callbacks from the JVM to the underlying
4444
// python processes.
4545
val localhost = InetAddress.getLoopbackAddress()
46-
val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
47-
.authToken(secret)
46+
val builder = new GatewayServer.GatewayServerBuilder()
4847
.javaPort(0)
4948
.javaAddress(localhost)
5049
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
51-
.build()
50+
if (sys.env.getOrElse("_PYSPARK_CREATE_INSECURE_GATEWAY", "0") != "1") {
51+
builder.authToken(secret)
52+
} else {
53+
assert(sys.env.getOrElse("SPARK_TESTING", "0") == "1",
54+
"Creating insecure Java gateways only allowed for testing")
55+
}
56+
val gatewayServer: GatewayServer = builder.build()
5257

5358
gatewayServer.start()
5459
val boundPort: Int = gatewayServer.getListeningPort

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,10 @@ private[spark] class PythonAccumulatorV2(
616616
if (socket == null || socket.isClosed) {
617617
socket = new Socket(serverHost, serverPort)
618618
logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort")
619-
// send the secret just for the initial authentication when opening a new connection
620-
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
619+
if (secretToken != null) {
620+
// send the secret just for the initial authentication when opening a new connection
621+
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
622+
}
621623
}
622624
socket
623625
}

python/pyspark/accumulators.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,10 @@ def authenticate_and_accum_updates():
262262
raise Exception(
263263
"The value of the provided token to the AccumulatorServer is not correct.")
264264

265-
# first we keep polling till we've received the authentication token
266-
poll(authenticate_and_accum_updates)
267-
# now we've authenticated, don't need to check for the token anymore
265+
if auth_token is not None:
266+
# first we keep polling till we've received the authentication token
267+
poll(authenticate_and_accum_updates)
268+
# now we've authenticated if needed, don't need to check for the token anymore
268269
poll(accum_updates)
269270

270271

python/pyspark/context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,20 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
112112
ValueError:...
113113
"""
114114
self._callsite = first_spark_call() or CallSite(None, None, None)
115+
if gateway is not None and gateway.gateway_parameters.auth_token is None:
116+
allow_insecure_env = os.environ.get("PYSPARK_ALLOW_INSECURE_GATEWAY", "0")
117+
if allow_insecure_env == "1" or allow_insecure_env.lower() == "true":
118+
warnings.warn(
119+
"You are passing in an insecure Py4j gateway. This "
120+
"presents a security risk, and will be completely forbidden in Spark 3.0")
121+
else:
122+
raise ValueError(
123+
"You are trying to pass an insecure Py4j gateway to Spark. This"
124+
" presents a security risk. If you are sure you understand and accept this"
125+
" risk, you can set the environment variable"
126+
" 'PYSPARK_ALLOW_INSECURE_GATEWAY=1', but"
127+
" note this option will be removed in Spark 3.0")
128+
115129
SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)
116130
try:
117131
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,

python/pyspark/java_gateway.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,20 @@ def launch_gateway(conf=None):
4141
"""
4242
launch jvm gateway
4343
:param conf: spark configuration passed to spark-submit
44-
:return:
44+
:return: a JVM gateway
4545
"""
46+
return _launch_gateway(conf)
47+
48+
49+
def _launch_gateway(conf=None, insecure=False):
50+
"""
51+
launch jvm gateway
52+
:param conf: spark configuration passed to spark-submit
53+
:param insecure: True to create an insecure gateway; only for testing
54+
:return: a JVM gateway
55+
"""
56+
if insecure and os.environ.get("SPARK_TESTING", "0") != "1":
57+
raise ValueError("creating insecure gateways is only for testing")
4658
if "PYSPARK_GATEWAY_PORT" in os.environ:
4759
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
4860
gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
@@ -74,6 +86,8 @@ def launch_gateway(conf=None):
7486

7587
env = dict(os.environ)
7688
env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
89+
if insecure:
90+
env["_PYSPARK_CREATE_INSECURE_GATEWAY"] = "1"
7791

7892
# Launch the Java gateway.
7993
# We open a pipe to stdin so that the Java gateway can die when the pipe is broken
@@ -116,9 +130,10 @@ def killChild():
116130
atexit.register(killChild)
117131

118132
# Connect to the gateway
119-
gateway = JavaGateway(
120-
gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
121-
auto_convert=True))
133+
gateway_params = GatewayParameters(port=gateway_port, auto_convert=True)
134+
if not insecure:
135+
gateway_params.auth_token = gateway_secret
136+
gateway = JavaGateway(gateway_parameters=gateway_params)
122137

123138
# Import the classes used by PySpark
124139
java_import(gateway.jvm, "org.apache.spark.SparkConf")

python/pyspark/tests.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from pyspark import keyword_only
6262
from pyspark.conf import SparkConf
6363
from pyspark.context import SparkContext
64+
from pyspark.java_gateway import _launch_gateway
6465
from pyspark.rdd import RDD
6566
from pyspark.files import SparkFiles
6667
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
@@ -2381,6 +2382,37 @@ def test_startTime(self):
23812382
with SparkContext() as sc:
23822383
self.assertGreater(sc.startTime, 0)
23832384

2385+
def test_forbid_insecure_gateway(self):
2386+
# By default, we fail immediately if you try to create a SparkContext
2387+
# with an insecure gateway
2388+
gateway = _launch_gateway(insecure=True)
2389+
log4j = gateway.jvm.org.apache.log4j
2390+
old_level = log4j.LogManager.getRootLogger().getLevel()
2391+
try:
2392+
log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL)
2393+
with self.assertRaises(Exception) as context:
2394+
SparkContext(gateway=gateway)
2395+
self.assertIn("insecure Py4j gateway", str(context.exception))
2396+
self.assertIn("PYSPARK_ALLOW_INSECURE_GATEWAY", str(context.exception))
2397+
self.assertIn("removed in Spark 3.0", str(context.exception))
2398+
finally:
2399+
log4j.LogManager.getRootLogger().setLevel(old_level)
2400+
2401+
def test_allow_insecure_gateway_with_conf(self):
2402+
with SparkContext._lock:
2403+
SparkContext._gateway = None
2404+
SparkContext._jvm = None
2405+
gateway = _launch_gateway(insecure=True)
2406+
try:
2407+
os.environ["PYSPARK_ALLOW_INSECURE_GATEWAY"] = "1"
2408+
with SparkContext(gateway=gateway) as sc:
2409+
a = sc.accumulator(1)
2410+
rdd = sc.parallelize([1, 2, 3])
2411+
rdd.foreach(lambda x: a.add(x))
2412+
self.assertEqual(7, a.value)
2413+
finally:
2414+
os.environ.pop("PYSPARK_ALLOW_INSECURE_GATEWAY", None)
2415+
23842416

23852417
class ConfTests(unittest.TestCase):
23862418
def test_memory_conf(self):

0 commit comments

Comments
 (0)