Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mcp_optimizer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def normalize_runtime_mode(cls, v) -> str:

# Timeout configuration
mcp_timeout: int = Field(
default=10, ge=1, le=300, description="MCP operation timeout in seconds (1-300)"
default=20, ge=1, le=300, description="MCP operation timeout in seconds (1-300)"
)

# Database configuration
Expand Down
90 changes: 88 additions & 2 deletions src/mcp_optimizer/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
from typing import Any, Awaitable, Callable

import httpx
import structlog
from mcp import ClientSession
from mcp.client.sse import sse_client
Expand All @@ -18,6 +19,79 @@
logger = structlog.get_logger(__name__)


class _TolerantStream(httpx.AsyncByteStream):
"""
Stream wrapper that tolerates incomplete response errors.

Some remote SSE servers (behind proxies/CDNs) close POST response connections
before sending the complete response body. This is not a problem for SSE
because the actual MCP response arrives via the SSE stream, not the POST response.
"""

def __init__(self, original_stream: httpx.AsyncByteStream):
self._original: httpx.AsyncByteStream = original_stream

async def __aiter__(self):
try:
async for chunk in self._original:
yield chunk
except httpx.RemoteProtocolError as e:
# Server closed connection before body was sent - this is OK
# for SSE since the actual response comes via the SSE stream
logger.debug(
"Ignoring RemoteProtocolError on POST response (expected for some SSE servers)",
error=str(e),
)

async def aclose(self):
await self._original.aclose()


class _TolerantTransport(httpx.AsyncHTTPTransport):
"""
Custom transport that tolerates servers closing POST response connections early.

This is needed for some remote SSE MCP servers where the proxy/CDN closes
the POST response connection before the body is fully sent. The actual MCP
response arrives via SSE, so the POST response body is not needed.
"""

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
response = await super().handle_async_request(request)

# For POST requests, wrap the stream to tolerate incomplete responses
if request.method == "POST":
original_stream = response.stream
# AsyncHTTPTransport always produces AsyncByteStream
if not isinstance(original_stream, httpx.AsyncByteStream):
raise TypeError(
"Expected response.stream to be an instance of httpx.AsyncByteStream"
)
response.stream = _TolerantStream(original_stream)

return response


def _create_tolerant_httpx_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
"""
Create an httpx client that tolerates incomplete POST responses.

This is needed for remote SSE MCP servers where the server/proxy closes
the POST response connection before the body is sent. The actual MCP
response arrives via SSE, so this is safe to ignore.
"""
return httpx.AsyncClient(
headers=headers,
timeout=timeout,
auth=auth,
transport=_TolerantTransport(),
)


class WorkloadConnectionError(Exception):
"""Custom exception for workload-related errors."""

