Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions src/google/adk/tools/mcp_tool/mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import TextIO
from typing import Union
import warnings
import asyncio

from pydantic import model_validator
from typing_extensions import override
Expand Down Expand Up @@ -104,6 +105,9 @@ def __init__(
errlog: TextIO = sys.stderr,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
session_close_timeout: float = 5.0,
session_create_timeout: float = 15.0,
list_tools_timeout: float = 30.0,
):
"""Initializes the MCPToolset.

Expand Down Expand Up @@ -140,6 +144,9 @@ def __init__(
)
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
self._session_close_timeout = session_close_timeout
self._session_create_timeout = session_create_timeout
self._list_tools_timeout = list_tools_timeout

@retry_on_closed_resource
async def get_tools(
Expand All @@ -155,11 +162,43 @@ async def get_tools(
Returns:
List[BaseTool]: A list of tools available under the specified context.
"""
# Get session from session manager
session = await self._mcp_session_manager.create_session()

# Close stale session manager and create fresh one
try:
await asyncio.wait_for(
self._mcp_session_manager.close(),
timeout=self._session_close_timeout
)
except (asyncio.TimeoutError, Exception) as e:
logger.warning('Ignoring error while closing stale MCP session manager: %s', e)

# Recreate session manager with fresh connections
self._mcp_session_manager = MCPSessionManager(
connection_params=self._connection_params,
errlog=self._errlog,
)

# Get session from session manager with timeout
try:
session = await asyncio.wait_for(
self._mcp_session_manager.create_session(),
timeout=self._session_create_timeout
)
except asyncio.TimeoutError:
raise RuntimeError(
f"Failed to create MCP session: timeout after {self._session_create_timeout}s"
)

# Fetch available tools from the MCP server
tools_response: ListToolsResult = await session.list_tools()
# Fetch available tools from the MCP server with timeout
try:
tools_response: ListToolsResult = await asyncio.wait_for(
session.list_tools(),
timeout=self._list_tools_timeout
)
except asyncio.TimeoutError:
raise RuntimeError(
f"Failed to list MCP tools: timeout after {self._list_tools_timeout}s"
)

# Apply filtering based on context and tool_filter
tools = []
Expand Down