Skip to content

Feat/prompts caching #1087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/mcp.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ agent = Agent(

## Caching

Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tool list will not change.
Every time an Agent runs, it calls `list_tools()` and `list_prompts()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools and prompts, you can pass `cache_tools_list=True` and `cache_prompts_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tools and the prompts lists will not change.

If you want to invalidate the cache, you can call `invalidate_tools_cache()` on the servers.
If you want to invalidate the cache, you can call `invalidate_tools_cache()` and `invalidate_prompts_cache()` on the servers.

## End-to-end examples

Expand Down
13 changes: 13 additions & 0 deletions examples/mcp/caching/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Caching Example

This example show how to integrate tools and prompts caching using a Streamable HTTP server in [server.py](server.py).

Run the example via:

```
uv run python examples/mcp/caching/main.py
```

## Details

The example uses the `MCPServerStreamableHttp` class from `agents.mcp`. The server runs in a sub-process at `https://localhost:8000/mcp`.
83 changes: 83 additions & 0 deletions examples/mcp/caching/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import asyncio
import os
import shutil
import subprocess
import time
from typing import Any

from agents import gen_trace_id, trace
from agents.mcp import MCPServerStreamableHttp


async def run(mcp_server: MCPServerStreamableHttp):
print("Cached tools before invoking tool_list")
print(mcp_server._tools_list)

print("Cached tools names after invoking list_tools")
await mcp_server.list_tools()
cached_tools_list = mcp_server._tools_list
if cached_tools_list:
for tool in cached_tools_list:
print(f"name: {tool.name}")

else:
print("Failed to cache list_prompts")

print("Cached prompts before invoking list_prompts")
print(mcp_server._prompts_list)

print("Cached prompts after invoking list_prompts")
await mcp_server.list_prompts()
cached_prompts_list = mcp_server._prompts_list
if cached_prompts_list:
for prompt in cached_prompts_list.prompts:
print(f"name: {prompt.name}")
else:
print("Failed to cache list_prompts")

async def main():
async with MCPServerStreamableHttp(
name="Streamable HTTP Python Server",
cache_tools_list=True,
cache_prompts_list=True,
params={
"url": "http://localhost:8000/mcp",
},
) as server:
trace_id = gen_trace_id()
with trace(workflow_name="Caching Example", trace_id=trace_id):
print(f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}\n")
await run(server)


if __name__ == "__main__":
# Let's make sure the user has uv installed
if not shutil.which("uv"):
raise RuntimeError(
"uv is not installed. Please install it: https://docs.astral.sh/uv/getting-started/installation/"
)

# We'll run the Streamable HTTP server in a subprocess. Usually this would be a remote server, but for this
# demo, we'll run it locally at http://localhost:8000/mcp
process: subprocess.Popen[Any] | None = None
try:
this_dir = os.path.dirname(os.path.abspath(__file__))
server_file = os.path.join(this_dir, "server.py")

print("Starting Streamable HTTP server at http://localhost:8000/mcp ...")

# Run `uv run server.py` to start the Streamable HTTP server
process = subprocess.Popen(["uv", "run", server_file])
# Give it 3 seconds to start
time.sleep(3)

print("Streamable HTTP server started. Running example...\n\n")
except Exception as e:
print(f"Error starting Streamable HTTP server: {e}")
exit(1)

try:
asyncio.run(main())
finally:
if process:
process.terminate()
37 changes: 37 additions & 0 deletions examples/mcp/caching/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import random

import requests
from mcp.server.fastmcp import FastMCP

# Create server
mcp = FastMCP("Echo Server")


@mcp.tool()
def add(a: int, b: int) -> int:
"""Add two numbers"""
print(f"[debug-server] add({a}, {b})")
return a + b


@mcp.tool()
def get_secret_word() -> str:
print("[debug-server] get_secret_word()")
return random.choice(["apple", "banana", "cherry"])


@mcp.tool()
def get_current_weather(city: str) -> str:
print(f"[debug-server] get_current_weather({city})")

endpoint = "https://wttr.in"
response = requests.get(f"{endpoint}/{city}")
return response.text

@mcp.prompt()
def system_prompt() -> str:
return "Use the tools to answer the questions."


