11"""Base classes for FastMCP prompts."""
22
3+ from __future__ import annotations
4+
35import inspect
46from collections .abc import Awaitable , Callable , Sequence
5- from typing import Any , Literal
7+ from typing import TYPE_CHECKING , Any , Literal
68
79import pydantic_core
810from pydantic import BaseModel , Field , TypeAdapter , validate_call
911
12+ from mcp .server .fastmcp .utilities .context_injection import find_context_parameter , inject_context
13+ from mcp .server .fastmcp .utilities .func_metadata import func_metadata
1014from mcp .types import ContentBlock , TextContent
1115
16+ if TYPE_CHECKING :
17+ from mcp .server .fastmcp .server import Context
18+ from mcp .server .session import ServerSessionT
19+ from mcp .shared .context import LifespanContextT , RequestT
20+
1221
1322class Message (BaseModel ):
1423 """Base class for all prompt messages."""
@@ -62,6 +71,7 @@ class Prompt(BaseModel):
6271 description : str | None = Field (None , description = "Description of what the prompt does" )
6372 arguments : list [PromptArgument ] | None = Field (None , description = "Arguments that can be passed to the prompt" )
6473 fn : Callable [..., PromptResult | Awaitable [PromptResult ]] = Field (exclude = True )
74+ context_kwarg : str | None = Field (None , description = "Name of the kwarg that should receive context" , exclude = True )
6575
6676 @classmethod
6777 def from_function (
@@ -70,7 +80,8 @@ def from_function(
7080 name : str | None = None ,
7181 title : str | None = None ,
7282 description : str | None = None ,
73- ) -> "Prompt" :
83+ context_kwarg : str | None = None ,
84+ ) -> Prompt :
7485 """Create a Prompt from a function.
7586
7687 The function can return:
@@ -84,8 +95,16 @@ def from_function(
8495 if func_name == "<lambda>" :
8596 raise ValueError ("You must provide a name for lambda functions" )
8697
87- # Get schema from TypeAdapter - will fail if function isn't properly typed
88- parameters = TypeAdapter (fn ).json_schema ()
98+ # Find context parameter if it exists
99+ if context_kwarg is None :
100+ context_kwarg = find_context_parameter (fn )
101+
102+ # Get schema from func_metadata, excluding context parameter
103+ func_arg_metadata = func_metadata (
104+ fn ,
105+ skip_names = [context_kwarg ] if context_kwarg is not None else [],
106+ )
107+ parameters = func_arg_metadata .arg_model .model_json_schema ()
89108
90109 # Convert parameters to PromptArguments
91110 arguments : list [PromptArgument ] = []
@@ -109,9 +128,14 @@ def from_function(
109128 description = description or fn .__doc__ or "" ,
110129 arguments = arguments ,
111130 fn = fn ,
131+ context_kwarg = context_kwarg ,
112132 )
113133
114- async def render (self , arguments : dict [str , Any ] | None = None ) -> list [Message ]:
134+ async def render (
135+ self ,
136+ arguments : dict [str , Any ] | None = None ,
137+ context : Context [ServerSessionT , LifespanContextT , RequestT ] | None = None ,
138+ ) -> list [Message ]:
115139 """Render the prompt with arguments."""
116140 # Validate required arguments
117141 if self .arguments :
@@ -122,8 +146,11 @@ async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]
122146 raise ValueError (f"Missing required arguments: { missing } " )
123147
124148 try :
149+ # Add context to arguments if needed
150+ call_args = inject_context (self .fn , arguments or {}, context , self .context_kwarg )
151+
125152 # Call function and check if result is a coroutine
126- result = self .fn (** ( arguments or {}) )
153+ result = self .fn (** call_args )
127154 if inspect .iscoroutine (result ):
128155 result = await result
129156
0 commit comments