Expand Down Expand Up @@ -243,13 +317,25 @@ async def _execute_streamable_session(
async def _execute_sse_session(
self, operation: Callable[[ClientSession], Awaitable], url: str
) -> Any:
"""Execute operation with SSE session."""
"""
Execute operation with SSE session.

Uses a tolerant client that ignores incomplete POST response body errors.
This is needed because some remote SSE servers (behind proxies/CDNs) close
the POST response connection before the body is fully sent. The actual MCP
response arrives via the SSE stream, so the POST response body is not needed.
"""
logger.debug(
f"Establishing SSE session for workload '{self.workload.name}'",
workload=self.workload.name,
url=url,
)
async with sse_client(url) as (read_stream, write_stream):

# Use tolerant client to handle servers that close POST response connections early
async with sse_client(url=url, httpx_client_factory=_create_tolerant_httpx_client) as (
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session:
logger.info(
"Initializing SSE MCP session for workload",
Expand Down
159 changes: 158 additions & 1 deletion tests/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

from unittest.mock import AsyncMock, patch

import httpx
import pytest
from mcp.shared.exceptions import McpError
from mcp.types import ErrorData

from mcp_optimizer.mcp_client import (
MCPServerClient,
WorkloadConnectionError,
_create_tolerant_httpx_client,
_TolerantStream,
determine_transport_type,
)
from mcp_optimizer.toolhive.api_models.core import Workload
Expand Down Expand Up @@ -325,7 +328,13 @@ async def test_workload_url_unchanged_during_list_tools(
assert workload.url == url

# Verify the client was called with the original URL
mock_client.assert_called_once_with(url)
# SSE client uses keyword arguments including httpx_client_factory
if client_mock_name == "sse_client":
mock_client.assert_called_once()
assert mock_client.call_args.kwargs["url"] == url
assert "httpx_client_factory" in mock_client.call_args.kwargs
else:
mock_client.assert_called_once_with(url)


@pytest.mark.asyncio
Expand Down Expand Up @@ -633,3 +642,151 @@ def test_determine_transport_type_docker_fallback_when_no_proxy_mode():
result = determine_transport_type(workload, "docker")
# Docker mode should ignore transport_type and fallback to URL
assert result == ToolHiveTransportMode.STREAMABLE


# Unit tests for SSE tolerant client behavior


@pytest.mark.asyncio
async def test_sse_session_propagates_errors(mock_mcp_session):
"""Test that SSE session propagates errors as WorkloadConnectionError."""
workload = Workload(
name="test-server",
url="http://localhost:8080/sse/test-server",
status="running",
tool_type="mcp",
)

client = MCPServerClient(workload, timeout=10, runtime_mode="docker")

class FailingCM:
async def __aenter__(self):
raise ExceptionGroup("errors", [ValueError("some error")])

async def __aexit__(self, exc_type, exc_val, exc_tb):
return False

with (
patch("mcp_optimizer.mcp_client.sse_client") as mock_sse_client,
patch("mcp_optimizer.mcp_client.ClientSession", return_value=mock_mcp_session),
):
mock_sse_client.return_value = FailingCM()

# Should raise WorkloadConnectionError
with pytest.raises(WorkloadConnectionError):
await client.list_tools()

# Verify sse_client was called once
assert mock_sse_client.call_count == 1


@pytest.mark.asyncio
async def test_sse_session_uses_tolerant_client(mock_mcp_session):
"""Test that SSE session always uses the tolerant httpx client."""
workload = Workload(
name="test-server",
url="http://localhost:8080/sse/test-server",
status="running",
tool_type="mcp",
)

client = MCPServerClient(workload, timeout=10, runtime_mode="docker")

with (
patch("mcp_optimizer.mcp_client.sse_client") as mock_sse_client,
patch(
"mcp_optimizer.mcp_client.ClientSession", return_value=mock_mcp_session
) as mock_session_class,
):
# Mock successful connection
mock_sse_client.return_value.__aenter__.return_value = (AsyncMock(), AsyncMock())
mock_session_class.return_value.__aenter__.return_value = mock_mcp_session

# Call list_tools
await client.list_tools()

# Verify sse_client was called with httpx_client_factory
assert mock_sse_client.call_count == 1
call_kwargs = mock_sse_client.call_args.kwargs
assert "httpx_client_factory" in call_kwargs
assert call_kwargs["httpx_client_factory"] == _create_tolerant_httpx_client


# Unit tests for _TolerantStream class


class MockAsyncByteStream(httpx.AsyncByteStream):
"""Mock async byte stream for testing _TolerantStream."""

def __init__(self, chunks: list[bytes], exception: Exception | None = None):
self.chunks = chunks
self.exception = exception
self._closed = False

async def __aiter__(self):
for chunk in self.chunks:
yield chunk
if self.exception:
raise self.exception

async def aclose(self):
self._closed = True


@pytest.mark.asyncio
async def test_tolerant_stream_swallows_remote_protocol_error():
"""Test that _TolerantStream catches and ignores RemoteProtocolError."""
# Create a stream that raises RemoteProtocolError after yielding some data
error = httpx.RemoteProtocolError("Server disconnected")
mock_stream = MockAsyncByteStream(chunks=[b"chunk1", b"chunk2"], exception=error)

tolerant_stream = _TolerantStream(mock_stream)

# Should not raise, just stop iterating
chunks = []
async for chunk in tolerant_stream:
chunks.append(chunk)

# Should have received the chunks before the error
assert chunks == [b"chunk1", b"chunk2"]


@pytest.mark.asyncio
async def test_tolerant_stream_propagates_other_errors():
"""Test that _TolerantStream does not swallow non-RemoteProtocolError exceptions."""
# Create a stream that raises a different error
error = ValueError("Some other error")
mock_stream = MockAsyncByteStream(chunks=[b"chunk1"], exception=error)

tolerant_stream = _TolerantStream(mock_stream)

# Should raise ValueError, not swallow it
with pytest.raises(ValueError, match="Some other error"):
async for _ in tolerant_stream:
pass


@pytest.mark.asyncio
async def test_tolerant_stream_passes_through_chunks():
"""Test that _TolerantStream correctly passes through all chunks when no error."""
mock_stream = MockAsyncByteStream(chunks=[b"chunk1", b"chunk2", b"chunk3"])

tolerant_stream = _TolerantStream(mock_stream)

chunks = []
async for chunk in tolerant_stream:
chunks.append(chunk)

assert chunks == [b"chunk1", b"chunk2", b"chunk3"]


@pytest.mark.asyncio
async def test_tolerant_stream_aclose():
"""Test that _TolerantStream properly closes the underlying stream."""
mock_stream = MockAsyncByteStream(chunks=[])

tolerant_stream = _TolerantStream(mock_stream)

assert not mock_stream._closed
await tolerant_stream.aclose()
assert mock_stream._closed