Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1930,6 +1930,58 @@ def release_session(self) -> None:
except Exception as error:
self._handle_error(error)

def _get_operation_statuses(
self,
operation_ids: Optional[List[str]] = None,
operation_extensions: Optional[List[any_pb2.Any]] = None,
request_extensions: Optional[List[any_pb2.Any]] = None,
) -> "pb2.GetStatusResponse":
"""
Get status of operations in the session.

Parameters
----------
operation_ids : list of str, optional
List of operation IDs to get status for.
If None or empty, returns status of all operations in the session.
operation_extensions : list of google.protobuf.any_pb2.Any, optional
Per-operation extension messages to include in the OperationStatusRequest to request
additional per-operation information.
request_extensions : list of google.protobuf.any_pb2.Any, optional
Request-level extension messages to include in the GetStatusRequest.

Returns
-------
pb2.GetStatusResponse
The full GetStatusResponse, including operation_statuses and any extensions.
"""
req = pb2.GetStatusRequest()
req.session_id = self._session_id
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id
if self._server_session_id:
req.client_observed_server_side_session_id = self._server_session_id

req.operation_status.SetInParent()

if operation_ids:
req.operation_status.operation_ids.extend(operation_ids)
if operation_extensions:
req.operation_status.extensions.extend(operation_extensions)
if request_extensions:
req.extensions.extend(request_extensions)

try:
for attempt in self._retrying():
with attempt:
resp = self._stub.GetStatus(req, metadata=self._builder.metadata())
self._verify_response_integrity(resp)
return resp
raise SparkConnectException("Invalid state during retry exception handling.")
except Exception as error:
self._handle_error(error)

