Skip to content

Commit

Permalink
Add API to fetch Assist devices (home-assistant#107333)
Browse files Browse the repository at this point in the history
* Add API to fetch Assist devices

* Revert some changes to fixture, make a single fixture for an Assist device
  • Loading branch information
balloob authored Jan 6, 2024
1 parent 6201e81 commit f1d2868
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 23 deletions.
14 changes: 11 additions & 3 deletions homeassistant/components/assist_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,7 +1703,7 @@ async def _change_listener(
pipeline_run.abort_wake_word_detection = True


@dataclass
@dataclass(slots=True)
class DeviceAudioQueue:
"""Audio capture queue for a satellite device."""

Expand All @@ -1717,19 +1717,27 @@ class DeviceAudioQueue:
"""Flag to be set if audio samples were dropped because the queue was full."""


@dataclass(slots=True)
class AssistDevice:
"""Assist device."""

domain: str
unique_id_prefix: str


class PipelineData:
"""Store and debug data stored in hass.data."""

def __init__(self, pipeline_store: PipelineStorageCollection) -> None:
"""Initialize."""
self.pipeline_store = pipeline_store
self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {}
self.pipeline_devices: set[str] = set()
self.pipeline_devices: dict[str, AssistDevice] = {}
self.pipeline_runs = PipelineRuns(pipeline_store)
self.device_audio_queues: dict[str, DeviceAudioQueue] = {}


@dataclass
@dataclass(slots=True)
class PipelineRunDebug:
"""Debug data for a pipelinerun."""

Expand Down
17 changes: 12 additions & 5 deletions homeassistant/components/assist_pipeline/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from homeassistant.helpers import collection, entity_registry as er, restore_state

from .const import DOMAIN
from .pipeline import PipelineData, PipelineStorageCollection
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
from .vad import VadSensitivity

OPTION_PREFERRED = "preferred"
Expand Down Expand Up @@ -70,8 +70,10 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
_attr_current_option = OPTION_PREFERRED
_attr_options = [OPTION_PREFERRED]

def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None:
def __init__(self, hass: HomeAssistant, domain: str, unique_id_prefix: str) -> None:
"""Initialize a pipeline selector."""
self._domain = domain
self._unique_id_prefix = unique_id_prefix
self._attr_unique_id = f"{unique_id_prefix}-pipeline"
self.hass = hass
self._update_options()
Expand All @@ -91,11 +93,16 @@ async def async_added_to_hass(self) -> None:
self._attr_current_option = state.state

if self.registry_entry and (device_id := self.registry_entry.device_id):
pipeline_data.pipeline_devices.add(device_id)
self.async_on_remove(
lambda: pipeline_data.pipeline_devices.discard(device_id)
pipeline_data.pipeline_devices[device_id] = AssistDevice(
self._domain, self._unique_id_prefix
)

def cleanup() -> None:
"""Clean up registered device."""
pipeline_data.pipeline_devices.pop(device_id)

self.async_on_remove(cleanup)

async def async_select_option(self, option: str) -> None:
"""Select an option."""
self._attr_current_option = option
Expand Down
32 changes: 31 additions & 1 deletion homeassistant/components/assist_pipeline/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from homeassistant.components import conversation, stt, tts, websocket_api
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.util import language as language_util

from .const import (
Expand Down Expand Up @@ -53,6 +53,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, websocket_run)
websocket_api.async_register_command(hass, websocket_list_languages)
websocket_api.async_register_command(hass, websocket_list_runs)
websocket_api.async_register_command(hass, websocket_list_devices)
websocket_api.async_register_command(hass, websocket_get_run)
websocket_api.async_register_command(hass, websocket_device_capture)

Expand Down Expand Up @@ -287,6 +288,35 @@ def websocket_list_runs(
)


@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/device/list",
}
)
def websocket_list_devices(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""List assist devices."""
pipeline_data: PipelineData = hass.data[DOMAIN]
ent_reg = er.async_get(hass)
connection.send_result(
msg["id"],
[
{
"device_id": device_id,
"pipeline_entity": ent_reg.async_get_entity_id(
"select", info.domain, f"{info.unique_id_prefix}-pipeline"
),
}
for device_id, info in pipeline_data.pipeline_devices.items()
],
)


@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
Expand Down
3 changes: 2 additions & 1 deletion homeassistant/components/esphome/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback

from .const import DOMAIN
from .domain_data import DomainData
from .entity import (
EsphomeAssistEntity,
Expand Down Expand Up @@ -75,7 +76,7 @@ class EsphomeAssistPipelineSelect(EsphomeAssistEntity, AssistPipelineSelect):
def __init__(self, hass: HomeAssistant, entry_data: RuntimeEntryData) -> None:
"""Initialize a pipeline selector."""
EsphomeAssistEntity.__init__(self, entry_data)
AssistPipelineSelect.__init__(self, hass, self._device_info.mac_address)
AssistPipelineSelect.__init__(self, hass, DOMAIN, self._device_info.mac_address)


class EsphomeVadSensitivitySelect(EsphomeAssistEntity, VadSensitivitySelect):
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/voip/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None:
"""Initialize a pipeline selector."""
VoIPEntity.__init__(self, device)
AssistPipelineSelect.__init__(self, hass, device.voip_id)
AssistPipelineSelect.__init__(self, hass, DOMAIN, device.voip_id)


class VoipVadSensitivitySelect(VoIPEntity, VadSensitivitySelect):
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/wyoming/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, hass: HomeAssistant, device: SatelliteDevice) -> None:
self.device = device

WyomingSatelliteEntity.__init__(self, device)
AssistPipelineSelect.__init__(self, hass, device.satellite_id)
AssistPipelineSelect.__init__(self, hass, DOMAIN, device.satellite_id)

async def async_select_option(self, option: str) -> None:
"""Select an option."""
Expand Down
81 changes: 78 additions & 3 deletions tests/components/assist_pipeline/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import pytest

from homeassistant.components import stt, tts, wake_word
from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline import DOMAIN, select as assist_select
from homeassistant.components.assist_pipeline.pipeline import (
PipelineData,
PipelineStorageCollection,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_setup_component

Expand Down Expand Up @@ -288,7 +290,7 @@ async def async_setup_entry_init(
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setups(
config_entry, [stt.DOMAIN, wake_word.DOMAIN]
config_entry, [Platform.STT, Platform.WAKE_WORD]
)
return True

Expand All @@ -297,7 +299,7 @@ async def async_unload_entry_init(
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_unload_platforms(
config_entry, [stt.DOMAIN, wake_word.DOMAIN]
config_entry, [Platform.STT, Platform.WAKE_WORD]
)
return True

Expand Down Expand Up @@ -369,6 +371,79 @@ async def init_components(hass: HomeAssistant, init_supporting_components):
assert await async_setup_component(hass, "assist_pipeline", {})


@pytest.fixture
async def assist_device(hass: HomeAssistant, init_components) -> dr.DeviceEntry:
"""Create an assist device."""
config_entry = MockConfigEntry(domain="test_assist_device")
config_entry.add_to_hass(hass)

dev_reg = dr.async_get(hass)
device = dev_reg.async_get_or_create(
name="Test Device",
config_entry_id=config_entry.entry_id,
identifiers={("test_assist_device", "test")},
)

async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setups(
config_entry, [Platform.SELECT]
)
return True

async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_unload_platforms(
config_entry, [Platform.SELECT]
)
return True

async def async_setup_entry_select_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test select platform via config entry."""
entities = [
assist_select.AssistPipelineSelect(
hass, "test_assist_device", "test-prefix"
),
assist_select.VadSensitivitySelect(hass, "test-prefix"),
]
for ent in entities:
ent._attr_device_info = dr.DeviceInfo(
identifiers={("test_assist_device", "test")},
)
async_add_entities(entities)

mock_integration(
hass,
MockModule(
"test_assist_device",
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
mock_platform(
hass,
"test_assist_device.select",
MockPlatform(
async_setup_entry=async_setup_entry_select_platform,
),
)
mock_platform(hass, "test_assist_device.config_flow")

with mock_config_flow("test_assist_device", ConfigFlow):
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()

return device


@pytest.fixture
def pipeline_data(hass: HomeAssistant, init_components) -> PipelineData:
"""Return pipeline data."""
Expand Down
19 changes: 11 additions & 8 deletions tests/components/assist_pipeline/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from homeassistant.components.assist_pipeline import Pipeline
from homeassistant.components.assist_pipeline.pipeline import (
AssistDevice,
PipelineData,
PipelineStorageCollection,
)
Expand Down Expand Up @@ -33,7 +34,7 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up fake select platform."""
pipeline_entity = AssistPipelineSelect(hass, "test")
pipeline_entity = AssistPipelineSelect(hass, "test-domain", "test-prefix")
pipeline_entity._attr_device_info = DeviceInfo(
identifiers={("test", "test")},
)
Expand Down Expand Up @@ -109,13 +110,15 @@ async def test_select_entity_registering_device(
assert device is not None

# Test device is registered
assert pipeline_data.pipeline_devices == {device.id}
assert pipeline_data.pipeline_devices == {
device.id: AssistDevice("test-domain", "test-prefix")
}

await hass.config_entries.async_remove(init_select.entry_id)
await hass.async_block_till_done()

# Test device is removed
assert pipeline_data.pipeline_devices == set()
assert pipeline_data.pipeline_devices == {}


async def test_select_entity_changing_pipelines(
Expand All @@ -128,7 +131,7 @@ async def test_select_entity_changing_pipelines(
"""Test entity tracking pipeline changes."""
config_entry = init_select # nicer naming

state = hass.states.get("select.assist_pipeline_test_pipeline")
state = hass.states.get("select.assist_pipeline_test_prefix_pipeline")
assert state is not None
assert state.state == "preferred"
assert state.attributes["options"] == [
Expand All @@ -143,28 +146,28 @@ async def test_select_entity_changing_pipelines(
"select",
"select_option",
{
"entity_id": "select.assist_pipeline_test_pipeline",
"entity_id": "select.assist_pipeline_test_prefix_pipeline",
"option": pipeline_2.name,
},
blocking=True,
)

state = hass.states.get("select.assist_pipeline_test_pipeline")
state = hass.states.get("select.assist_pipeline_test_prefix_pipeline")
assert state is not None
assert state.state == pipeline_2.name

# Reload config entry to test selected option persists
assert await hass.config_entries.async_forward_entry_unload(config_entry, "select")
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")

state = hass.states.get("select.assist_pipeline_test_pipeline")
state = hass.states.get("select.assist_pipeline_test_prefix_pipeline")
assert state is not None
assert state.state == pipeline_2.name

# Remove selected pipeline
await pipeline_storage.async_delete_item(pipeline_2.id)

state = hass.states.get("select.assist_pipeline_test_pipeline")
state = hass.states.get("select.assist_pipeline_test_prefix_pipeline")
assert state is not None
assert state.state == "preferred"
assert state.attributes["options"] == [
Expand Down
19 changes: 19 additions & 0 deletions tests/components/assist_pipeline/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2502,3 +2502,22 @@ async def test_pipeline_empty_tts_output(
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])


async def test_pipeline_list_devices(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
assist_device,
) -> None:
"""Test list devices."""
client = await hass_ws_client(hass)

await client.send_json_auto_id({"type": "assist_pipeline/device/list"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == [
{
"device_id": assist_device.id,
"pipeline_entity": "select.test_assist_device_test_prefix_pipeline",
}
]

0 comments on commit f1d2868

Please sign in to comment.