Skip to content

Add support for Elicitation #625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
- [Prompts](#prompts)
- [Images](#images)
- [Context](#context)
- [Elicitation](#elicitation)
- [Authentication](#authentication)
- [Running Your Server](#running-your-server)
- [Development Mode](#development-mode)
- [Claude Desktop Integration](#claude-desktop-integration)
Expand Down Expand Up @@ -310,6 +312,44 @@ async def long_task(files: list[str], ctx: Context) -> str:
return "Processing complete"
```

### Elicitation

Request additional information from users during tool execution:

```python
from mcp.server.fastmcp import FastMCP, Context
from pydantic import BaseModel, Field

mcp = FastMCP("Booking System")


@mcp.tool()
async def book_table(date: str, party_size: int, ctx: Context) -> str:
"""Book a table with confirmation"""

# Schema must only contain primitive types (str, int, float, bool)
class ConfirmBooking(BaseModel):
confirm: bool = Field(description="Confirm booking?")
notes: str = Field(default="", description="Special requests")

result = await ctx.elicit(
message=f"Confirm booking for {party_size} on {date}?", schema=ConfirmBooking
)
Comment on lines +335 to +337
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it gives us a better API and user experience if ctx.elicit(schema=SchemaT) always return an instance of SchemaT, or an exception.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comment implies that an exception would be raised if user rejects.


if result.action == "accept" and result.data:
if result.data.confirm:
return f"Booked! Notes: {result.data.notes or 'None'}"
return "Booking cancelled"

# User declined or cancelled
return f"Booking {result.action}"
```

The `elicit()` method returns an `ElicitationResult` with:
- `action`: "accept", "decline", or "cancel"
- `data`: The validated response (only when accepted)
- `validation_error`: Any validation error message

### Authentication

Authentication can be used by servers that want to expose tools accessing protected resources.
Expand Down
30 changes: 30 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ async def __call__(
) -> types.CreateMessageResult | types.ErrorData: ...


class ElicitationFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData: ...


class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
Expand Down Expand Up @@ -58,6 +66,16 @@ async def _default_sampling_callback(
)


async def _default_elicitation_callback(
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Elicitation not supported",
)


async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
Expand Down Expand Up @@ -91,6 +109,7 @@ def __init__(
write_stream: MemoryObjectSendStream[SessionMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
elicitation_callback: ElicitationFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
Expand All @@ -105,12 +124,16 @@ def __init__(
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler

async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
elicitation = (
types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None
)
roots = (
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
Expand All @@ -128,6 +151,7 @@ async def initialize(self) -> types.InitializeResult:
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=sampling,
elicitation=elicitation,
experimental=None,
roots=roots,
),
Expand Down Expand Up @@ -356,6 +380,12 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)

case types.ElicitRequest(params=params):
with responder:
response = await self._elicitation_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)

case types.ListRootsRequest():
with responder:
response = await self._list_roots_callback(ctx)
Expand Down
106 changes: 104 additions & 2 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@

import inspect
import re
import types
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
from contextlib import (
AbstractAsyncContextManager,
asynccontextmanager,
)
from itertools import chain
from typing import Any, Generic, Literal
from typing import Any, Generic, Literal, TypeVar, Union, get_args, get_origin

import anyio
import pydantic_core
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ValidationError
from pydantic.fields import FieldInfo
from pydantic.networks import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
from starlette.applications import Starlette
Expand Down Expand Up @@ -65,6 +67,21 @@

logger = get_logger(__name__)

ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel)


class ElicitationResult(BaseModel, Generic[ElicitSchemaModelT]):
"""Result of an elicitation request."""

action: Literal["accept", "decline", "cancel"]
"""The user's action in response to the elicitation."""

data: ElicitSchemaModelT | None = None
"""The validated data if action is 'accept', None otherwise."""

validation_error: str | None = None
"""Validation error message if data failed to validate."""


class Settings(BaseSettings, Generic[LifespanResultT]):
"""FastMCP server settings.
Expand Down Expand Up @@ -858,6 +875,43 @@ def _convert_to_content(
return [TextContent(type="text", text=result)]


# Primitive types allowed in elicitation schemas
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)


def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
"""Validate that a Pydantic model only contains primitive field types."""
for field_name, field_info in schema.model_fields.items():
if not _is_primitive_field(field_info):
raise TypeError(
f"Elicitation schema field '{field_name}' must be a primitive type "
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
f"Complex types like lists, dicts, or nested models are not allowed."
)


def _is_primitive_field(field_info: FieldInfo) -> bool:
"""Check if a field is a primitive type allowed in elicitation schemas."""
annotation = field_info.annotation

# Handle None type
if annotation is types.NoneType:
return True

# Handle basic primitive types
if annotation in _ELICITATION_PRIMITIVE_TYPES:
return True

# Handle Union types
origin = get_origin(annotation)
if origin is Union or origin is types.UnionType:
args = get_args(annotation)
# All args must be primitive types or None
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)

return False


class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
"""Context object providing access to MCP capabilities.

Expand Down Expand Up @@ -954,6 +1008,54 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent
assert self._fastmcp is not None, "Context is not available outside of a request"
return await self._fastmcp.read_resource(uri)

async def elicit(
self,
message: str,
schema: type[ElicitSchemaModelT],
) -> ElicitationResult[ElicitSchemaModelT]:
"""Elicit information from the client/user.

This method can be used to interactively ask for additional information from the
client within a tool's execution. The client might display the message to the
user and collect a response according to the provided schema. Or in case a
client is an agent, it might decide how to handle the elicitation -- either by asking
the user or automatically generating a response.

Args:
schema: A Pydantic model class defining the expected response structure, according to the specification,
only primive types are allowed.
message: Optional message to present to the user. If not provided, will use
a default message based on the schema

Returns:
An ElicitationResult containing the action taken and the data if accepted

Note:
Check the result.action to determine if the user accepted, declined, or cancelled.
The result.data will only be populated if action is "accept" and validation succeeded.
"""

# Validate that schema only contains primitive types and fail loudly if not
_validate_elicitation_schema(schema)

json_schema = schema.model_json_schema()

result = await self.request_context.session.elicit(
message=message,
requestedSchema=json_schema,
related_request_id=self.request_id,
)

if result.action == "accept" and result.content:
# Validate and parse the content using the schema
try:
validated_data = schema.model_validate(result.content)
return ElicitationResult(action="accept", data=validated_data)
except ValidationError as e:
return ElicitationResult(action="accept", validation_error=str(e))
else:
return ElicitationResult(action=result.action)

async def log(
self,
level: Literal["debug", "info", "warning", "error"],
Expand Down
33 changes: 33 additions & 0 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
if client_caps.sampling is None:
return False

if capability.elicitation is not None:
if client_caps.elicitation is None:
return False

if capability.experimental is not None:
if client_caps.experimental is None:
return False
Expand Down Expand Up @@ -251,6 +255,35 @@ async def list_roots(self) -> types.ListRootsResult:
types.ListRootsResult,
)

async def elicit(
self,
message: str,
requestedSchema: types.ElicitRequestedSchema,
related_request_id: types.RequestId | None = None,
) -> types.ElicitResult:
"""Send an elicitation/create request.

Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response structure

Returns:
The client's response
"""
return await self.send_request(
types.ServerRequest(
types.ElicitRequest(
method="elicitation/create",
params=types.ElicitRequestParams(
message=message,
requestedSchema=requestedSchema,
),
)
),
types.ElicitResult,
metadata=ServerMessageMetadata(related_request_id=related_request_id),
)

async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
return await self.send_request(
Expand Down
11 changes: 10 additions & 1 deletion src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

import mcp.types as types
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
from mcp.client.session import (
ClientSession,
ElicitationFnT,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
SamplingFnT,
)
from mcp.server import Server
from mcp.shared.message import SessionMessage

Expand Down Expand Up @@ -53,6 +60,7 @@ async def create_connected_server_and_client_session(
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
raise_exceptions: bool = False,
elicitation_callback: ElicitationFnT | None = None,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
async with create_client_server_memory_streams() as (
Expand Down Expand Up @@ -83,6 +91,7 @@ async def create_connected_server_and_client_session(
logging_callback=logging_callback,
message_handler=message_handler,
client_info=client_info,
elicitation_callback=elicitation_callback,
) as client_session:
await client_session.initialize()
yield client_session
Expand Down
Loading
Loading