Skip to content

Commit d0cf667

Browse files
Expose initialization requests to middleware, and provide session storage
1 parent f950b8c commit d0cf667

File tree

3 files changed

+123
-1
lines changed

3 files changed

+123
-1
lines changed

src/fastmcp/server/low_level.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
from typing import Any
1+
from contextlib import AsyncExitStack
2+
from typing import TYPE_CHECKING, Any
23

4+
import anyio
5+
import mcp.types as types
36
from mcp.server.lowlevel.server import (
47
LifespanResultT,
58
NotificationOptions,
@@ -9,6 +12,66 @@
912
Server as _Server,
1013
)
1114
from mcp.server.models import InitializationOptions
15+
from mcp.server.session import ServerSession
16+
from mcp.shared.session import RequestResponder
17+
18+
if TYPE_CHECKING:
19+
from fastmcp.server.middleware.middleware import MiddlewareContext
20+
from fastmcp.server.server import FastMCP
21+
22+
23+
class MiddlewareExposedServerSession(ServerSession):
24+
"""ServerSession that routes initialization requests through FastMCP middleware."""
25+
26+
def __init__(self, fastmcp_server: "FastMCP", *args, **kwargs):
27+
super().__init__(*args, **kwargs)
28+
self.fastmcp_server = fastmcp_server
29+
30+
async def _received_request(
31+
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
32+
):
33+
# Check if this is an initialization request and if middleware should handle it
34+
if (
35+
isinstance(responder.request.root, types.InitializeRequest)
36+
and self.fastmcp_server
37+
and hasattr(self.fastmcp_server, "_apply_middleware")
38+
):
39+
# Import here to avoid circular imports
40+
from fastmcp.server.middleware.middleware import MiddlewareContext
41+
42+
# HACK: Pass session object directly to middleware context for proof-of-concept
43+
context = MiddlewareContext(
44+
message=responder.request.root.params,
45+
method="initialize",
46+
type="request",
47+
source="client",
48+
session=self, # Pass session so middleware can store data on it
49+
)
50+
51+
# Create a continuation that calls the original initialization handler
52+
async def call_original_handler(
53+
ctx: "MiddlewareContext",
54+
) -> types.InitializeResult:
55+
# Call the original handler by continuing to the parent implementation
56+
await super(MiddlewareExposedServerSession, self)._received_request(
57+
responder
58+
)
59+
# The response will be handled by the parent, we just need to extract the result
60+
# This is a bit tricky since the parent handles the response internally
61+
# For now, we'll call the parent and assume it handles the response correctly
62+
return None # Parent handles the actual response
63+
64+
# Apply middleware chain, but still let parent handle the actual response
65+
try:
66+
await self.fastmcp_server._apply_middleware(
67+
context, call_original_handler
68+
)
69+
except Exception:
70+
# If middleware fails, fall back to original handling
71+
await super()._received_request(responder)
72+
else:
73+
# For non-initialization requests or when no middleware, use original handling
74+
await super()._received_request(responder)
1275

1376

1477
class LowLevelServer(_Server[LifespanResultT, RequestT]):
@@ -20,6 +83,8 @@ def __init__(self, *args, **kwargs):
2083
resources_changed=True,
2184
tools_changed=True,
2285
)
86+
# Reference to FastMCP server for middleware integration
87+
self.fastmcp_server: FastMCP | None = None
2388

2489
def create_initialization_options(
2590
self,
@@ -35,3 +100,46 @@ def create_initialization_options(
35100
experimental_capabilities=experimental_capabilities,
36101
**kwargs,
37102
)
103+
104+
async def run(
105+
self,
106+
read_stream,
107+
write_stream,
108+
initialization_options,
109+
raise_exceptions=False,
110+
stateless=False,
111+
):
112+
"""Override run to use MiddlewareExposedServerSession when fastmcp_server is available."""
113+
async with AsyncExitStack() as stack:
114+
lifespan_context = await stack.enter_async_context(self.lifespan(self))
115+
116+
# Use MiddlewareExposedServerSession if we have a FastMCP server, otherwise use default
117+
if self.fastmcp_server:
118+
session = await stack.enter_async_context(
119+
MiddlewareExposedServerSession(
120+
self.fastmcp_server,
121+
read_stream,
122+
write_stream,
123+
initialization_options,
124+
stateless=stateless,
125+
)
126+
)
127+
else:
128+
session = await stack.enter_async_context(
129+
ServerSession(
130+
read_stream,
131+
write_stream,
132+
initialization_options,
133+
stateless=stateless,
134+
)
135+
)
136+
137+
async with anyio.create_task_group() as tg:
138+
async for message in session.incoming_messages:
139+
tg.start_soon(
140+
self._handle_message,
141+
message,
142+
session,
143+
lifespan_context,
144+
raise_exceptions,
145+
)

src/fastmcp/server/middleware/middleware.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ class MiddlewareContext(Generic[T]):
7373

7474
fastmcp_context: Context | None = None
7575

76+
# HACK: Session object for initialization middleware access
77+
session: Any | None = None
78+
7679
# Common metadata
7780
source: Literal["client", "server"] = "client"
7881
type: Literal["request", "notification"] = "request"
@@ -118,6 +121,8 @@ async def _dispatch_handler(
118121
handler = call_next
119122

120123
match context.method:
124+
case "initialize":
125+
handler = partial(self.on_initialize, call_next=handler)
121126
case "tools/call":
122127
handler = partial(self.on_call_tool, call_next=handler)
123128
case "resources/read":
@@ -164,6 +169,13 @@ async def on_notification(
164169
) -> Any:
165170
return await call_next(context)
166171

172+
async def on_initialize(
173+
self,
174+
context: MiddlewareContext[mt.InitializeRequestParams],
175+
call_next: CallNext[mt.InitializeRequestParams, mt.InitializeResult],
176+
) -> mt.InitializeResult:
177+
return await call_next(context)
178+
167179
async def on_call_tool(
168180
self,
169181
context: MiddlewareContext[mt.CallToolRequestParams],

src/fastmcp/server/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def __init__(
193193
instructions=instructions,
194194
lifespan=_lifespan_wrapper(self, lifespan),
195195
)
196+
# Set reference for middleware integration with initialization
197+
self._mcp_server.fastmcp_server = self
196198

197199
if auth is None and fastmcp.settings.default_auth_provider == "bearer_env":
198200
auth = EnvBearerAuthProvider()

0 commit comments

Comments
 (0)