Skip to content

Commit 309439f

Browse files
committed
Support different transports in Client
1 parent b38716e commit 309439f

File tree

19 files changed

+290
-378
lines changed

19 files changed

+290
-378
lines changed

README.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2213,11 +2213,7 @@ from mcp.client.streamable_http import streamable_http_client
22132213

22142214
async def main():
22152215
# Connect to a streamable HTTP server
2216-
async with streamable_http_client("http://localhost:8000/mcp") as (
2217-
read_stream,
2218-
write_stream,
2219-
_,
2220-
):
2216+
async with streamable_http_client("http://localhost:8000/mcp") as (read_stream, write_stream):
22212217
# Create a session using the client streams
22222218
async with ClientSession(read_stream, write_stream) as session:
22232219
# Initialize the connection
@@ -2395,7 +2391,7 @@ async def main():
23952391
)
23962392

23972393
async with httpx.AsyncClient(auth=oauth_auth, follow_redirects=True) as custom_client:
2398-
async with streamable_http_client("http://localhost:8001/mcp", http_client=custom_client) as (read, write, _):
2394+
async with streamable_http_client("http://localhost:8001/mcp", http_client=custom_client) as (read, write):
23992395
async with ClientSession(read, write) as session:
24002396
await session.initialize()
24012397

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import time
1515
import webbrowser
1616
from http.server import BaseHTTPRequestHandler, HTTPServer
17-
from typing import Any, Callable
17+
from typing import Any
1818
from urllib.parse import parse_qs, urlparse
1919

2020
import httpx
@@ -223,15 +223,15 @@ async def _default_redirect_handler(authorization_url: str) -> None:
223223
auth=oauth_auth,
224224
timeout=60.0,
225225
) as (read_stream, write_stream):
226-
await self._run_session(read_stream, write_stream, None)
226+
await self._run_session(read_stream, write_stream)
227227
else:
228228
print("📡 Opening StreamableHTTP transport connection with auth...")
229229
async with httpx.AsyncClient(auth=oauth_auth, follow_redirects=True) as custom_client:
230-
async with streamable_http_client(
231-
url=self.server_url,
232-
http_client=custom_client,
233-
) as (read_stream, write_stream, get_session_id):
234-
await self._run_session(read_stream, write_stream, get_session_id)
230+
async with streamable_http_client(url=self.server_url, http_client=custom_client) as (
231+
read_stream,
232+
write_stream,
233+
):
234+
await self._run_session(read_stream, write_stream)
235235

