Skip to content

Add MCP Client Support #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ jobs:
strategy:
matrix:
config:
- { python-version: "3.9", test_google: false, test_azure: false }
- { python-version: "3.10", test_google: false, test_azure: false }
- { python-version: "3.11", test_google: false, test_azure: false }
- { python-version: "3.12", test_google: true, test_azure: true }
- { python-version: "3.13", test_google: false, test_azure: false }
- { python-version: "3.9", test_google: false, test_azure: false, test_mcp: false }
- { python-version: "3.10", test_google: false, test_azure: false, test_mcp: true }
- { python-version: "3.11", test_google: false, test_azure: false, test_mcp: true }
- { python-version: "3.12", test_google: true, test_azure: true, test_mcp: true }
- { python-version: "3.13", test_google: false, test_azure: false, test_mcp: true }

env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
Expand All @@ -33,6 +33,8 @@ jobs:
TEST_GOOGLE: ${{ matrix.config.test_google }}
# Free tier of Azure is rate limited, so we only test on 3.12
TEST_AZURE: ${{ matrix.config.test_azure }}
# MCP is only supported on 3.10 and above
TEST_MCP: ${{ matrix.config.test_mcp }}

steps:
- uses: actions/checkout@v4
Expand All @@ -44,7 +46,13 @@ jobs:
run: uv python install ${{matrix.config.python-version }}

- name: 📦 Install the project
run: uv sync --python ${{ matrix.config.python-version }} --all-extras
run: |
if [[ "${{ matrix.config.python-version }}" == "3.9" ]]; then
# MCP is only supported on 3.10 and above
uv sync --python ${{ matrix.config.python-version }} -e test -e dev -e docs
else
uv sync --python ${{ matrix.config.python-version }} --all-extras
fi