if __name__ == "__main__":
mcp.run(transport="streamable-http")
97 changes: 77 additions & 20 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,26 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
def __init__(
self,
cache_tools_list: bool,
cache_prompts_list: bool,
client_session_timeout_seconds: float | None,
tool_filter: ToolFilter = None,
):
"""
Args:
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be invalidated
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
server will not change its tools list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be invalidated
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
server will not change its tools list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).

cache_prompts_list: Whether to cache the prompts list. If `True`, the prompts list
will be cached and only fetched from the server once. If `False`, the prompts
list will be fetched from the server on each call to `list_prompts()`.
The cache can be invalidated by calling `invalidate_prompts_cache()`.
You should set this to `True` if you know the server will not change
its prompts list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).

client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
tool_filter: The tool filter to use for filtering tools.
Expand All @@ -103,13 +112,16 @@ def __init__(
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.cache_tools_list = cache_tools_list
self.cache_prompts_list = cache_prompts_list
self.server_initialize_result: InitializeResult | None = None

self.client_session_timeout_seconds = client_session_timeout_seconds

# The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True
# The cache is always dirty at startup, so that we fetch tools and prompts at least once
self._cache_dirty_tools = True
self._tools_list: list[MCPTool] | None = None
self._cache_dirty_prompts = True
self._prompts_list: ListPromptsResult | None = None

self.tool_filter = tool_filter

Expand Down Expand Up @@ -213,7 +225,11 @@ async def __aexit__(self, exc_type, exc_value, traceback):

def invalidate_tools_cache(self):
"""Invalidate the tools cache."""
self._cache_dirty = True
self._cache_dirty_tools = True

def invalidate_prompts_cache(self):
"""Invalidate the prompts cache."""
self._cache_dirty_prompts = True

async def connect(self):
"""Connect to the server."""
Expand Down Expand Up @@ -251,11 +267,11 @@ async def list_tools(
raise UserError("Server not initialized. Make sure you call `connect()` first.")

# Return from cache if caching is enabled, we have tools, and the cache is not dirty
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
if self.cache_tools_list and not self._cache_dirty_tools and self._tools_list:
tools = self._tools_list
else:
# Reset the cache dirty to False
self._cache_dirty = False
self._cache_dirty_tools = False
# Fetch the tools from the server
self._tools_list = (await self.session.list_tools()).tools
tools = self._tools_list
Expand All @@ -282,7 +298,16 @@ async def list_prompts(
if not self.session:
raise UserError("Server not initialized. Make sure you call `connect()` first.")

return await self.session.list_prompts()
if self.cache_prompts_list and not self._cache_dirty_prompts and self._prompts_list:
prompts = self._prompts_list
else:
# Reset the cache dirty to False
self._cache_dirty_prompts = False
# Fetch the prompts from the server
self._prompts_list = await self.session.list_prompts()
prompts = self._prompts_list

return prompts

async def get_prompt(
self, name: str, arguments: dict[str, Any] | None = None
Expand Down Expand Up @@ -343,6 +368,7 @@ def __init__(
self,
params: MCPServerStdioParams,
cache_tools_list: bool = False,
cache_prompts_list: bool = False,
name: str | None = None,
client_session_timeout_seconds: float | None = 5,
tool_filter: ToolFilter = None,
Expand All @@ -354,21 +380,32 @@ def __init__(
start the server, the args to pass to the command, the environment variables to
set for the server, the working directory to use when spawning the process, and
the text encoding used when sending/receiving messages to the server.

cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).

cache_prompts_list: Whether to cache the prompts list. If `True`, the prompts list
will be cached and only fetched from the server once. If `False`, the prompts
list will be fetched from the server on each call to `list_prompts()`.
The cache can be invalidated by calling `invalidate_prompts_cache()`.
You should set this to `True` if you know the server will not change
its prompts list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).

name: A readable name for the server. If not provided, we'll create one from the
command.
client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
tool_filter: The tool filter to use for filtering tools.
"""
super().__init__(
cache_tools_list,
client_session_timeout_seconds,
tool_filter,
cache_tools_list=cache_tools_list,
cache_prompts_list=cache_prompts_list,
client_session_timeout_seconds=client_session_timeout_seconds,
tool_filter=tool_filter,
)

self.params = StdioServerParameters(
Expand Down Expand Up @@ -426,6 +463,7 @@ def __init__(
self,
params: MCPServerSseParams,
cache_tools_list: bool = False,
cache_prompts_list: bool = False,
name: str | None = None,
client_session_timeout_seconds: float | None = 5,
tool_filter: ToolFilter = None,
Expand All @@ -444,16 +482,25 @@ def __init__(
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).

cache_prompts_list: Whether to cache the prompts list. If `True`, the prompts list
will be cached and only fetched from the server once. If `False`, the prompts
list will be fetched from the server on each call to `list_prompts()`.
The cache can be invalidated by calling `invalidate_prompts_cache()`.
You should set this to `True` if you know the server will not change
its prompts list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).

name: A readable name for the server. If not provided, we'll create one from the
URL.

client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
tool_filter: The tool filter to use for filtering tools.
"""
super().__init__(
cache_tools_list,
client_session_timeout_seconds,
tool_filter,
cache_tools_list=cache_tools_list,
cache_prompts_list=cache_prompts_list,
client_session_timeout_seconds=client_session_timeout_seconds,
tool_filter=tool_filter,
)

self.params = params
Expand Down Expand Up @@ -511,6 +558,7 @@ def __init__(
self,
params: MCPServerStreamableHttpParams,
cache_tools_list: bool = False,
cache_prompts_list: bool = False,
name: str | None = None,
client_session_timeout_seconds: float | None = 5,
tool_filter: ToolFilter = None,
Expand All @@ -530,16 +578,25 @@ def __init__(
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).

cache_prompts_list: Whether to cache the prompts list. If `True`, the prompts list
will be cached and only fetched from the server once. If `False`, the prompts
list will be fetched from the server on each call to `list_prompts()`.
The cache can be invalidated by calling `invalidate_prompts_cache()`.
You should set this to `True` if you know the server will not change
its prompts list, because it can drastically improve latency
(by avoiding a round-trip to the server every time).

name: A readable name for the server. If not provided, we'll create one from the
URL.

client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
tool_filter: The tool filter to use for filtering tools.
"""
super().__init__(
cache_tools_list,
client_session_timeout_seconds,
tool_filter,
cache_tools_list=cache_tools_list,
cache_prompts_list=cache_prompts_list,
client_session_timeout_seconds=client_session_timeout_seconds,
tool_filter=tool_filter,
)

self.params = params
Expand Down
1 change: 1 addition & 0 deletions tests/mcp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, tool_filter: ToolFilter, server_name: str):
# Initialize parent class properly to avoid type errors
super().__init__(
cache_tools_list=False,
cache_prompts_list=False,
client_session_timeout_seconds=None,
tool_filter=tool_filter,
)
Expand Down
Loading