Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 109 additions & 1 deletion src/fastmcp/server/low_level.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any

import anyio
import mcp.types as types
from mcp.server.lowlevel.server import (
LifespanResultT,
NotificationOptions,
Expand All @@ -9,6 +12,66 @@
Server as _Server,
)
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared.session import RequestResponder

if TYPE_CHECKING:
from fastmcp.server.middleware.middleware import MiddlewareContext
from fastmcp.server.server import FastMCP


class MiddlewareExposedServerSession(ServerSession):
"""ServerSession that routes initialization requests through FastMCP middleware."""

def __init__(self, fastmcp_server: "FastMCP", *args, **kwargs):
super().__init__(*args, **kwargs)
self.fastmcp_server = fastmcp_server

async def _received_request(
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
):
# Check if this is an initialization request and if middleware should handle it
if (
isinstance(responder.request.root, types.InitializeRequest)
and self.fastmcp_server
and hasattr(self.fastmcp_server, "_apply_middleware")
):
# Import here to avoid circular imports
from fastmcp.server.middleware.middleware import MiddlewareContext

# HACK: Pass session object directly to middleware context for proof-of-concept
context = MiddlewareContext(
message=responder.request.root.params,
method="initialize",
type="request",
source="client",
session=self, # Pass session so middleware can store data on it
)

# Create a continuation that calls the original initialization handler
async def call_original_handler(
ctx: "MiddlewareContext",
) -> types.InitializeResult:
# Call the original handler by continuing to the parent implementation
await super(MiddlewareExposedServerSession, self)._received_request(
responder
)
# The response will be handled by the parent, we just need to extract the result
# This is a bit tricky since the parent handles the response internally
# For now, we'll call the parent and assume it handles the response correctly
return None # Parent handles the actual response

# Apply middleware chain, but still let parent handle the actual response
try:
await self.fastmcp_server._apply_middleware(
context, call_original_handler
)
except Exception:
# If middleware fails, fall back to original handling
await super()._received_request(responder)
else:
# For non-initialization requests or when no middleware, use original handling
await super()._received_request(responder)


class LowLevelServer(_Server[LifespanResultT, RequestT]):
Expand All @@ -20,6 +83,8 @@ def __init__(self, *args, **kwargs):
resources_changed=True,
tools_changed=True,
)
# Reference to FastMCP server for middleware integration
self.fastmcp_server: FastMCP | None = None

def create_initialization_options(
self,
Expand All @@ -35,3 +100,46 @@ def create_initialization_options(
experimental_capabilities=experimental_capabilities,
**kwargs,
)

async def run(
self,
read_stream,
write_stream,
initialization_options,
raise_exceptions=False,
stateless=False,
):
"""Override run to use MiddlewareExposedServerSession when fastmcp_server is available."""
async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))

# Use MiddlewareExposedServerSession if we have a FastMCP server, otherwise use default
if self.fastmcp_server:
session = await stack.enter_async_context(
MiddlewareExposedServerSession(
self.fastmcp_server,
read_stream,
write_stream,
initialization_options,
stateless=stateless,
)
)
else:
session = await stack.enter_async_context(
ServerSession(
read_stream,
write_stream,
initialization_options,
stateless=stateless,
)
)

async with anyio.create_task_group() as tg:
async for message in session.incoming_messages:
tg.start_soon(
self._handle_message,
message,
session,
lifespan_context,
raise_exceptions,
)
12 changes: 12 additions & 0 deletions src/fastmcp/server/middleware/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class MiddlewareContext(Generic[T]):

fastmcp_context: Context | None = None

# HACK: Session object for initialization middleware access
session: Any | None = None

# Common metadata
source: Literal["client", "server"] = "client"
type: Literal["request", "notification"] = "request"
Expand Down Expand Up @@ -98,6 +101,8 @@ async def _dispatch_handler(
handler = call_next

match context.method:
case "initialize":
handler = partial(self.on_initialize, call_next=handler)
case "tools/call":
handler = partial(self.on_call_tool, call_next=handler)
case "resources/read":
Expand Down Expand Up @@ -144,6 +149,13 @@ async def on_notification(
) -> Any:
return await call_next(context)

async def on_initialize(
self,
context: MiddlewareContext[mt.InitializeRequestParams],
call_next: CallNext[mt.InitializeRequestParams, mt.InitializeResult],
) -> mt.InitializeResult:
return await call_next(context)

async def on_call_tool(
self,
context: MiddlewareContext[mt.CallToolRequestParams],
Expand Down
2 changes: 2 additions & 0 deletions src/fastmcp/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def __init__(
instructions=instructions,
lifespan=_lifespan_wrapper(self, lifespan),
)
# Set reference for middleware integration with initialization
self._mcp_server.fastmcp_server = self

# if auth is `NotSet`, try to create a provider from the environment
if auth is NotSet:
Expand Down
Loading