Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47986][CONNECT][PYTHON] Unable to create a new session when the default session is closed by the server #46435

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1829,6 +1829,7 @@ def _verify_response_integrity(
response.server_side_session_id
and response.server_side_session_id != self._server_session_id
):
self._closed = True
raise PySparkAssertionError(
"Received incorrect server side session identifier for request. "
"Please create a new Spark Session to reconnect. ("
Expand Down
20 changes: 15 additions & 5 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ def create(self) -> "SparkSession":
def getOrCreate(self) -> "SparkSession":
with SparkSession._lock:
session = SparkSession.getActiveSession()
if session is None or session.is_stopped:
session = SparkSession._default_session
if session is None or session.is_stopped:
if session is None:
session = SparkSession._get_default_session()
if session is None:
session = self.create()
self._apply_options(session)
return session
Expand Down Expand Up @@ -285,9 +285,19 @@ def _set_default_and_active_session(cls, session: "SparkSession") -> None:
if getattr(cls._active_session, "session", None) is None:
cls._active_session.session = session

@classmethod
def _get_default_session(cls) -> Optional["SparkSession"]:
s = cls._default_session
if s is not None and not s.is_stopped:
return s
return None

@classmethod
def getActiveSession(cls) -> Optional["SparkSession"]:
return getattr(cls._active_session, "session", None)
s = getattr(cls._active_session, "session", None)
if s is not None and not s.is_stopped:
return s
return None

@classmethod
def _getActiveSessionIfMatches(cls, session_id: str) -> "SparkSession":
Expand Down Expand Up @@ -315,7 +325,7 @@ def _getActiveSessionIfMatches(cls, session_id: str) -> "SparkSession":
def active(cls) -> "SparkSession":
session = cls.getActiveSession()
if session is None:
session = cls._default_session
session = cls._get_default_session()
if session is None:
raise PySparkRuntimeError(
error_class="NO_ACTIVE_OR_DEFAULT_SESSION",
Expand Down
16 changes: 15 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def toChannel(self):
session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create()
session.sql("select 1 + 1")

def test_reset_when_server_session_changes(self):
def test_reset_when_server_and_client_sessionids_mismatch(self):
session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate()
# run a simple query so the session id is synchronized.
session.range(3).collect()
Expand All @@ -256,6 +256,20 @@ def test_reset_when_server_session_changes(self):
session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate()
session.range(3).collect()

def test_reset_when_server_session_id_mismatch(self):
session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate()
# run a simple query so the session id is synchronized.
session.range(3).collect()

# trigger a mismatch
session._client._server_session_id = str(uuid.uuid4())
with self.assertRaises(SparkConnectException):
session.range(3).collect()

# assert that getOrCreate() generates a new session
session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate()
session.range(3).collect()


class SparkConnectSessionWithOptionsTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/sql/tests/connect/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,34 @@ def test_session_create_sets_active_session(self):
self.assertIs(session, session2)
session.stop()

def test_active_session_expires_when_client_closes(self):
s1 = RemoteSparkSession.builder.remote("sc://other").getOrCreate()
s2 = RemoteSparkSession.getActiveSession()

self.assertIs(s1, s2)

# We don't call close() to avoid executing ExecutePlanResponseReattachableIterator
s1._client._closed = True

self.assertIsNone(RemoteSparkSession.getActiveSession())
s3 = RemoteSparkSession.builder.remote("sc://other").getOrCreate()

self.assertIsNot(s1, s3)

def test_default_session_expires_when_client_closes(self):
s1 = RemoteSparkSession.builder.remote("sc://other").getOrCreate()
s2 = RemoteSparkSession.getDefaultSession()

self.assertIs(s1, s2)

# We don't call close() to avoid executing ExecutePlanResponseReattachableIterator
s1._client._closed = True

self.assertIsNone(RemoteSparkSession.getDefaultSession())
s3 = RemoteSparkSession.builder.remote("sc://other").getOrCreate()

self.assertIsNot(s1, s3)


class JobCancellationTests(ReusedConnectTestCase):
def test_tags(self):
Expand Down