From bb97a16756f59fb0f7b0211bb9d8ce2f5926607e Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 4 Jan 2025 09:35:05 -0800 Subject: [PATCH] Add prompts to MCP server (#134619) * Add prompts to MCP server * Improve test coverage for get prompt error cases --- homeassistant/components/mcp_server/server.py | 31 ++++++++++++ tests/components/mcp_server/test_http.py | 49 +++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/homeassistant/components/mcp_server/server.py b/homeassistant/components/mcp_server/server.py index a52a0f92c0befc..ba21abd722cac9 100644 --- a/homeassistant/components/mcp_server/server.py +++ b/homeassistant/components/mcp_server/server.py @@ -50,6 +50,37 @@ async def create_server( server = Server("home-assistant") + @server.list_prompts() # type: ignore[no-untyped-call, misc] + async def handle_list_prompts() -> list[types.Prompt]: + llm_api = await llm.async_get_api(hass, llm_api_id, llm_context) + return [ + types.Prompt( + name=llm_api.api.name, + description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}", + ) + ] + + @server.get_prompt() # type: ignore[no-untyped-call, misc] + async def handle_get_prompt( + name: str, arguments: dict[str, str] | None + ) -> types.GetPromptResult: + llm_api = await llm.async_get_api(hass, llm_api_id, llm_context) + if name != llm_api.api.name: + raise ValueError(f"Unknown prompt: {name}") + + return types.GetPromptResult( + description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}", + messages=[ + types.PromptMessage( + role="assistant", + content=types.TextContent( + type="text", + text=llm_api.api_prompt, + ), + ) + ], + ) + @server.list_tools() # type: ignore[no-untyped-call, misc] async def list_tools() -> list[types.Tool]: """List available time tools.""" diff --git a/tests/components/mcp_server/test_http.py b/tests/components/mcp_server/test_http.py index 78f1364502dd36..a71bf42acc89db 100644 --- a/tests/components/mcp_server/test_http.py +++ b/tests/components/mcp_server/test_http.py @@ -10,6 +10,7 @@ import mcp import mcp.client.session import mcp.client.sse +from mcp.shared.exceptions import McpError import pytest from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN @@ -354,3 +355,51 @@ async def test_mcp_tool_call_failed( assert len(result.content) == 1 assert result.content[0].type == "text" assert "Error calling tool" in result.content[0].text + + +async def test_prompt_list( + hass: HomeAssistant, + setup_integration: None, + mcp_sse_url: str, + hass_supervisor_access_token: str, +) -> None: + """Test the list prompt endpoint.""" + + async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session: + result = await session.list_prompts() + + assert len(result.prompts) == 1 + prompt = result.prompts[0] + assert prompt.name == "Assist" + assert prompt.description == "Default prompt for the Home Assistant LLM API Assist" + + +async def test_prompt_get( + hass: HomeAssistant, + setup_integration: None, + mcp_sse_url: str, + hass_supervisor_access_token: str, +) -> None: + """Test the get prompt endpoint.""" + + async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session: + result = await session.get_prompt(name="Assist") + + assert result.description == "Default prompt for the Home Assistant LLM API Assist" + assert len(result.messages) == 1 + assert result.messages[0].role == "assistant" + assert result.messages[0].content.type == "text" + assert "When controlling Home Assistant" in result.messages[0].content.text + + +async def test_get_unknwon_prompt( + hass: HomeAssistant, + setup_integration: None, + mcp_sse_url: str, + hass_supervisor_access_token: str, +) -> None: + """Test the get prompt endpoint.""" + + async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session: + with pytest.raises(McpError): + await session.get_prompt(name="Unknown")