Skip to content

Commit

Permalink
[SPARK-47986][CONNECT][PYTHON] Unable to create a new session when th…
Browse files Browse the repository at this point in the history
…e default session is closed by the server

### What changes were proposed in this pull request?

When the server closes a session, usually after a cluster restart,
the client is unaware of this until it receives an error.

At this point, the client in unable to create a new session to the
same connect endpoint, since the stale session is still recorded
as the active and default session.

With this change, when the server communicates that the session
has changed via a GRPC error, the session and the respective client
are marked as stale. A new default connection can be created
via the session builder.

### Why are the changes needed?

See section above.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Attached unit tests

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#46221 from nija-at/session-expires.

Authored-by: Niranjan Jayakar <nija@databricks.com>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
nija-at authored and zhengruifeng committed Apr 26, 2024
1 parent b0e03a1 commit 7d04d0f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
3 changes: 3 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,6 +1763,9 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
info = error_details_pb2.ErrorInfo()
d.Unpack(info)

if info.metadata["errorClass"] == "INVALID_HANDLE.SESSION_CHANGED":
self._closed = True

raise convert_exception(
info,
status.message,
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ def create(self) -> "SparkSession":
def getOrCreate(self) -> "SparkSession":
with SparkSession._lock:
session = SparkSession.getActiveSession()
if session is None:
if session is None or session.is_stopped:
session = SparkSession._default_session
if session is None:
if session is None or session.is_stopped:
session = self.create()
self._apply_options(session)
return session
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,20 @@ def toChannel(self):
session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create()
session.sql("select 1 + 1")

def test_reset_when_server_session_changes(self):
session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate()
# run a simple query so the session id is synchronized.
session.range(3).collect()

# trigger a mismatch between client session id and server session id.
session._client._session_id = str(uuid.uuid4())
with self.assertRaises(SparkConnectException):
session.range(3).collect()

# assert that getOrCreate() generates a new session
session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate()
session.range(3).collect()


class SparkConnectSessionWithOptionsTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit 7d04d0f

Please sign in to comment.