22
33from __future__ import annotations
44
5- from collections .abc import AsyncGenerator
6- from contextlib import asynccontextmanager
5+ from collections .abc import AsyncIterator
6+ from contextlib import AbstractAsyncContextManager , asynccontextmanager
7+ from types import TracebackType
78from typing import Any
89
910import anyio
10- from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
1111
12+ from mcp .client ._transport import TransportStreams
1213from mcp .server import Server
1314from mcp .server .mcpserver import MCPServer
1415from mcp .shared .memory import create_client_server_memory_streams
15- from mcp .shared .message import SessionMessage
1616
1717
1818class InMemoryTransport :
@@ -23,17 +23,17 @@ class InMemoryTransport:
2323 stopped when the context manager exits.
2424
2525 Example:
26- server = MCPServer("test")
27- transport = InMemoryTransport(server)
26+ ```python
27+ from mcp.client import Client, ClientSession
28+ from mcp.server.mcpserver import MCPServer
29+ from mcp.client.memory import InMemoryTransport
2830
29- async with transport.connect() as (read_stream, write_stream):
31+ server = MCPServer("test")
32+ async with InMemoryTransport(server) as (read_stream, write_stream):
3033 async with ClientSession(read_stream, write_stream) as session:
3134 await session.initialize()
3235 # Use the session...
33-
34- Or more commonly, use with Client:
35- async with Client(server) as client:
36- result = await client.call_tool("my_tool", {...})
36+ ```
3737 """
3838
3939 def __init__ (self , server : Server [Any ] | MCPServer , * , raise_exceptions : bool = False ) -> None :
@@ -45,26 +45,15 @@ def __init__(self, server: Server[Any] | MCPServer, *, raise_exceptions: bool =
4545 """
4646 self ._server = server
4747 self ._raise_exceptions = raise_exceptions
48+ self ._cm : AbstractAsyncContextManager [TransportStreams ] | None = None
4849
4950 @asynccontextmanager
50- async def connect (
51- self ,
52- ) -> AsyncGenerator [
53- tuple [
54- MemoryObjectReceiveStream [SessionMessage | Exception ],
55- MemoryObjectSendStream [SessionMessage ],
56- ],
57- None ,
58- ]:
59- """Connect to the server and return streams for communication.
60-
61- Yields:
62- A tuple of (read_stream, write_stream) for bidirectional communication
63- """
51+ async def _connect (self ) -> AsyncIterator [TransportStreams ]:
52+ """Connect to the server and yield streams for communication."""
6453 # Unwrap MCPServer to get underlying Server
65- actual_server : Server [Any ]
6654 if isinstance (self ._server , MCPServer ):
67- actual_server = self ._server ._lowlevel_server # type: ignore[reportPrivateUsage]
55+ # TODO(Marcelo): Make `lowlevel_server` public.
56+ actual_server : Server [Any ] = self ._server ._lowlevel_server # type: ignore[reportPrivateUsage]
6857 else :
6958 actual_server = self ._server
7059
@@ -87,3 +76,16 @@ async def connect(
8776 yield client_read , client_write
8877 finally :
8978 tg .cancel_scope .cancel ()
79+
80+ async def __aenter__ (self ) -> TransportStreams :
81+ """Connect to the server and return streams for communication."""
82+ self ._cm = self ._connect ()
83+ return await self ._cm .__aenter__ ()
84+
85+ async def __aexit__ (
86+ self , exc_type : type [BaseException ] | None , exc_val : BaseException | None , exc_tb : TracebackType | None
87+ ) -> None :
88+ """Close the transport and stop the server."""
89+ if self ._cm is not None :
90+ await self ._cm .__aexit__ (exc_type , exc_val , exc_tb )
91+ self ._cm = None
0 commit comments