Skip to content

Commit

Permalink
Add start_conversation service to Assist Satellite
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob committed Jan 11, 2025
1 parent ab8af03 commit 89b5235
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 15 deletions.
15 changes: 15 additions & 0 deletions homeassistant/components/assist_satellite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"async_internal_announce",
[AssistSatelliteEntityFeature.ANNOUNCE],
)
component.async_register_entity_service(
"start_conversation",
vol.All(
cv.make_entity_service_schema(
{
vol.Optional("start_message"): str,
vol.Optional("start_media_id"): str,
vol.Optional("extra_system_prompt"): str,
}
),
cv.has_at_least_one_key("start_message", "start_media_id"),
),
"async_internal_start_conversation",
[AssistSatelliteEntityFeature.START_CONVERSATION],
)
hass.data[CONNECTION_TEST_DATA] = {}
async_register_websocket_api(hass)
hass.http.register_view(ConnectionTestView())
Expand Down
3 changes: 3 additions & 0 deletions homeassistant/components/assist_satellite/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ class AssistSatelliteEntityFeature(IntFlag):

ANNOUNCE = 1
"""Device supports remotely triggered announcements."""

START_CONVERSATION = 2
"""Device supports starting conversations."""
55 changes: 54 additions & 1 deletion homeassistant/components/assist_satellite/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from typing import Any, Final, Literal, final

