Skip to content

Fix issue #428: ClientSession.initialize gets stuck if the MCP server process exits #434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/mcp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .client.session import ClientSession
from .client.stdio import StdioServerParameters, stdio_client
from .client.stdio import (
ProcessTerminatedEarlyError,
StdioServerParameters,
stdio_client,
)
from .server.session import ServerSession
from .server.stdio import stdio_server
from .shared.exceptions import McpError
Expand Down Expand Up @@ -101,6 +105,7 @@
"ServerResult",
"ServerSession",
"SetLevelRequest",
"ProcessTerminatedEarlyError",
"StdioServerParameters",
"StopReason",
"SubscribeRequest",
Expand Down
62 changes: 58 additions & 4 deletions src/mcp/client/stdio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
terminate_windows_process,
)

__all__ = [
"ProcessTerminatedEarlyError",
"StdioServerParameters",
"stdio_client",
"get_default_environment",
]

# Environment variables to inherit by default
DEFAULT_INHERITED_ENV_VARS = (
[
Expand All @@ -38,6 +45,13 @@
)


class ProcessTerminatedEarlyError(Exception):
"""Raised when a process terminates unexpectedly."""

def __init__(self, message: str):
super().__init__(message)


def get_default_environment() -> dict[str, str]:
"""
Returns a default environment object including only environment variables deemed
Expand Down Expand Up @@ -163,20 +177,60 @@ async def stdin_writer():
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()

process_error: str | None = None

async with (
anyio.create_task_group() as tg,
process,
):
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)

# Add a task to monitor the process and detect early termination
async def monitor_process():
nonlocal process_error
try:
await process.wait()
# Only consider it an error if the process exits with a non-zero code
# during normal operation (not when we explicitly terminate it)
if process.returncode != 0 and not tg.cancel_scope.cancel_called:
process_error = f"Process exited with code {process.returncode}."
# Cancel the task group to stop other tasks
tg.cancel_scope.cancel()
except anyio.get_cancelled_exc_class():
# Task was cancelled, which is expected when we're done
pass

tg.start_soon(monitor_process)

try:
yield read_stream, write_stream
finally:
# Set a flag to indicate we're explicitly terminating the process
# This prevents the monitor_process from treating our termination
# as an error when we explicitly terminate it
tg.cancel_scope.cancel()

# Close all streams to prevent resource leaks
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()

# Clean up process to prevent any dangling orphaned processes
if sys.platform == "win32":
await terminate_windows_process(process)
else:
process.terminate()
try:
if sys.platform == "win32":
await terminate_windows_process(process)
else:
process.terminate()
except ProcessLookupError:
# Process has already exited, which is fine
pass

if process_error:
# Raise outside the task group so that the error is not wrapped in an
# ExceptionGroup
raise ProcessTerminatedEarlyError(process_error)


def _get_executable_command(command: str) -> str:
Expand Down
38 changes: 37 additions & 1 deletion tests/client/test_stdio.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import shutil

import pytest
from anyio import fail_after

from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.session import ClientSession
from mcp.client.stdio import (
ProcessTerminatedEarlyError,
StdioServerParameters,
stdio_client,
)
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse

tee: str = shutil.which("tee") # type: ignore
python: str = shutil.which("python") # type: ignore


@pytest.mark.anyio
Expand Down Expand Up @@ -41,3 +48,32 @@ async def test_stdio_client():
assert read_messages[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
)


@pytest.mark.anyio
@pytest.mark.skipif(python is None, reason="could not find python command")
async def test_initialize_with_exiting_server():
"""
Test that ClientSession.initialize raises an error if the server process exits.
"""
# Create a server that will exit during initialization
server_params = StdioServerParameters(
command="python",
args=[
"-c",
"import sys; print('Error: Missing API key', file=sys.stderr); sys.exit(1)",
],
)

with pytest.raises(ProcessTerminatedEarlyError):
try:
# Set a timeout to avoid hanging indefinitely if the test fails
with fail_after(5):
async with stdio_client(server_params) as (read_stream, write_stream):
# Create a client session
session = ClientSession(read_stream, write_stream)

# This should fail because the server process has exited
await session.initialize()
except TimeoutError:
pytest.fail("The connection hung and timed out.")
Loading