-
-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Voice assistant integration with pipelines (#89822)
* Initial commit * Add websocket test tool * Small tweak * Tiny cleanup * Make pipeline work with frontend branch * Add some more info to start event * Fixes * First voice assistant tests * Remove run_task * Clean up for PR * Add config_flow.py * Remove CLI tool * Simplify by removing stt/tts for now * Clean up and fix tests * More clean up and API changes * Add quality_scale * Remove data from run-finish * Use StrEnum backport --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
- Loading branch information
1 parent
81c0382
commit e16f17f
Showing
10 changed files
with
392 additions
and
0 deletions.
There are no files selected for viewing
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
"""The Voice Assistant integration.""" | ||
from __future__ import annotations | ||
|
||
from homeassistant.core import HomeAssistant | ||
from homeassistant.helpers.typing import ConfigType | ||
|
||
from .const import DEFAULT_PIPELINE, DOMAIN | ||
from .pipeline import Pipeline | ||
from .websocket_api import async_register_websocket_api | ||
|
||
|
||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: | ||
"""Set up Voice Assistant integration.""" | ||
hass.data[DOMAIN] = { | ||
DEFAULT_PIPELINE: Pipeline( | ||
name=DEFAULT_PIPELINE, | ||
language=None, | ||
conversation_engine=None, | ||
) | ||
} | ||
async_register_websocket_api(hass) | ||
|
||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
"""Constants for the Voice Assistant integration.""" | ||
DOMAIN = "voice_assistant" | ||
DEFAULT_PIPELINE = "default" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
{ | ||
"domain": "voice_assistant", | ||
"name": "Voice Assistant", | ||
"codeowners": ["@balloob", "@synesthesiam"], | ||
"dependencies": ["conversation"], | ||
"documentation": "https://www.home-assistant.io/integrations/voice_assistant", | ||
"iot_class": "local_push", | ||
"quality_scale": "internal" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
"""Classes for voice assistant pipelines.""" | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
from collections.abc import Callable | ||
from dataclasses import dataclass, field | ||
from typing import Any | ||
|
||
from homeassistant.backports.enum import StrEnum | ||
from homeassistant.components import conversation | ||
from homeassistant.core import Context, HomeAssistant | ||
from homeassistant.util.dt import utcnow | ||
|
||
DEFAULT_TIMEOUT = 30 # seconds | ||
|
||
|
||
@dataclass | ||
class PipelineRequest: | ||
"""Request to start a pipeline run.""" | ||
|
||
intent_input: str | ||
conversation_id: str | None = None | ||
|
||
|
||
class PipelineEventType(StrEnum): | ||
"""Event types emitted during a pipeline run.""" | ||
|
||
RUN_START = "run-start" | ||
RUN_FINISH = "run-finish" | ||
INTENT_START = "intent-start" | ||
INTENT_FINISH = "intent-finish" | ||
ERROR = "error" | ||
|
||
|
||
@dataclass | ||
class PipelineEvent: | ||
"""Events emitted during a pipeline run.""" | ||
|
||
type: PipelineEventType | ||
data: dict[str, Any] | None = None | ||
timestamp: str = field(default_factory=lambda: utcnow().isoformat()) | ||
|
||
def as_dict(self) -> dict[str, Any]: | ||
"""Return a dict representation of the event.""" | ||
return { | ||
"type": self.type, | ||
"timestamp": self.timestamp, | ||
"data": self.data or {}, | ||
} | ||
|
||
|
||
@dataclass | ||
class Pipeline: | ||
"""A voice assistant pipeline.""" | ||
|
||
name: str | ||
language: str | None | ||
conversation_engine: str | None | ||
|
||
async def run( | ||
self, | ||
hass: HomeAssistant, | ||
context: Context, | ||
request: PipelineRequest, | ||
event_callback: Callable[[PipelineEvent], None], | ||
timeout: int | float | None = DEFAULT_TIMEOUT, | ||
) -> None: | ||
"""Run a pipeline with an optional timeout.""" | ||
await asyncio.wait_for( | ||
self._run(hass, context, request, event_callback), timeout=timeout | ||
) | ||
|
||
async def _run( | ||
self, | ||
hass: HomeAssistant, | ||
context: Context, | ||
request: PipelineRequest, | ||
event_callback: Callable[[PipelineEvent], None], | ||
) -> None: | ||
"""Run a pipeline.""" | ||
language = self.language or hass.config.language | ||
event_callback( | ||
PipelineEvent( | ||
PipelineEventType.RUN_START, | ||
{ | ||
"pipeline": self.name, | ||
"language": language, | ||
}, | ||
) | ||
) | ||
|
||
intent_input = request.intent_input | ||
|
||
event_callback( | ||
PipelineEvent( | ||
PipelineEventType.INTENT_START, | ||
{ | ||
"engine": self.conversation_engine or "default", | ||
"intent_input": intent_input, | ||
}, | ||
) | ||
) | ||
|
||
conversation_result = await conversation.async_converse( | ||
hass=hass, | ||
text=intent_input, | ||
conversation_id=request.conversation_id, | ||
context=context, | ||
language=language, | ||
agent_id=self.conversation_engine, | ||
) | ||
|
||
event_callback( | ||
PipelineEvent( | ||
PipelineEventType.INTENT_FINISH, | ||
{"intent_output": conversation_result.as_dict()}, | ||
) | ||
) | ||
|
||
event_callback( | ||
PipelineEvent( | ||
PipelineEventType.RUN_FINISH, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
"""Voice Assistant Websocket API.""" | ||
from typing import Any | ||
|
||
import voluptuous as vol | ||
|
||
from homeassistant.components import websocket_api | ||
from homeassistant.core import HomeAssistant, callback | ||
|
||
from .const import DOMAIN | ||
from .pipeline import DEFAULT_TIMEOUT, PipelineRequest | ||
|
||
|
||
@callback | ||
def async_register_websocket_api(hass: HomeAssistant) -> None: | ||
"""Register the websocket API.""" | ||
websocket_api.async_register_command(hass, websocket_run) | ||
|
||
|
||
@websocket_api.websocket_command( | ||
{ | ||
vol.Required("type"): "voice_assistant/run", | ||
vol.Optional("pipeline", default="default"): str, | ||
vol.Required("intent_input"): str, | ||
vol.Optional("conversation_id"): vol.Any(str, None), | ||
vol.Optional("timeout"): vol.Any(float, int), | ||
} | ||
) | ||
@websocket_api.async_response | ||
async def websocket_run( | ||
hass: HomeAssistant, | ||
connection: websocket_api.ActiveConnection, | ||
msg: dict[str, Any], | ||
) -> None: | ||
"""Run a pipeline.""" | ||
pipeline_id = msg["pipeline"] | ||
pipeline = hass.data[DOMAIN].get(pipeline_id) | ||
if pipeline is None: | ||
connection.send_error( | ||
msg["id"], "pipeline_not_found", f"Pipeline not found: {pipeline_id}" | ||
) | ||
return | ||
|
||
# Run pipeline with a timeout. | ||
# Events are sent over the websocket connection. | ||
timeout = msg.get("timeout", DEFAULT_TIMEOUT) | ||
run_task = hass.async_create_task( | ||
pipeline.run( | ||
hass, | ||
connection.context(msg), | ||
request=PipelineRequest( | ||
intent_input=msg["intent_input"], | ||
conversation_id=msg.get("conversation_id"), | ||
), | ||
event_callback=lambda event: connection.send_event( | ||
msg["id"], event.as_dict() | ||
), | ||
timeout=timeout, | ||
) | ||
) | ||
|
||
# Cancel pipeline if user unsubscribes | ||
connection.subscriptions[msg["id"]] = run_task.cancel | ||
|
||
connection.send_result(msg["id"]) | ||
|
||
# Task contains a timeout | ||
await run_task |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Tests for the Voice Assistant integration.""" |
Oops, something went wrong.