Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ classifiers = [
]
dependencies = [
"anyio>=4.5",
"exceptiongroup>=1.2.0; python_version < '3.11'",
"httpx>=0.27.1",
"httpx-sse>=0.4",
"pydantic>=2.12.0",
Expand Down
41 changes: 28 additions & 13 deletions src/mcp/client/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
from collections.abc import AsyncIterator
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from types import TracebackType
from typing import Any
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from builtins import BaseExceptionGroup
else:
try:
from builtins import BaseExceptionGroup
except ImportError:
from exceptiongroup import BaseExceptionGroup

import anyio

Expand Down Expand Up @@ -49,20 +57,27 @@ async def _connect(self) -> AsyncIterator[TransportStreams]:
server_read, server_write = server_streams

async with anyio.create_task_group() as tg:
# Start server in background
tg.start_soon(
lambda: actual_server.run(
server_read,
server_write,
actual_server.create_initialization_options(),
raise_exceptions=self._raise_exceptions,
try:
# Start server in background
tg.start_soon(
lambda: actual_server.run(
server_read,
server_write,
actual_server.create_initialization_options(),
raise_exceptions=self._raise_exceptions,
)
)
)

try:
yield client_read, client_write
finally:
tg.cancel_scope.cancel()
try:
yield client_read, client_write
finally:
tg.cancel_scope.cancel()
except BaseExceptionGroup as e:
from mcp.shared.exceptions import unwrap_task_group_exception

real_exc = unwrap_task_group_exception(e)
if real_exc is not e:
raise real_exc

async def __aenter__(self) -> TransportStreams:
"""Connect to the server and return streams for communication."""
Expand Down
21 changes: 18 additions & 3 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
from collections.abc import Callable
from dataclasses import dataclass
from types import TracebackType
from typing import Any, TypeAlias
from typing import TYPE_CHECKING, Any, TypeAlias

if TYPE_CHECKING:
from builtins import BaseExceptionGroup
else:
try:
from builtins import BaseExceptionGroup
except ImportError:
from exceptiongroup import BaseExceptionGroup

import anyio
import httpx
Expand Down Expand Up @@ -167,8 +175,15 @@ async def __aexit__(

# Concurrently close session stacks.
async with anyio.create_task_group() as tg:
for exit_stack in self._session_exit_stacks.values():
tg.start_soon(exit_stack.aclose)
try:
for exit_stack in self._session_exit_stacks.values():
tg.start_soon(exit_stack.aclose)
except BaseExceptionGroup as e:
from mcp.shared.exceptions import unwrap_task_group_exception

real_exc = unwrap_task_group_exception(e)
if real_exc is not e:
raise real_exc

@property
def sessions(self) -> list[mcp.ClientSession]:
Expand Down
16 changes: 15 additions & 1 deletion src/mcp/client/sse.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import logging
from collections.abc import Callable
from contextlib import asynccontextmanager
from typing import Any
from typing import TYPE_CHECKING, Any
from urllib.parse import parse_qs, urljoin, urlparse

if TYPE_CHECKING:
from builtins import BaseExceptionGroup
else:
try:
from builtins import BaseExceptionGroup
except ImportError:
from exceptiongroup import BaseExceptionGroup

import anyio
import httpx
from anyio.abc import TaskStatus
Expand Down Expand Up @@ -157,6 +165,12 @@ async def post_writer(endpoint_url: str):
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
except BaseExceptionGroup as e:
from mcp.shared.exceptions import unwrap_task_group_exception

real_exc = unwrap_task_group_exception(e)
if real_exc is not e:
raise real_exc
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
73 changes: 44 additions & 29 deletions src/mcp/client/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Literal, TextIO
from typing import TYPE_CHECKING, Literal, TextIO

if TYPE_CHECKING:
from builtins import BaseExceptionGroup
else:
try:
from builtins import BaseExceptionGroup
except ImportError:
from exceptiongroup import BaseExceptionGroup

import anyio
import anyio.lowlevel
Expand Down Expand Up @@ -178,37 +186,44 @@ async def stdin_writer():
await anyio.lowlevel.checkpoint()

async with anyio.create_task_group() as tg, process:
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)
try:
yield read_stream, write_stream
finally:
# MCP spec: stdio shutdown sequence
# 1. Close input stream to server
# 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time
# 3. Send SIGKILL if still not exited
if process.stdin: # pragma: no branch
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)
try:
yield read_stream, write_stream
finally:
# MCP spec: stdio shutdown sequence
# 1. Close input stream to server
# 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time
# 3. Send SIGKILL if still not exited
if process.stdin: # pragma: no branch
try:
await process.stdin.aclose()
except Exception: # pragma: no cover
# stdin might already be closed, which is fine
pass

try:
await process.stdin.aclose()
except Exception: # pragma: no cover
# stdin might already be closed, which is fine
# Give the process time to exit gracefully after stdin closes
with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT):
await process.wait()
except TimeoutError:
# Process didn't exit from stdin closure, use platform-specific termination
# which handles SIGTERM -> SIGKILL escalation
await _terminate_process_tree(process)
except ProcessLookupError: # pragma: no cover
# Process already exited, which is fine
pass