236236
except Exception as e:
237237
print(f"❌ Failed to connect: {e}")
@@ -243,7 +243,6 @@ async def _run_session(
243243
self,
244244
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
245245
write_stream: MemoryObjectSendStream[SessionMessage],
246-
get_session_id: Callable[[], str | None] | None = None,
247246
):
248247
"""Run the MCP session with the given streams."""
249248
print("🤝 Initializing MCP session...")
@@ -254,10 +253,6 @@ async def _run_session(
254253
print("✨ Session initialization complete!")
255254

256255
print(f"\n✅ Connected to MCP server at {self.server_url}")
257-
if get_session_id:
258-
session_id = get_session_id()
259-
if session_id:
260-
print(f"Session ID: {session_id}")
261256

262257
# Run interactive loop
263258
await self.interactive_loop()

examples/clients/simple-task-client/mcp_simple_task_client/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
async def run(url: str) -> None:
12-
async with streamable_http_client(url) as (read, write, _):
12+
async with streamable_http_client(url) as (read, write):
1313
async with ClientSession(read, write) as session:
1414
await session.initialize()
1515

examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_text(result: CallToolResult) -> str:
7373

7474

7575
async def run(url: str) -> None:
76-
async with streamable_http_client(url) as (read, write, _):
76+
async with streamable_http_client(url) as (read, write):
7777
async with ClientSession(
7878
read,
7979
write,

examples/clients/sse-polling-client/mcp_sse_polling_client/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ async def run_demo(url: str, items: int, checkpoint_every: int) -> None:
3131
print(f"Processing {items} items with checkpoints every {checkpoint_every}")
3232
print(f"{'=' * 60}\n")
3333

34-
async with streamable_http_client(url) as (read_stream, write_stream, _):
34+
async with streamable_http_client(url) as (read_stream, write_stream):
3535
async with ClientSession(read_stream, write_stream) as session:
3636
# Initialize the connection
3737
print("Initializing connection...")

examples/snippets/clients/oauth_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ async def main():
6969
)
7070

7171
async with httpx.AsyncClient(auth=oauth_auth, follow_redirects=True) as custom_client:
72-
async with streamable_http_client("http://localhost:8001/mcp", http_client=custom_client) as (read, write, _):
72+
async with streamable_http_client("http://localhost:8001/mcp", http_client=custom_client) as (read, write):
7373
async with ClientSession(read, write) as session:
7474
await session.initialize()
7575

examples/snippets/clients/streamable_basic.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010

1111
async def main():
1212
# Connect to a streamable HTTP server
13-
async with streamable_http_client("http://localhost:8000/mcp") as (
14-
read_stream,
15-
write_stream,
16-
_,
17-
):
13+
async with streamable_http_client("http://localhost:8000/mcp") as (read_stream, write_stream):
1814
# Create a session using the client streams
1915
async with ClientSession(read_stream, write_stream) as session:
2016
# Initialize the connection

src/mcp/client/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""MCP Client module."""
22

3+
from mcp.client._transport import Transport
34
from mcp.client.client import Client
45
from mcp.client.session import ClientSession
56

6-
__all__ = ["Client", "ClientSession"]
7+
__all__ = ["Client", "ClientSession", "Transport"]

src/mcp/client/_memory.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import AsyncGenerator
6-
from contextlib import asynccontextmanager
5+
from collections.abc import AsyncIterator
6+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
7+
from types import TracebackType
78
from typing import Any
89

910
import anyio
10-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111

12+
from mcp.client._transport import TransportStreams
1213
from mcp.server import Server
1314
from mcp.server.mcpserver import MCPServer
1415
from mcp.shared.memory import create_client_server_memory_streams
15-
from mcp.shared.message import SessionMessage
1616

1717

1818
class InMemoryTransport:
@@ -23,17 +23,17 @@ class InMemoryTransport:
2323
stopped when the context manager exits.
2424
2525
Example:
26-
server = MCPServer("test")
27-
transport = InMemoryTransport(server)
26+
```python
27+
from mcp.client import Client, ClientSession
28+
from mcp.server.mcpserver import MCPServer
29+
from mcp.client.memory import InMemoryTransport
2830
29-
async with transport.connect() as (read_stream, write_stream):
31+
server = MCPServer("test")
32+
async with InMemoryTransport(server) as (read_stream, write_stream):
3033
async with ClientSession(read_stream, write_stream) as session:
3134
await session.initialize()
3235
# Use the session...
33-
34-
Or more commonly, use with Client:
35-
async with Client(server) as client:
36-
result = await client.call_tool("my_tool", {...})
36+
```
3737
"""
3838

3939
def __init__(self, server: Server[Any] | MCPServer, *, raise_exceptions: bool = False) -> None:
@@ -45,26 +45,15 @@ def __init__(self, server: Server[Any] | MCPServer, *, raise_exceptions: bool =
4545
"""
4646
self._server = server
4747
self._raise_exceptions = raise_exceptions
48+
self._cm: AbstractAsyncContextManager[TransportStreams] | None = None
4849

4950
@asynccontextmanager
50-
async def connect(
51-
self,
52-
) -> AsyncGenerator[
53-
tuple[
54-
MemoryObjectReceiveStream[SessionMessage | Exception],
55-
MemoryObjectSendStream[SessionMessage],
56-
],
57-
None,
58-
]:
59-
"""Connect to the server and return streams for communication.
60-
61-
Yields:
62-
A tuple of (read_stream, write_stream) for bidirectional communication
63-
"""
51+
async def _connect(self) -> AsyncIterator[TransportStreams]:
52+
"""Connect to the server and yield streams for communication."""
6453
# Unwrap MCPServer to get underlying Server
65-
actual_server: Server[Any]
6654
if isinstance(self._server, MCPServer):
67-
actual_server = self._server._lowlevel_server # type: ignore[reportPrivateUsage]
55+
# TODO(Marcelo): Make `lowlevel_server` public.
56+
actual_server: Server[Any] = self._server._lowlevel_server # type: ignore[reportPrivateUsage]
6857
else:
6958
actual_server = self._server
7059

@@ -87,3 +76,16 @@ async def connect(
8776
yield client_read, client_write
8877
finally:
8978
tg.cancel_scope.cancel()
79+
80+
async def __aenter__(self) -> TransportStreams:
81+
"""Connect to the server and return streams for communication."""
82+
self._cm = self._connect()
83+
return await self._cm.__aenter__()
84+
85+
async def __aexit__(
86+
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
87+
) -> None:
88+
"""Close the transport and stop the server."""
89+
if self._cm is not None:
90+
await self._cm.__aexit__(exc_type, exc_val, exc_tb)
91+
self._cm = None

src/mcp/client/_transport.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Transport protocol for MCP clients."""
2+
3+
from __future__ import annotations
4+
5+
from contextlib import AbstractAsyncContextManager
6+
from typing import Protocol
7+
8+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
9+
10+
from mcp.shared.message import SessionMessage
11+
12+
TransportStreams = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
13+
14+
15+
class Transport(AbstractAsyncContextManager[TransportStreams], Protocol):
16+
"""Protocol for MCP transports.
17+
18+
A transport is an async context manager that yields read and write streams
19+
for bidirectional communication with an MCP server.
20+
"""

0 commit comments

Comments
 (0)