Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change wake word interception to a subscription #125629

Merged
merged 4 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions homeassistant/components/assist_satellite/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent

Expand Down Expand Up @@ -42,5 +43,19 @@ async def websocket_intercept_wake_word(
)
return

wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})
async def intercept_wake_word() -> None:
"""Push an intercepted wake word to websocket."""
try:
wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_message(
websocket_api.event_message(
msg["id"],
{"wake_word_phrase": wake_word_phrase},
)
)
except HomeAssistantError as err:
connection.send_error(msg["id"], "home_assistant_error", str(err))

task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word")
connection.subscriptions[msg["id"]] = task.cancel
synesthesiam marked this conversation as resolved.
Show resolved Hide resolved
connection.send_message(websocket_api.result_message(msg["id"]))
141 changes: 112 additions & 29 deletions tests/components/assist_satellite/test_websocket_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Test WebSocket API."""

import asyncio
from unittest.mock import patch

import pytest

from homeassistant.components.assist_pipeline import PipelineStage
from homeassistant.config_entries import ConfigEntry
Expand Down Expand Up @@ -28,20 +31,23 @@ async def test_intercept_wake_word(
"entity_id": ENTITY_ID,
}
)

for _ in range(3):
await asyncio.sleep(0)
msg = await ws_client.receive_json()
assert msg["success"]
assert msg["result"] is None
subscription_id = msg["id"]

await entity.async_accept_pipeline_from_satellite(
object(),
object(), # type: ignore[arg-type]
start_stage=PipelineStage.STT,
wake_word_phrase="ok, nabu",
)

response = await ws_client.receive_json()
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert response["success"]
assert response["result"] == {"wake_word_phrase": "ok, nabu"}
assert msg["id"] == subscription_id
assert msg["type"] == "event"
assert msg["event"] == {"wake_word_phrase": "ok, nabu"}


async def test_intercept_wake_word_requires_on_device_wake_word(
Expand All @@ -60,18 +66,23 @@ async def test_intercept_wake_word_requires_on_device_wake_word(
}
)

for _ in range(3):
await asyncio.sleep(0)
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["success"]
assert msg["result"] is None

await entity.async_accept_pipeline_from_satellite(
object(),
object(), # type: ignore[arg-type]
# Emulate wake word processing in Home Assistant
start_stage=PipelineStage.WAKE_WORD,
)

response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "Only on-device wake words currently supported",
}
Expand All @@ -93,18 +104,23 @@ async def test_intercept_wake_word_requires_wake_word_phrase(
}
)

for _ in range(3):
await asyncio.sleep(0)
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["success"]
assert msg["result"] is None

await entity.async_accept_pipeline_from_satellite(
object(),
object(), # type: ignore[arg-type]
start_stage=PipelineStage.STT,
# We are not passing wake word phrase
)

response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "No wake word phrase provided",
}
Expand All @@ -128,10 +144,12 @@ async def test_intercept_wake_word_require_admin(
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()

assert not response["success"]
assert response["error"] == {
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert not msg["success"]
assert msg["error"] == {
"code": "unauthorized",
"message": "Unauthorized",
}
Expand All @@ -152,10 +170,11 @@ async def test_intercept_wake_word_invalid_satellite(
"entity_id": "assist_satellite.invalid",
}
)
response = await ws_client.receive_json()
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert not response["success"]
assert response["error"] == {
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": "Entity not found",
}
Expand All @@ -167,7 +186,7 @@ async def test_intercept_wake_word_twice(
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word requires admin access."""
"""Test intercepting a wake word twice cancels the previous request."""
ws_client = await hass_ws_client(hass)

await ws_client.send_json_auto_id(
Expand All @@ -177,16 +196,80 @@ async def test_intercept_wake_word_twice(
}
)

async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["success"]
assert msg["result"] is None

task = hass.async_create_task(ws_client.receive_json())

await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()

assert not response["success"]
assert response["error"] == {
# Should get an error from previous subscription
async with asyncio.timeout(1):
msg = await task

assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "Wake word interception already in progress",
}

# Response to second subscription
async with asyncio.timeout(1):
msg = await ws_client.receive_json()

assert msg["success"]
assert msg["result"] is None


async def test_intercept_wake_word_unsubscribe(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test that closing the websocket connection stops interception."""
ws_client = await hass_ws_client(hass)

await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)

# Wait for interception to start
for _ in range(3):
await asyncio.sleep(0)

async def receive_json():
with pytest.raises(TypeError):
# Raises TypeError when connection is closed
await ws_client.receive_json()

task = hass.async_create_task(receive_json())

# Close connection
await ws_client.close()
await task

with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_pipeline_from_audio_stream,
):
# Start a pipeline with a wake word
await entity.async_accept_pipeline_from_satellite(
object(),
wake_word_phrase="ok, nabu", # type: ignore[arg-type]
)

# Wake word should not be intercepted
mock_pipeline_from_audio_stream.assert_called_once()
Loading