def add_tag(self, tag: str) -> None:
self._throw_if_invalid_tag(tag)
if not hasattr(self.thread_local, "tags"):
Expand Down Expand Up @@ -2169,6 +2221,7 @@ def _verify_response_integrity(
pb2.AnalyzePlanResponse,
pb2.FetchErrorDetailsResponse,
pb2.ReleaseSessionResponse,
pb2.GetStatusResponse,
],
) -> None:
"""
Expand Down
174 changes: 173 additions & 1 deletion python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,25 @@ class MockService:

req: Optional[proto.ExecutePlanRequest]

def __init__(self, session_id: str):
OperationStatus = proto.GetStatusResponse.OperationStatus
DEFAULT_OPERATION_STATUSES = [
OperationStatus(
operation_id="default-op-1",
state=OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED,
),
OperationStatus(
operation_id="default-op-2",
state=OperationStatus.OperationState.OPERATION_STATE_RUNNING,
),
]

def __init__(self, session_id: str, operation_statuses=None):
self._session_id = session_id
self.req = None
self.client_user_context_extensions = []
if operation_statuses is None:
operation_statuses = self.DEFAULT_OPERATION_STATUSES
self._operation_statuses = {s.operation_id: s for s in operation_statuses}

def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
self.req = req
Expand Down Expand Up @@ -191,6 +206,45 @@ def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata):
resp.semantic_hash.result = 12345
return resp

def GetStatus(self, req: proto.GetStatusRequest, metadata):
self.req = req
self.client_user_context_extensions = list(req.user_context.extensions)
self.received_custom_server_session_id = req.client_observed_server_side_session_id
resp = proto.GetStatusResponse(session_id=self._session_id)

# Echo top-level request extensions back in the response
if req.extensions:
resp.extensions.extend(req.extensions)

if not req.HasField("operation_status"):
return resp

# Collect operation-status-level extensions from the request to echo back
op_status_extensions = list(req.operation_status.extensions)

requested_ids = list(req.operation_status.operation_ids)
if len(requested_ids) == 0:
# Empty list — return all statuses
resp.operation_statuses.extend(self._operation_statuses.values())
return resp

OperationStatus = proto.GetStatusResponse.OperationStatus
for op_id in requested_ids:
status = self._operation_statuses.get(op_id)
if status is not None:
op_status = OperationStatus(
operation_id=status.operation_id,
state=status.state,
)
else:
op_status = OperationStatus(
operation_id=op_id,
state=OperationStatus.OperationState.OPERATION_STATE_UNKNOWN,
)
op_status.extensions.extend(op_status_extensions)
resp.operation_statuses.append(op_status)
return resp

# The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster)
# and it blocks the test process exiting because it is registered as the atexit handler
# in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test.
Expand Down Expand Up @@ -396,6 +450,124 @@ def test_custom_operation_id(self):
for resp in client._stub.ExecutePlan(req, metadata=None):
assert resp.operation_id == "10a4c38e-7e87-40ee-9d6f-60ff0751e63b"

def test_get_operations_statuses_all(self):
"""Test get_operations_statuses returns all operation statuses when no IDs specified."""
OperationStatus = proto.GetStatusResponse.OperationStatus
client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False)
mock = MockService(client._session_id)
client._stub = mock

resp = client._get_operation_statuses()
result = list(resp.operation_statuses)
self.assertEqual(len(result), 2)
status_map = {s.operation_id: s.state for s in result}
self.assertEqual(
status_map["default-op-1"],
OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED,
)
self.assertEqual(
status_map["default-op-2"],
OperationStatus.OperationState.OPERATION_STATE_RUNNING,
)

def test_get_operations_statuses_specific_ids(self):
"""Test get_operations_statuses filters by specific operation IDs."""
OperationStatus = proto.GetStatusResponse.OperationStatus
client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False)
mock = MockService(client._session_id)
client._stub = mock

resp = client._get_operation_statuses(operation_ids=["default-op-1", "unknown-op"])
result = list(resp.operation_statuses)
self.assertEqual(len(result), 2)
status_map = {s.operation_id: s.state for s in result}
self.assertEqual(
status_map["default-op-1"],
OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED,
)
self.assertEqual(
status_map["unknown-op"],
OperationStatus.OperationState.OPERATION_STATE_UNKNOWN,
)
# Verify the request included the operation IDs
self.assertEqual(
set(mock.req.operation_status.operation_ids), {"default-op-1", "unknown-op"}
)

def test_get_operations_statuses_empty(self):
"""Test get_operations_statuses returns empty list when no operations exist."""
client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False)
mock = MockService(client._session_id, operation_statuses=[])
client._stub = mock

resp = client._get_operation_statuses()
self.assertEqual(len(list(resp.operation_statuses)), 0)

def test_get_operations_statuses_with_operation_extensions(self):
"""Test get_operations_statuses passes operation-level extensions and echoes them back per operation."""
from google.protobuf import any_pb2, wrappers_pb2

client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False)
mock = MockService(client._session_id)
client._stub = mock

op_ext = any_pb2.Any()
op_ext.Pack(wrappers_pb2.StringValue(value="op_extension"))

resp = client._get_operation_statuses(
operation_ids=["default-op-1", "default-op-2"],
operation_extensions=[op_ext],
)
result = list(resp.operation_statuses)
self.assertEqual(len(result), 2)
self.assertEqual({s.operation_id for s in result}, {"default-op-1", "default-op-2"})

# Verify operation-level extensions were included in the request
self.assertEqual(len(mock.req.operation_status.extensions), 1)
unpacked = wrappers_pb2.StringValue()
mock.req.operation_status.extensions[0].Unpack(unpacked)
self.assertEqual(unpacked.value, "op_extension")

# Verify operation-level extensions were echoed back per operation
for op_status in result:
self.assertEqual(len(op_status.extensions), 1)
echoed = wrappers_pb2.StringValue()
op_status.extensions[0].Unpack(echoed)
self.assertEqual(echoed.value, "op_extension")

def test_get_operations_statuses_with_request_extensions(self):
"""Test _get_operation_statuses sends request-level extensions and echoes them back."""
from google.protobuf import any_pb2, wrappers_pb2

client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False)
mock = MockService(client._session_id)
client._stub = mock

req_ext = any_pb2.Any()
req_ext.Pack(wrappers_pb2.StringValue(value="request_extension"))

resp = client._get_operation_statuses(
operation_ids=["default-op-1"],
request_extensions=[req_ext],
)

# Verify the operation status is returned
result = list(resp.operation_statuses)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].operation_id, "default-op-1")

# Verify request-level extensions were included in the request
self.assertEqual(len(mock.req.extensions), 1)
unpacked = wrappers_pb2.StringValue()
mock.req.extensions[0].Unpack(unpacked)
self.assertEqual(unpacked.value, "request_extension")

# Verify request-level extensions were echoed back in the response
self.assertEqual(len(resp.extensions), 1)
resp_echoed = wrappers_pb2.StringValue()
resp.extensions[0].Unpack(resp_echoed)
self.assertEqual(resp_echoed.value, "request_extension")


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectClientReattachTestCase(unittest.TestCase):
Expand Down
Loading