try:
# Give the process time to exit gracefully after stdin closes
with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT):
await process.wait()
except TimeoutError:
# Process didn't exit from stdin closure, use platform-specific termination
# which handles SIGTERM -> SIGKILL escalation
await _terminate_process_tree(process)
except ProcessLookupError: # pragma: no cover
# Process already exited, which is fine
pass
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
except BaseExceptionGroup as e:
from mcp.shared.exceptions import unwrap_task_group_exception

real_exc = unwrap_task_group_exception(e)
if real_exc is not e:
raise real_exc


def _get_executable_command(command: str) -> str:
Expand Down
16 changes: 16 additions & 0 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from builtins import BaseExceptionGroup
else:
try:
from builtins import BaseExceptionGroup
except ImportError:
from exceptiongroup import BaseExceptionGroup

import anyio
import httpx
Expand Down Expand Up @@ -574,6 +583,13 @@ def start_get_stream() -> None:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
except BaseExceptionGroup as e:
# Unwrap ExceptionGroup to get only the real error
from mcp.shared.exceptions import unwrap_task_group_exception

real_exc = unwrap_task_group_exception(e)
if real_exc is not e:
raise real_exc
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
30 changes: 23 additions & 7 deletions src/mcp/client/websocket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import json
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from builtins import BaseExceptionGroup
else:
try:
from builtins import BaseExceptionGroup
except ImportError:
from exceptiongroup import BaseExceptionGroup

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
Expand Down Expand Up @@ -69,12 +78,19 @@ async def ws_writer():
await ws.send(json.dumps(msg_dict))

async with anyio.create_task_group() as tg:
# Start reader and writer tasks
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)
try:
# Start reader and writer tasks
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)

# Yield the receive/send streams
yield (read_stream, write_stream)

# Yield the receive/send streams
yield (read_stream, write_stream)
# Once the caller's 'async with' block exits, we shut down
tg.cancel_scope.cancel()
except BaseExceptionGroup as e:
from mcp.shared.exceptions import unwrap_task_group_exception

# Once the caller's 'async with' block exits, we shut down
tg.cancel_scope.cancel()
real_exc = unwrap_task_group_exception(e)
if real_exc is not e:
raise real_exc
55 changes: 35 additions & 20 deletions src/mcp/server/experimental/task_result_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
"""

import logging
from typing import Any
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from builtins import BaseExceptionGroup
else:
try:
from builtins import BaseExceptionGroup
except ImportError:
from exceptiongroup import BaseExceptionGroup

import anyio

Expand Down Expand Up @@ -163,25 +171,32 @@ async def _wait_for_task_update(self, task_id: str) -> None:
Races between store update and queue message - first one wins.
"""
async with anyio.create_task_group() as tg:

async def wait_for_store() -> None:
try:
await self._store.wait_for_update(task_id)
except Exception:
pass
finally:
tg.cancel_scope.cancel()

async def wait_for_queue() -> None:
try:
await self._queue.wait_for_message(task_id)
except Exception:
pass
finally:
tg.cancel_scope.cancel()

tg.start_soon(wait_for_store)
tg.start_soon(wait_for_queue)
try:

async def wait_for_store() -> None:
try:
await self._store.wait_for_update(task_id)
except Exception:
pass
finally:
tg.cancel_scope.cancel()

async def wait_for_queue() -> None:
try:
await self._queue.wait_for_message(task_id)
except Exception:
pass
finally:
tg.cancel_scope.cancel()

tg.start_soon(wait_for_store)
tg.start_soon(wait_for_queue)
except BaseExceptionGroup as e:
from mcp.shared.exceptions import unwrap_task_group_exception

real_exc = unwrap_task_group_exception(e)
if real_exc is not e:
raise real_exc

def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool:
"""Route a response back to the waiting resolver.
Expand Down
Loading
Loading