Skip to content

Commit

Permalink
Change wake word interception to a subscription (#125629)
Browse files Browse the repository at this point in the history
* Allow stopping intercepting wake words

* Make wake word interception a subscription

* Keep future

* Add test for unsub
  • Loading branch information
synesthesiam authored Sep 16, 2024
1 parent 3ba39d5 commit c63cab3
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 31 deletions.
19 changes: 17 additions & 2 deletions homeassistant/components/assist_satellite/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent

Expand Down Expand Up @@ -42,5 +43,19 @@ async def websocket_intercept_wake_word(
)
return

wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})
async def intercept_wake_word() -> None:
"""Push an intercepted wake word to websocket."""
try:
wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_message(
websocket_api.event_message(
msg["id"],
{"wake_word_phrase": wake_word_phrase},
)
)
except HomeAssistantError as err:
connection.send_error(msg["id"], "home_assistant_error", str(err))

task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word")
connection.subscriptions[msg["id"]] = task.cancel
connection.send_message(websocket_api.result_message(msg["id"]))
141 changes: 112 additions & 29 deletions tests/components/assist_satellite/test_websocket_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Test WebSocket API."""

import asyncio
from unittest.mock import patch

import pytest

from homeassistant.components.assist_pipeline import PipelineStage
from homeassistant.config_entries import ConfigEntry
Expand Down Expand Up @@ -28,20 +31,23 @@ async def test_intercept_wake_word(
"entity_id": ENTITY_ID,
}
)

for _ in range(3):
await asyncio.sleep(0)
msg = await ws_client.receive_json()
assert msg["success"]
assert msg["result"] is None
subscription_id = msg["id"]

await entity.async_accept_pipeline_from_satellite(
object(),
object(), # type: ignore[arg-type]
start_stage=PipelineStage.STT,
wake_word_phrase="ok, nabu",
)

response = await ws_client.receive_json()
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert response["success"]
assert response["result"] == {"wake_word_phrase": "ok, nabu"}
assert msg["id"] == subscription_id
assert msg["type"] == "event"
assert msg["event"] == {"wake_word_phrase": "ok, nabu"}


async def test_intercept_wake_word_requires_on_device_wake_word(
Expand All @@ -60,18 +66,23 @@ async def test_intercept_wake_word_requires_on_device_wake_word(
}
)

for _ in range(3):
await asyncio.sleep(0)
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["success"]
assert msg["result"] is None

await entity.async_accept_pipeline_from_satellite(
object(),
object(), # type: ignore[arg-type]
# Emulate wake word processing in Home Assistant
start_stage=PipelineStage.WAKE_WORD,
)

response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "Only on-device wake words currently supported",
}
Expand All @@ -93,18 +104,23 @@ async def test_intercept_wake_word_requires_wake_word_phrase(
}
)

for _ in range(3):
await asyncio.sleep(0)
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["success"]
assert msg["result"] is None

await entity.async_accept_pipeline_from_satellite(
object(),
object(), # type: ignore[arg-type]
start_stage=PipelineStage.STT,
# We are not passing wake word phrase
)

response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "No wake word phrase provided",
}
Expand All @@ -128,10 +144,12 @@ async def test_intercept_wake_word_require_admin(
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()

assert not response["success"]
assert response["error"] == {
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert not msg["success"]
assert msg["error"] == {
"code": "unauthorized",
"message": "Unauthorized",
}
Expand All @@ -152,10 +170,11 @@ async def test_intercept_wake_word_invalid_satellite(
"entity_id": "assist_satellite.invalid",
}
)
response = await ws_client.receive_json()
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert not response["success"]
assert response["error"] == {
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": "Entity not found",
}
Expand All @@ -167,7 +186,7 @@ async def test_intercept_wake_word_twice(
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word requires admin access."""
"""Test intercepting a wake word twice cancels the previous request."""
ws_client = await hass_ws_client(hass)

await ws_client.send_json_auto_id(
Expand All @@ -177,16 +196,80 @@ async def test_intercept_wake_word_twice(
}
)

async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["success"]
assert msg["result"] is None

task = hass.async_create_task(ws_client.receive_json())

await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()

assert not response["success"]
assert response["error"] == {
# Should get an error from previous subscription
async with asyncio.timeout(1):
msg = await task

assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "Wake word interception already in progress",
}

# Response to second subscription
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["success"]
assert msg["result"] is None


async def test_intercept_wake_word_unsubscribe(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test that closing the websocket connection stops interception."""
ws_client = await hass_ws_client(hass)

await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)

# Wait for interception to start
for _ in range(3):
await asyncio.sleep(0)

async def receive_json():
with pytest.raises(TypeError):
# Raises TypeError when connection is closed
await ws_client.receive_json()

task = hass.async_create_task(receive_json())

# Close connection
await ws_client.close()
await task

with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_pipeline_from_audio_stream,
):
# Start a pipeline with a wake word
await entity.async_accept_pipeline_from_satellite(
object(),
wake_word_phrase="ok, nabu", # type: ignore[arg-type]
)

# Wake word should not be intercepted
mock_pipeline_from_audio_stream.assert_called_once()

0 comments on commit c63cab3

Please sign in to comment.