Skip to content

Commit ee8106b

Browse files
wukathcopybara-github
authored andcommitted
fix: Fix error handling when MCP server is unreachable
Currently ADK web hangs and logs "AGSI callable returned without completing response" when the server is unreachable. To fix, set timeouts for connecting to server. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 828648928
1 parent 99ca6aa commit ee8106b

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -339,38 +339,50 @@ async def create_session(
339339

340340
# Create a new session (either first time or replacing disconnected one)
341341
exit_stack = AsyncExitStack()
342+
timeout_in_seconds = (
343+
self._connection_params.timeout
344+
if hasattr(self._connection_params, 'timeout')
345+
else None
346+
)
342347

343348
try:
344349
client = self._create_client(merged_headers)
345350

346-
transports = await exit_stack.enter_async_context(client)
347-
# The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
348-
# needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
351+
transports = await asyncio.wait_for(
352+
exit_stack.enter_async_context(client),
353+
timeout=timeout_in_seconds,
354+
)
355+
# The streamable http client returns a GetSessionCallback in addition to the
356+
# read/write MemoryObjectStreams needed to build the ClientSession, we limit
357+
# then to the two first values to be compatible with all clients.
349358
if isinstance(self._connection_params, StdioConnectionParams):
350359
session = await exit_stack.enter_async_context(
351360
ClientSession(
352361
*transports[:2],
353-
read_timeout_seconds=timedelta(
354-
seconds=self._connection_params.timeout
355-
),
362+
read_timeout_seconds=timedelta(seconds=timeout_in_seconds),
356363
)
357364
)
358365
else:
359366
session = await exit_stack.enter_async_context(
360367
ClientSession(*transports[:2])
361368
)
362-
await session.initialize()
369+
await asyncio.wait_for(session.initialize(), timeout=timeout_in_seconds)
363370

364371
# Store session and exit stack in the pool
365372
self._sessions[session_key] = (session, exit_stack)
366373
logger.debug('Created new session: %s', session_key)
367374
return session
368375

369-
except Exception:
376+
except Exception as e:
370377
# If session creation fails, clean up the exit stack
371378
if exit_stack:
372-
await exit_stack.aclose()
373-
raise
379+
try:
380+
await exit_stack.aclose()
381+
except Exception as exit_stack_error:
382+
logger.warning(
383+
'Error during session creation cleanup: %s', exit_stack_error
384+
)
385+
raise ConnectionError(f'Failed to create MCP session: {e}') from e
374386

375387
async def close(self):
376388
"""Closes all sessions and cleans up resources."""

tests/unittests/tools/mcp_tool/test_mcp_session_manager.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
16+
from datetime import timedelta
1517
import hashlib
1618
from io import StringIO
1719
import json
@@ -279,6 +281,47 @@ async def test_create_session_reuse_existing(self):
279281
# Should not create new session
280282
existing_session.initialize.assert_not_called()
281283

284+
@pytest.mark.asyncio
285+
@patch("google.adk.tools.mcp_tool.mcp_session_manager.stdio_client")
286+
@patch("google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack")
287+
@patch("google.adk.tools.mcp_tool.mcp_session_manager.ClientSession")
288+
async def test_create_session_timeout(
289+
self, mock_session_class, mock_exit_stack_class, mock_stdio
290+
):
291+
"""Test session creation timeout."""
292+
manager = MCPSessionManager(self.mock_stdio_connection_params)
293+
294+
mock_session = MockClientSession()
295+
mock_exit_stack = MockAsyncExitStack()
296+
297+
mock_exit_stack_class.return_value = mock_exit_stack
298+
mock_stdio.return_value = AsyncMock()
299+
mock_exit_stack.enter_async_context.side_effect = [
300+
("read", "write"), # First call returns transports
301+
mock_session, # Second call returns session
302+
]
303+
mock_session_class.return_value = mock_session
304+
305+
# Simulate timeout during session initialization
306+
mock_session.initialize.side_effect = asyncio.TimeoutError("Test timeout")
307+
308+
# Expect ConnectionError due to timeout
309+
with pytest.raises(ConnectionError, match="Failed to create MCP session"):
310+
await manager.create_session()
311+
312+
# Verify ClientSession called with timeout
313+
mock_session_class.assert_called_with(
314+
"read",
315+
"write",
316+
read_timeout_seconds=timedelta(
317+
seconds=manager._connection_params.timeout
318+
),
319+
)
320+
# Verify session was not added to pool
321+
assert not manager._sessions
322+
# Verify cleanup was called
323+
mock_exit_stack.aclose.assert_called_once()
324+
282325
@pytest.mark.asyncio
283326
async def test_close_success(self):
284327
"""Test successful cleanup of all sessions."""

0 commit comments

Comments
 (0)