Skip to content

[SPARK-26019][PYSPARK] Allow insecure py4j gateways #23337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,17 @@ private[spark] object PythonGatewayServer extends Logging {
// with the same secret, in case the app needs callbacks from the JVM to the underlying
// python processes.
val localhost = InetAddress.getLoopbackAddress()
val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
.authToken(secret)
val builder = new GatewayServer.GatewayServerBuilder()
.javaPort(0)
.javaAddress(localhost)
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
.build()
if (sys.env.getOrElse("_PYSPARK_CREATE_INSECURE_GATEWAY", "0") != "1") {
builder.authToken(secret)
} else {
assert(sys.env.getOrElse("SPARK_TESTING", "0") == "1",
"Creating insecure Java gateways only allowed for testing")
}
val gatewayServer: GatewayServer = builder.build()

gatewayServer.start()
val boundPort: Int = gatewayServer.getListeningPort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,10 @@ private[spark] class PythonAccumulatorV2(
if (socket == null || socket.isClosed) {
socket = new Socket(serverHost, serverPort)
logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort")
// send the secret just for the initial authentication when opening a new connection
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
if (secretToken != null) {
// send the secret just for the initial authentication when opening a new connection
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
}
}
socket
}
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,10 @@ def authenticate_and_accum_updates():
raise Exception(
"The value of the provided token to the AccumulatorServer is not correct.")

