Skip to content

Timeout for initializing MCP client #1833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 31 additions & 37 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import base64
import json
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Sequence
from contextlib import AsyncExitStack, asynccontextmanager
Expand All @@ -12,6 +11,7 @@
from types import TracebackType
from typing import Any

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.shared.message import SessionMessage
from mcp.types import (
Expand Down Expand Up @@ -77,6 +77,9 @@ def _get_log_level(self) -> LoggingLevel | None:
"""Get the log level for the MCP server."""
raise NotImplementedError('MCP Server subclasses must implement this method.')

def _get_client_initialize_timeout(self) -> float:
return 5 # pragma: no cover

def get_prefixed_tool_name(self, tool_name: str) -> str:
"""Get the tool name with prefix if `tool_prefix` is set."""
return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name
Expand Down Expand Up @@ -136,7 +139,9 @@ async def __aenter__(self) -> Self:
client = ClientSession(read_stream=self._read_stream, write_stream=self._write_stream)
self._client = await self._exit_stack.enter_async_context(client)

await self._client.initialize()
with anyio.fail_after(self._get_client_initialize_timeout()):
await self._client.initialize()

if log_level := self._get_log_level():
await self._client.set_logging_level(log_level)
self.is_running = True
Expand Down Expand Up @@ -251,6 +256,9 @@ async def main():
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
"""

timeout: float = 5
""" The timeout in seconds to wait for the client to initialize."""

@asynccontextmanager
async def client_streams(
self,
Expand All @@ -267,6 +275,9 @@ def _get_log_level(self) -> LoggingLevel | None:
def __repr__(self) -> str:
return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'

def _get_client_initialize_timeout(self) -> float:
return self.timeout


@dataclass
class MCPServerHTTP(MCPServer):
Expand Down Expand Up @@ -312,15 +323,15 @@ async def main():
Useful for authentication, custom headers, or other HTTP-specific configurations.
"""

timeout: timedelta | float = timedelta(seconds=5)
"""Initial connection timeout as a timedelta for establishing the connection.
timeout: float = 5
"""Initial connection timeout in seconds for establishing the connection.

This timeout applies to the initial connection setup and handshake.
If the connection cannot be established within this time, the operation will fail.
"""

sse_read_timeout: timedelta | float = timedelta(minutes=5)
"""Maximum time as a timedelta to wait for new SSE messages before timing out.
sse_read_timeout: float = 300
"""Maximum time as in seconds to wait for new SSE messages before timing out.

This timeout applies to the long-lived SSE connection after it's established.
If no new messages are received within this time, the connection will be considered stale
Expand All @@ -343,46 +354,26 @@ async def main():
"""

def __post_init__(self):
if not isinstance(self.timeout, timedelta):
warnings.warn(
'Passing timeout as a float has been deprecated, please use a timedelta instead.',
DeprecationWarning,
stacklevel=2,
)
self.timeout = timedelta(seconds=self.timeout)
# streamablehttp_client expects timedeltas, so we accept them too to match,
# but primarily work with floats for a simpler user API.

if not isinstance(self.sse_read_timeout, timedelta):
warnings.warn(
'Passing sse_read_timeout as a float has been deprecated, please use a timedelta instead.',
DeprecationWarning,
stacklevel=2,
)
self.sse_read_timeout = timedelta(seconds=self.sse_read_timeout)
if isinstance(self.timeout, timedelta):
self.timeout = self.timeout.total_seconds()

if isinstance(self.sse_read_timeout, timedelta):
self.sse_read_timeout = self.sse_read_timeout.total_seconds()

@asynccontextmanager
async def client_streams(
self,
) -> AsyncIterator[
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
]: # pragma: no cover
if not isinstance(self.timeout, timedelta):
warnings.warn(
'Passing timeout as a float has been deprecated, please use a timedelta instead.',
DeprecationWarning,
stacklevel=2,
)
self.timeout = timedelta(seconds=self.timeout)

if not isinstance(self.sse_read_timeout, timedelta):
warnings.warn(
'Passing sse_read_timeout as a float has been deprecated, please use a timedelta instead.',
DeprecationWarning,
stacklevel=2,
)
self.sse_read_timeout = timedelta(seconds=self.sse_read_timeout)

async with streamablehttp_client(
url=self.url, headers=self.headers, timeout=self.timeout, sse_read_timeout=self.sse_read_timeout
url=self.url,
headers=self.headers,
timeout=timedelta(seconds=self.timeout),
sse_read_timeout=timedelta(self.sse_read_timeout),
) as (read_stream, write_stream, _):
yield read_stream, write_stream

Expand All @@ -391,3 +382,6 @@ def _get_log_level(self) -> LoggingLevel | None:

def __repr__(self) -> str: # pragma: no cover
return f'MCPServerHTTP(url={self.url!r}, tool_prefix={self.tool_prefix!r})'

def _get_client_initialize_timeout(self) -> float: # pragma: no cover
return self.timeout
29 changes: 14 additions & 15 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,30 +81,29 @@ def test_http_server_with_header_and_timeout():
http_server = MCPServerHTTP(
url='http://localhost:8000/sse',
headers={'my-custom-header': 'my-header-value'},
timeout=timedelta(seconds=10),
sse_read_timeout=timedelta(seconds=100),
timeout=10,
sse_read_timeout=100,
log_level='info',
)
assert http_server.url == 'http://localhost:8000/sse'
assert http_server.headers is not None and http_server.headers['my-custom-header'] == 'my-header-value'
assert http_server.timeout == timedelta(seconds=10)
assert http_server.sse_read_timeout == timedelta(seconds=100)
assert http_server.timeout == 10
assert http_server.sse_read_timeout == 100
assert http_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage]


def test_http_server_with_deprecated_arguments():
with pytest.warns(DeprecationWarning):
http_server = MCPServerHTTP(
url='http://localhost:8000/sse',
headers={'my-custom-header': 'my-header-value'},
timeout=10,
sse_read_timeout=100,
log_level='info',
)
def test_http_server_with_timedelta_arguments():
http_server = MCPServerHTTP(
url='http://localhost:8000/sse',
headers={'my-custom-header': 'my-header-value'},
timeout=timedelta(seconds=10), # type: ignore[arg-type]
sse_read_timeout=timedelta(seconds=100), # type: ignore[arg-type]
log_level='info',
)
assert http_server.url == 'http://localhost:8000/sse'
assert http_server.headers is not None and http_server.headers['my-custom-header'] == 'my-header-value'
assert http_server.timeout == timedelta(seconds=10)
assert http_server.sse_read_timeout == timedelta(seconds=100)
assert http_server.timeout == 10
assert http_server.sse_read_timeout == 100
assert http_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage]


Expand Down