Skip to content

Commit 9eae96a

Browse files
Add get_server_capabilities() to ClientSession (#1588)
1 parent b7b0f8e commit 9eae96a

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

src/mcp/client/session.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
self._logging_callback = logging_callback or _default_logging_callback
135135
self._message_handler = message_handler or _default_message_handler
136136
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
137+
self._server_capabilities: types.ServerCapabilities | None = None
137138

138139
async def initialize(self) -> types.InitializeResult:
139140
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
@@ -170,10 +171,19 @@ async def initialize(self) -> types.InitializeResult:
170171
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
171172
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
172173

174+
self._server_capabilities = result.capabilities
175+
173176
await self.send_notification(types.ClientNotification(types.InitializedNotification()))
174177

175178
return result
176179

180+
def get_server_capabilities(self) -> types.ServerCapabilities | None:
181+
"""Return the server capabilities received during initialization.
182+
183+
Returns None if the session has not been initialized yet.
184+
"""
185+
return self._server_capabilities
186+
177187
async def send_ping(self) -> types.EmptyResult:
178188
"""Send a ping request."""
179189
return await self.send_request(

tests/client/test_session.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,78 @@ async def mock_server():
504504
assert received_capabilities.roots.listChanged is True
505505

506506

507+
@pytest.mark.anyio
508+
async def test_get_server_capabilities():
509+
"""Test that get_server_capabilities returns None before init and capabilities after"""
510+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
511+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
512+
513+
expected_capabilities = ServerCapabilities(
514+
logging=types.LoggingCapability(),
515+
prompts=types.PromptsCapability(listChanged=True),
516+
resources=types.ResourcesCapability(subscribe=True, listChanged=True),
517+
tools=types.ToolsCapability(listChanged=False),
518+
)
519+
520+
async def mock_server():
521+
session_message = await client_to_server_receive.receive()
522+
jsonrpc_request = session_message.message
523+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
524+
request = ClientRequest.model_validate(
525+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
526+
)
527+
assert isinstance(request.root, InitializeRequest)
528+
529+
result = ServerResult(
530+
InitializeResult(
531+
protocolVersion=LATEST_PROTOCOL_VERSION,
532+
capabilities=expected_capabilities,
533+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
534+
)
535+
)
536+
537+
async with server_to_client_send:
538+
await server_to_client_send.send(
539+
SessionMessage(
540+
JSONRPCMessage(
541+
JSONRPCResponse(
542+
jsonrpc="2.0",
543+
id=jsonrpc_request.root.id,
544+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
545+
)
546+
)
547+
)
548+
)
549+
await client_to_server_receive.receive()
550+
551+
async with (
552+
ClientSession(
553+
server_to_client_receive,
554+
client_to_server_send,
555+
) as session,
556+
anyio.create_task_group() as tg,
557+
client_to_server_send,
558+
client_to_server_receive,
559+
server_to_client_send,
560+
server_to_client_receive,
561+
):
562+
assert session.get_server_capabilities() is None
563+
564+
tg.start_soon(mock_server)
565+
await session.initialize()
566+
567+
capabilities = session.get_server_capabilities()
568+
assert capabilities is not None
569+
assert capabilities == expected_capabilities
570+
assert capabilities.logging is not None
571+
assert capabilities.prompts is not None
572+
assert capabilities.prompts.listChanged is True
573+
assert capabilities.resources is not None
574+
assert capabilities.resources.subscribe is True
575+
assert capabilities.tools is not None
576+
assert capabilities.tools.listChanged is False
577+
578+
507579
@pytest.mark.anyio
508580
@pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}])
509581
async def test_client_tool_call_with_meta(meta: dict[str, Any] | None):

0 commit comments

Comments
 (0)