Skip to content

Commit

Permalink
Keep future
Browse files Browse the repository at this point in the history
  • Loading branch information
synesthesiam committed Sep 10, 2024
1 parent 02a4492 commit e6b71e0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 54 deletions.
59 changes: 29 additions & 30 deletions homeassistant/components/assist_satellite/entity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Assist satellite entity."""

from abc import abstractmethod
from collections.abc import AsyncIterable, Callable
import asyncio
from collections.abc import AsyncIterable
from enum import StrEnum
import logging
import time
Expand All @@ -23,13 +24,13 @@
from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import CALLBACK_TYPE, Context, callback
from homeassistant.core import Context, callback
from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription
from homeassistant.util import ulid

from .const import AssistSatelliteEntityFeature
from .errors import SatelliteBusyError
from .errors import AssistSatelliteError, SatelliteBusyError

_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes

Expand Down Expand Up @@ -70,7 +71,7 @@ class AssistSatelliteEntity(entity.Entity):

_run_has_tts: bool = False
_is_announcing = False
_wake_word_listener: Callable[[str | None, str], None] | None = None
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
_attr_tts_options: dict[str, Any] | None = None

__assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD
Expand All @@ -96,25 +97,24 @@ def tts_options(self) -> dict[str, Any] | None:
"""Options passed for text-to-speech."""
return self._attr_tts_options

def async_intercept_wake_word(
self, wake_word_listener: Callable[[str | None, str], None]
) -> CALLBACK_TYPE:
"""Set a listener to intercept the next wake word from the satellite.
async def async_intercept_wake_word(self) -> str | None:
"""Intercept the next wake word from the satellite.
Returns a callback to remove the listener.
Returns the detected wake word phrase or None.
"""
if self._wake_word_listener is not None:
# Cancel existing interception
self._wake_word_listener(None, "Only one interception is allowed")
if self._wake_word_intercept_future is not None:
raise SatelliteBusyError("Wake word interception already in progress")

self._wake_word_listener = wake_word_listener
# Will cause next wake word to be intercepted in
# async_accept_pipeline_from_satellite
self._wake_word_intercept_future = asyncio.Future()

@callback
def unsubscribe() -> None:
if self._wake_word_listener == wake_word_listener:
self._wake_word_listener = None
_LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id)

return unsubscribe
try:
return await self._wake_word_intercept_future
finally:
self._wake_word_intercept_future = None

async def async_internal_announce(
self,
Expand Down Expand Up @@ -193,15 +193,16 @@ async def async_accept_pipeline_from_satellite(
wake_word_phrase: str | None = None,
) -> None:
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
if (self._wake_word_listener is not None) and start_stage in (
if self._wake_word_intercept_future and start_stage in (
PipelineStage.WAKE_WORD,
PipelineStage.STT,
):
if start_stage == PipelineStage.WAKE_WORD:
self._wake_word_listener(
None, "Only on-device wake words currently supported"
self._wake_word_intercept_future.set_exception(
AssistSatelliteError(
"Only on-device wake words currently supported"
)
)
self._wake_word_listener = None
return

# Intercepting wake word and immediately end pipeline
Expand All @@ -211,14 +212,12 @@ async def async_accept_pipeline_from_satellite(
self.entity_id,
)

try:
if wake_word_phrase is None:
self._wake_word_listener(None, "No wake word phrase provided")
else:
self._wake_word_listener(wake_word_phrase, "")
finally:
self._wake_word_listener = None

if wake_word_phrase is None:
self._wake_word_intercept_future.set_exception(
AssistSatelliteError("No wake word phrase provided")
)
else:
self._wake_word_intercept_future.set_result(wake_word_phrase)
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
return

Expand Down
16 changes: 9 additions & 7 deletions homeassistant/components/assist_satellite/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent

Expand Down Expand Up @@ -42,20 +43,21 @@ async def websocket_intercept_wake_word(
)
return

@callback
def wake_word_listener(wake_word_phrase: str | None, message: str) -> None:
async def intercept_wake_word() -> None:
"""Push an intercepted wake word to websocket."""
if wake_word_phrase is not None:
try:
wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_message(
websocket_api.event_message(
msg["id"],
{"wake_word_phrase": wake_word_phrase},
)
)
else:
connection.send_error(msg["id"], "home_assistant_error", message)
except HomeAssistantError as err:
connection.send_error(msg["id"], "home_assistant_error", str(err))

connection.subscriptions[msg["id"]] = satellite.async_intercept_wake_word(
wake_word_listener
task = hass.async_create_background_task(
intercept_wake_word(), "intercept_wake_word"
)
connection.subscriptions[msg["id"]] = task.cancel
connection.send_message(websocket_api.result_message(msg["id"]))
53 changes: 36 additions & 17 deletions tests/components/assist_satellite/test_websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ async def test_intercept_wake_word(
wake_word_phrase="ok, nabu",
)

msg = await ws_client.receive_json()
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["id"] == subscription_id
assert msg["type"] == "event"
assert msg["event"] == {"wake_word_phrase": "ok, nabu"}
Expand All @@ -61,7 +63,9 @@ async def test_intercept_wake_word_requires_on_device_wake_word(
}
)

msg = await ws_client.receive_json()
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["success"]
assert msg["result"] is None

Expand All @@ -71,9 +75,11 @@ async def test_intercept_wake_word_requires_on_device_wake_word(
start_stage=PipelineStage.WAKE_WORD,
)

response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "Only on-device wake words currently supported",
}
Expand All @@ -95,7 +101,9 @@ async def test_intercept_wake_word_requires_wake_word_phrase(
}
)

msg = await ws_client.receive_json()
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["success"]
assert msg["result"] is None

Expand All @@ -105,9 +113,11 @@ async def test_intercept_wake_word_requires_wake_word_phrase(
# We are not passing wake word phrase
)

response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "No wake word phrase provided",
}
Expand All @@ -131,10 +141,12 @@ async def test_intercept_wake_word_require_admin(
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()

assert not response["success"]
assert response["error"] == {
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert not msg["success"]
assert msg["error"] == {
"code": "unauthorized",
"message": "Unauthorized",
}
Expand All @@ -155,7 +167,8 @@ async def test_intercept_wake_word_invalid_satellite(
"entity_id": "assist_satellite.invalid",
}
)
msg = await ws_client.receive_json()
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert not msg["success"]
assert msg["error"] == {
Expand All @@ -180,7 +193,9 @@ async def test_intercept_wake_word_twice(
}
)

msg = await ws_client.receive_json()
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["success"]
assert msg["result"] is None

Expand All @@ -194,14 +209,18 @@ async def test_intercept_wake_word_twice(
)

# Should get an error from previous subscription
msg = await task
async with asyncio.timeout(1):
msg = await task

assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "Only one interception is allowed",
"message": "Wake word interception already in progress",
}

# Response to second subscription
msg = await ws_client.receive_json()
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["success"]
assert msg["result"] is None

0 comments on commit e6b71e0

Please sign in to comment.