- name: 🧪 Check tests
run: make check-tests
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `ContentToolRequest` and `ContentToolResponse` now have `.tagify()` methods, making it so they can render automatically in a Shiny chatbot. (#75)
* `ContentToolResult` instances can be returned from tools. This allows for custom rendering of the tool result. (#75)
* `Chat` gains a new `.current_display` property. When a `.chat()` or `.stream()` is currently active, this property returns an object with a `.echo()` method (to echo new content to the display). This is primarily useful for displaying custom content during a tool call. (#79)
* Added support for `Chat` providers to be [MCP clients](https://modelcontextprotocol.io/) and register with remote tools over SSE or stdio.

### Improvements

Expand Down
2 changes: 1 addition & 1 deletion chatlas/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _structured_tool_call(**kwargs: Any):
"""Extract structured data"""
pass

data_model_tool = Tool(_structured_tool_call)
data_model_tool = Tool.from_func(_structured_tool_call)

data_model_tool.schema["function"]["parameters"] = {
"type": "object",
Expand Down
181 changes: 180 additions & 1 deletion chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import traceback
import warnings
from contextlib import AsyncExitStack
from pathlib import Path
from threading import Thread
from typing import (
Expand All @@ -23,6 +24,11 @@
overload,
)

from mcp import (
ClientSession as MCPClientSession,
)
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
Comment on lines +27 to +31
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these would have to be run time imports given mcp is an optional dependency

from pydantic import BaseModel

from ._content import (
Expand Down Expand Up @@ -102,6 +108,9 @@ def __init__(
"css_styles": {},
}

self._mcp_sessions: dict[str, MCPClientSession] = {}
self._mcp_exit_stack: AsyncExitStack = AsyncExitStack()

def get_turns(
self,
*,
Expand Down Expand Up @@ -904,6 +913,176 @@ async def extract_data_async(
json = res[0]
return json.value

async def _register_mcp_tools(
self,
session: MCPClientSession,
include_tools: Optional[list[str]] = None,
exclude_tools: Optional[list[str]] = None,
):
assert not (include_tools and exclude_tools), (
"Cannot specify both include_tools and exclude_tools."
)

response = await session.list_tools()
for tool in response.tools:
if include_tools:
if tool.name not in include_tools:
continue
if exclude_tools:
if tool.name in exclude_tools:
continue
self._tools[tool.name] = Tool.from_mcp(
session=session,
mcp_tool=tool,
)

async def register_mcp_sse_server_async(
self,
name: str,
url: str,
include_tools: Optional[list[str]] = None,
exclude_tools: Optional[list[str]] = None,
transport_kwargs: dict[str, Any] = {},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way we can support typing on these arguments?

Suggested change
transport_kwargs: dict[str, Any] = {},
kwargs: dict[str, Any] = {},

):
"""
Register a SSE-based MCP server session asynchronously.

This method establishes a new SSE (Server-Sent Events) connection to an MCP server and registers
the available tools. The server is identified by a unique name and URL.

Parameters
----------
name
Unique identifier for this MCP server session
url
URL endpoint of the MCP server
include_tools
List of tool names to include. If None, all available tools will be included. Defaults to None.
exclude_tools
List of tool names to exclude. This parameter and include_tools are mutually exclusive. Defaults to None.
transport_kwargs
Additional keyword arguments to pass to the SSE transport layer. Defaults to {}.

Raises
------
AssertionError
If a session with the given name already exists

Returns
-------
None

Examples
--------
```python
await chat.register_mcp_sse_server_async(
name="my_server",
url="http://localhost:8080/sse",
include_tools=["tool1", "tool2"],
transport_kwargs={"timeout": 30},
)
```
"""
assert name not in self._mcp_sessions, f"Session {name} already exists."

transport = await self._mcp_exit_stack.enter_async_context(
sse_client(url, **transport_kwargs)
)
self._mcp_sessions[name] = await self._mcp_exit_stack.enter_async_context(
MCPClientSession(*transport)
)
session = self._mcp_sessions[name]
await session.initialize()

await self._register_mcp_tools(
session,
include_tools=include_tools,
exclude_tools=exclude_tools,
)

async def register_mcp_stdio_server_async(
self,
name: str,
command: str,
args: list[str],
env: dict[str, str] | None = None,
include_tools: Optional[list[str]] = None,
exclude_tools: Optional[list[str]] = None,
transport_kwargs: dict[str, Any] = {},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way we can support typing on these arguments?

Suggested change
transport_kwargs: dict[str, Any] = {},
kwargs: dict[str, Any] = {},

):
"""
Register a stdio-based MCP server session asynchronously.

This method establishes a new stdio connection to an MCP server and registers
the available tools. The server is identified by a unique name and command.

Parameters
----------
name
Unique identifier for this MCP server session
command
Command to execute to start the MCP server
args
Arguments to pass to the command
env
Environment variables to set for the command. Defaults to None.
include_tools
List of tool names to include. If None, all available tools will be included. Defaults to None.
exclude_tools
List of tool names to exclude. This parameter and include_tools are mutually exclusive. Defaults to None.
transport_kwargs
Additional keyword arguments to pass to the stdio transport layer. Defaults to {}.

Raises
------
AssertionError
If a session with the given name already exists

Returns
-------
None

Examples
--------
```python
await chat.register_mcp_sse_server_async(
name="my_server",
command="python",
args=["-m", "my_mcp_server"],
env={"DEBUG": "1"},
include_tools=["tool1", "tool2"],
transport_kwargs={"timeout": 30},
)
```
"""
assert name not in self._mcp_sessions, f"Session {name} already exists."

server_params = StdioServerParameters(
command=command,
args=args,
env=env,
**transport_kwargs,
)

transport = await self._mcp_exit_stack.enter_async_context(
stdio_client(server_params)
)
self._mcp_sessions[name] = await self._mcp_exit_stack.enter_async_context(
MCPClientSession(*transport)
)
session = self._mcp_sessions[name]
await session.initialize()

await self._register_mcp_tools(
session,
include_tools=include_tools,
exclude_tools=exclude_tools,
)

async def close_mcp_sessions(self):
"""Clean up resources."""
await self._mcp_exit_stack.aclose()

def register_tool(
self,
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
Expand Down Expand Up @@ -984,7 +1163,7 @@ def add(a: int, b: int) -> int:
Note that the name and docstring of the model takes precedence over the
name and docstring of the function.
"""
tool = Tool(func, model=model)
tool = Tool.from_func(func, model=model)
self._tools[tool.name] = tool

@property
Expand Down
Loading
Loading