from homeassistant.components import media_source, stt, tts
from homeassistant.components import conversation, media_source, stt, tts
from homeassistant.components.assist_pipeline import (
OPTION_PREFERRED,
AudioSettings,
Expand All @@ -27,6 +27,7 @@
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription

Expand Down Expand Up @@ -113,6 +114,7 @@ class AssistSatelliteEntity(entity.Entity):

_run_has_tts: bool = False
_is_announcing = False
_extra_system_prompt: str | None = None
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
_attr_tts_options: dict[str, Any] | None = None
_pipeline_task: asyncio.Task | None = None
Expand Down Expand Up @@ -212,6 +214,56 @@ async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> Non
"""
raise NotImplementedError

async def async_internal_start_conversation(
self,
start_message: str | None = None,
start_media_id: str | None = None,
extra_system_prompt: str | None = None,
) -> None:
"""Start a conversation from the satellite.
If start_media_id is not provided, message is synthesized to
audio with the selected pipeline.
If start_media_id is provided, it is played directly. It is possible
to omit the message and the satellite will not show any text.
Calls async_start_conversation.
"""
await self._cancel_running_pipeline()

# The Home Assistant built-in agent doesn't support conversations.
pipeline = async_get_pipeline(self.hass, self._resolve_pipeline())
if pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT:
raise HomeAssistantError(
"Built-in conversation agent does not support starting conversations"
)

if start_message is None:
start_message = ""

announcement = await self._resolve_announcement_media_id(
start_message, start_media_id
)

if self._is_announcing:
raise SatelliteBusyError

self._is_announcing = True
self._extra_system_prompt = extra_system_prompt

try:
await self.async_start_conversation(announcement)
finally:
self._is_announcing = False
self._extra_system_prompt = None

async def async_start_conversation(
self, start_announcement: AssistSatelliteAnnouncement
) -> None:
"""Start a conversation from the satellite."""
raise NotImplementedError

async def async_accept_pipeline_from_satellite(
self,
audio_stream: AsyncIterable[bytes],
Expand Down Expand Up @@ -298,6 +350,7 @@ async def async_accept_pipeline_from_satellite(
),
start_stage=start_stage,
end_stage=end_stage,
conversation_extra_system_prompt=self._extra_system_prompt,
),
f"{self.entity_id}_pipeline",
)
Expand Down
3 changes: 3 additions & 0 deletions homeassistant/components/assist_satellite/icons.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
"services": {
"announce": {
"service": "mdi:bullhorn"
},
"start_conversation": {
"service": "mdi:forum"
}
}
}
16 changes: 16 additions & 0 deletions homeassistant/components/assist_satellite/services.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,19 @@ announce:
required: false
selector:
text:
start_conversation:
target:
entity:
domain: assist_satellite
supported_features:
- assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
fields:
start_message:
required: false
example: "You left the lights on in the living room. Turn them off?"
selector:
text:
start_media_id:
required: false
selector:
text:
14 changes: 14 additions & 0 deletions homeassistant/components/assist_satellite/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
"description": "The media ID to announce instead of using text-to-speech."
}
}
},
"start_conversation": {
"name": "Start Conversation",
"description": "Start a conversation from a satellite.",
"fields": {
"start_message": {
"name": "Message",
"description": "The message to start with."
},
"start_media_id": {
"name": "Media ID",
"description": "The media ID to start with instead of using text-to-speech."
}
}
}
}
}
2 changes: 2 additions & 0 deletions homeassistant/helpers/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def _base_components() -> dict[str, ModuleType]:
# pylint: disable-next=import-outside-toplevel
from homeassistant.components import (
alarm_control_panel,
assist_satellite,
calendar,
camera,
climate,
Expand All @@ -108,6 +109,7 @@ def _base_components() -> dict[str, ModuleType]:

return {
"alarm_control_panel": alarm_control_panel,
"assist_satellite": assist_satellite,
"calendar": calendar,
"camera": camera,
"climate": climate,
Expand Down
13 changes: 12 additions & 1 deletion tests/components/assist_satellite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ class MockAssistSatellite(AssistSatelliteEntity):
"""Mock Assist Satellite Entity."""

_attr_name = "Test Entity"
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
_attr_supported_features = (
AssistSatelliteEntityFeature.ANNOUNCE
| AssistSatelliteEntityFeature.START_CONVERSATION
)
_attr_tts_options = {"test-option": "test-value"}

def __init__(self) -> None:
"""Initialize the mock entity."""
Expand All @@ -59,6 +63,7 @@ def __init__(self) -> None:
active_wake_words=["1234"],
max_active_wake_words=1,
)
self.start_conversations = []

def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
Expand All @@ -79,6 +84,12 @@ async def async_set_configuration(
"""Set the current satellite configuration."""
self.config = config

async def async_start_conversation(
self, start_announcement: AssistSatelliteConfiguration
) -> None:
"""Start a conversation from the satellite."""
self.start_conversations.append((self._extra_system_prompt, start_announcement))


@pytest.fixture
def entity() -> MockAssistSatellite:
Expand Down
117 changes: 104 additions & 13 deletions tests/components/assist_satellite/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,24 @@
from homeassistant.components.media_source import PlayMedia
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError

from . import ENTITY_ID
from .conftest import MockAssistSatellite


@pytest.fixture(autouse=True)
async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None:
"""Set up a pipeline with a TTS engine."""
await async_update_pipeline(
hass,
async_get_pipeline(hass),
tts_engine="tts.mock_entity",
tts_language="en",
tts_voice="test-voice",
)


async def test_entity_state(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None:
Expand Down Expand Up @@ -64,7 +77,7 @@ async def test_entity_state(
assert kwargs["stt_stream"] is audio_stream
assert kwargs["pipeline_id"] is None
assert kwargs["device_id"] is None
assert kwargs["tts_audio_output"] is None
assert kwargs["tts_audio_output"] == {"test-option": "test-value"}
assert kwargs["wake_word_phrase"] is None
assert kwargs["audio_settings"] == AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
Expand Down Expand Up @@ -189,24 +202,12 @@ async def test_announce(
expected_params: tuple[str, str],
) -> None:
"""Test announcing on a device."""
await async_update_pipeline(
hass,
async_get_pipeline(hass),
tts_engine="tts.mock_entity",
tts_language="en",
tts_voice="test-voice",
)

entity._attr_tts_options = {"test-option": "test-value"}

original_announce = entity.async_announce
announce_started = asyncio.Event()

async def async_announce(announcement):
# Verify state change
assert entity.state == AssistSatelliteState.RESPONDING
await original_announce(announcement)
announce_started.set()

def tts_generate_media_source_id(
hass: HomeAssistant,
Expand Down Expand Up @@ -464,3 +465,93 @@ async def test_vad_sensitivity_entity_not_found(

with pytest.raises(RuntimeError):
await entity.async_accept_pipeline_from_satellite(audio_stream)


@pytest.mark.parametrize(
("service_data", "expected_params"),
[
(
{
"start_message": "Hello",
"extra_system_prompt": "Better system prompt",
},
(
"Better system prompt",
AssistSatelliteAnnouncement(
"Hello", "https://www.home-assistant.io/resolved.mp3", "tts"
),
),
),
(
{
"start_message": "Hello",
"start_media_id": "media-source://bla",
},
(
None,
AssistSatelliteAnnouncement(
"Hello", "https://www.home-assistant.io/resolved.mp3", "media_id"
),
),
),
(
{"start_media_id": "http://example.com/bla.mp3"},
(
None,
AssistSatelliteAnnouncement("", "http://example.com/bla.mp3", "url"),
),
),
],
)
async def test_start_conversation(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
service_data: dict,
expected_params: tuple[str, str],
) -> None:
"""Test starting a conversation on a device."""
await async_update_pipeline(
hass,
async_get_pipeline(hass),
conversation_engine="conversation.some_llm",
)

with (
patch(
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
return_value="media-source://bla",
),
patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=PlayMedia(
url="https://www.home-assistant.io/resolved.mp3",
mime_type="audio/mp3",
),
),
):
await hass.services.async_call(
"assist_satellite",
"start_conversation",
service_data,
target={"entity_id": "assist_satellite.test_entity"},
blocking=True,
)

assert entity.start_conversations[0] == expected_params


async def test_start_conversation_reject_builtin_agent(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
) -> None:
"""Test starting a conversation on a device."""
with pytest.raises(HomeAssistantError):
await hass.services.async_call(
"assist_satellite",
"start_conversation",
{"start_message": "Hey!"},
target={"entity_id": "assist_satellite.test_entity"},
blocking=True,
)

0 comments on commit 89b5235

Please sign in to comment.