Skip to content

Commit

Permalink
Add prompts to MCP server (home-assistant#134619)
Browse files Browse the repository at this point in the history
* Add prompts to MCP server

* Improve test coverage for get prompt error cases
  • Loading branch information
allenporter authored Jan 4, 2025
1 parent c9a607a commit bb97a16
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
31 changes: 31 additions & 0 deletions homeassistant/components/mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,37 @@ async def create_server(

server = Server("home-assistant")

@server.list_prompts() # type: ignore[no-untyped-call, misc]
async def handle_list_prompts() -> list[types.Prompt]:
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
return [
types.Prompt(
name=llm_api.api.name,
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
)
]

@server.get_prompt() # type: ignore[no-untyped-call, misc]
async def handle_get_prompt(
name: str, arguments: dict[str, str] | None
) -> types.GetPromptResult:
llm_api = await llm.async_get_api(hass, llm_api_id, llm_context)
if name != llm_api.api.name:
raise ValueError(f"Unknown prompt: {name}")

return types.GetPromptResult(
description=f"Default prompt for the Home Assistant LLM API {llm_api.api.name}",
messages=[
types.PromptMessage(
role="assistant",
content=types.TextContent(
type="text",
text=llm_api.api_prompt,
),
)
],
)

@server.list_tools() # type: ignore[no-untyped-call, misc]
async def list_tools() -> list[types.Tool]:
"""List available time tools."""
Expand Down
49 changes: 49 additions & 0 deletions tests/components/mcp_server/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import mcp
import mcp.client.session
import mcp.client.sse
from mcp.shared.exceptions import McpError
import pytest

from homeassistant.components.conversation import DOMAIN as CONVERSATION_DOMAIN
Expand Down Expand Up @@ -354,3 +355,51 @@ async def test_mcp_tool_call_failed(
assert len(result.content) == 1
assert result.content[0].type == "text"
assert "Error calling tool" in result.content[0].text


async def test_prompt_list(
hass: HomeAssistant,
setup_integration: None,
mcp_sse_url: str,
hass_supervisor_access_token: str,
) -> None:
"""Test the list prompt endpoint."""

async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
result = await session.list_prompts()

assert len(result.prompts) == 1
prompt = result.prompts[0]
assert prompt.name == "Assist"
assert prompt.description == "Default prompt for the Home Assistant LLM API Assist"


async def test_prompt_get(
hass: HomeAssistant,
setup_integration: None,
mcp_sse_url: str,
hass_supervisor_access_token: str,
) -> None:
"""Test the get prompt endpoint."""

async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
result = await session.get_prompt(name="Assist")

assert result.description == "Default prompt for the Home Assistant LLM API Assist"
assert len(result.messages) == 1
assert result.messages[0].role == "assistant"
assert result.messages[0].content.type == "text"
assert "When controlling Home Assistant" in result.messages[0].content.text


async def test_get_unknwon_prompt(
hass: HomeAssistant,
setup_integration: None,
mcp_sse_url: str,
hass_supervisor_access_token: str,
) -> None:
"""Test the get prompt endpoint."""

async with mcp_session(mcp_sse_url, hass_supervisor_access_token) as session:
with pytest.raises(McpError):
await session.get_prompt(name="Unknown")

0 comments on commit bb97a16

Please sign in to comment.