# first we keep polling till we've received the authentication token
poll(authenticate_and_accum_updates)
# now we've authenticated, don't need to check for the token anymore
if auth_token is not None:
# first we keep polling till we've received the authentication token
poll(authenticate_and_accum_updates)
# now we've authenticated if needed, don't need to check for the token anymore
poll(accum_updates)


Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
ValueError:...
"""
self._callsite = first_spark_call() or CallSite(None, None, None)
if gateway is not None and gateway.gateway_parameters.auth_token is None:
allow_insecure_env = os.environ.get("PYSPARK_ALLOW_INSECURE_GATEWAY", "0")
if allow_insecure_env == "1" or allow_insecure_env.lower() == "true":
warnings.warn(
"You are passing in an insecure Py4j gateway. This "
"presents a security risk, and will be completely forbidden in Spark 3.0")
else:
raise ValueError(
"You are trying to pass an insecure Py4j gateway to Spark. This"
" presents a security risk. If you are sure you understand and accept this"
" risk, you can set the environment variable"
" 'PYSPARK_ALLOW_INSECURE_GATEWAY=1', but"
" note this option will be removed in Spark 3.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. Honestly, I still think insecure is a misusage of Spark and It should be removed. I'm going to merge this as an effort to help upgrading Spark easier in other projects like Zeppelin.


SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
Expand Down
23 changes: 19 additions & 4 deletions python/pyspark/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,20 @@ def launch_gateway(conf=None):
"""
launch jvm gateway
:param conf: spark configuration passed to spark-submit
:return:
:return: a JVM gateway
"""
return _launch_gateway(conf)


def _launch_gateway(conf=None, insecure=False):
"""
launch jvm gateway
:param conf: spark configuration passed to spark-submit
:param insecure: True to create an insecure gateway; only for testing
:return: a JVM gateway
"""
if insecure and os.environ.get("SPARK_TESTING", "0") != "1":
raise ValueError("creating insecure gateways is only for testing")
if "PYSPARK_GATEWAY_PORT" in os.environ:
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
Expand Down Expand Up @@ -74,6 +86,8 @@ def launch_gateway(conf=None):

env = dict(os.environ)
env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
if insecure:
env["_PYSPARK_CREATE_INSECURE_GATEWAY"] = "1"

# Launch the Java gateway.
# We open a pipe to stdin so that the Java gateway can die when the pipe is broken
Expand Down Expand Up @@ -116,9 +130,10 @@ def killChild():
atexit.register(killChild)

# Connect to the gateway
gateway = JavaGateway(
gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
auto_convert=True))
gateway_params = GatewayParameters(port=gateway_port, auto_convert=True)
if not insecure:
gateway_params.auth_token = gateway_secret
gateway = JavaGateway(gateway_parameters=gateway_params)

# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
Expand Down
32 changes: 32 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from pyspark import keyword_only
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.java_gateway import _launch_gateway
from pyspark.rdd import RDD
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
Expand Down Expand Up @@ -2381,6 +2382,37 @@ def test_startTime(self):
with SparkContext() as sc:
self.assertGreater(sc.startTime, 0)

def test_forbid_insecure_gateway(self):
# By default, we fail immediately if you try to create a SparkContext
# with an insecure gateway
gateway = _launch_gateway(insecure=True)
log4j = gateway.jvm.org.apache.log4j
old_level = log4j.LogManager.getRootLogger().getLevel()
try:
log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL)
with self.assertRaises(Exception) as context:
SparkContext(gateway=gateway)
self.assertIn("insecure Py4j gateway", str(context.exception))
self.assertIn("PYSPARK_ALLOW_INSECURE_GATEWAY", str(context.exception))
self.assertIn("removed in Spark 3.0", str(context.exception))
finally:
log4j.LogManager.getRootLogger().setLevel(old_level)

def test_allow_insecure_gateway_with_conf(self):
with SparkContext._lock:
SparkContext._gateway = None
SparkContext._jvm = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part of the test really bothers me, so I'd like to explain to reviewers. Without this, the test passes -- but it passes even without the changes to the main code! Or rather, it only passes when its run as part of the entire suite, it would fail when run individually.

What's happening is that SparkContext._gateway and SparkContext._jvm don't get reset by most tests (eg., they are not reset in sc.stop()), so a test running before this one will set those variables, and then this test will end up holding on to a gateway which does have the auth_token set, and so the accumulator server would still work.

Now that in itself sounds crazy to me, and seems like a problem for things like Zeppelin. I tried just adding these two lines into sc.stop(), but then when I ran all the tests, I got a lot of java.io.IOException: error=23, Too many open files in system. So maybe something else is not getting properly cleaned up properly in the pyspark tests?

I was hoping somebody else might have some ideas about what is going on or if there is a better way to do this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's really answering my question. I don't have a problem calling start & stop, I'm wondering why SparkContext._gateway and SparkContext._jvm don't get reset in sc.stop(). This means that if you have multiple spark contexts in one python session (as we do in our tests), they all reuse the same gateway:

if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway(conf)
SparkContext._jvm = SparkContext._gateway.jvm

for normal use of spark, that's not a problem; but it seems like it would be a problem (a) in our tests and (b) for systems like zeppelin, that might have multiple spark contexts over the lifetime of the python session (I assume, anyway ...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, gotya. In that case, we could consider simply move the test class to the top in a separate class as well but .. yes I suspect tests depending on its order isn't a great idea in a way as well. I'm okay as long as the tests pass. I can take a separate look for this later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this logic is sort rough, in spark-testing-base for example in between tests where folks do not intend to reuse the same Spark context we also clear some extra properties (although we do reuse the gateway). I think for environments where folks want multiple SparkContexts from Python on the same machine they end up using multiple Python processes anyways.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok its good to get some confirmation of this weird behavior ... but I feel like I still don't understand why we don't reset SparkContext._gateway and SparkContext._jvm in sc.stop(); and why when I tried to make that change, I hit all those errors. if nothing else, any chance this is related to general flakiness in pyspark tests?

gateway = _launch_gateway(insecure=True)
try:
os.environ["PYSPARK_ALLOW_INSECURE_GATEWAY"] = "1"
with SparkContext(gateway=gateway) as sc:
a = sc.accumulator(1)
rdd = sc.parallelize([1, 2, 3])
rdd.foreach(lambda x: a.add(x))
self.assertEqual(7, a.value)
finally:
os.environ.pop("PYSPARK_ALLOW_INSECURE_GATEWAY", None)


class ConfTests(unittest.TestCase):
def test_memory_conf(self):
Expand Down