Skip to content

Commit

Permalink
Allow picking a pipeline for voip devices (home-assistant#91524)
Browse files Browse the repository at this point in the history
* Allow picking a pipeline for voip device

* Add tests

* Fix test

* Adjust on new pipeline data
  • Loading branch information
balloob authored Apr 17, 2023
1 parent 9bd12f6 commit bd22e0b
Show file tree
Hide file tree
Showing 12 changed files with 323 additions and 8 deletions.
8 changes: 4 additions & 4 deletions homeassistant/components/assist_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,13 @@ async def async_pipeline_from_audio_stream(
tts_options: dict | None = None,
) -> None:
"""Create an audio pipeline from an audio stream."""
if language is None:
if language is None and pipeline_id is None:
language = hass.config.language

# Temporary workaround for language codes
if language == "en":
language = "en-US"

if stt_metadata.language == "":
stt_metadata.language = language

if context is None:
context = Context()

Expand All @@ -75,6 +72,9 @@ async def async_pipeline_from_audio_stream(
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
)

if stt_metadata.language == "":
stt_metadata.language = pipeline.language

pipeline_input = PipelineInput(
conversation_id=conversation_id,
stt_metadata=stt_metadata,
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/assist_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class Pipeline:
"""A voice assistant pipeline."""

conversation_engine: str | None
language: str | None
language: str
name: str
stt_engine: str | None
tts_engine: str | None
Expand Down
95 changes: 95 additions & 0 deletions homeassistant/components/assist_pipeline/select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Select entities for a pipeline."""

from __future__ import annotations

from collections.abc import Iterable

from homeassistant.components.select import SelectEntity, SelectEntityDescription
from homeassistant.const import EntityCategory, Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import collection, entity_registry as er, restore_state

from .const import DOMAIN
from .pipeline import PipelineStorageCollection

OPTION_PREFERRED = "preferred"


@callback
def get_chosen_pipeline(
hass: HomeAssistant, domain: str, unique_id_prefix: str
) -> str | None:
"""Get the chosen pipeline for a domain."""
ent_reg = er.async_get(hass)
pipeline_entity_id = ent_reg.async_get_entity_id(
Platform.SELECT, domain, f"{unique_id_prefix}-pipeline"
)
if pipeline_entity_id is None:
return None

state = hass.states.get(pipeline_entity_id)
if state is None or state.state == OPTION_PREFERRED:
return None

pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
return next(
(item.id for item in pipeline_store.async_items() if item.name == state.state),
None,
)


class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
"""Entity to represent a pipeline selector."""

entity_description = SelectEntityDescription(
key="pipeline",
translation_key="pipeline",
entity_category=EntityCategory.CONFIG,
)
_attr_should_poll = False
_attr_current_option = OPTION_PREFERRED
_attr_options = [OPTION_PREFERRED]

def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None:
"""Initialize a pipeline selector."""
self._attr_unique_id = f"{unique_id_prefix}-pipeline"
self.hass = hass
self._update_options()

async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()

pipeline_store: PipelineStorageCollection = self.hass.data[
DOMAIN
].pipeline_store
pipeline_store.async_add_change_set_listener(self._pipelines_updated)

state = await self.async_get_last_state()
if state is not None and state.state in self.options:
self._attr_current_option = state.state

async def async_select_option(self, option: str) -> None:
"""Select an option."""
self._attr_current_option = option
self.async_write_ha_state()

async def _pipelines_updated(
self, change_sets: Iterable[collection.CollectionChangeSet]
) -> None:
"""Handle pipeline update."""
self._update_options()
self.async_write_ha_state()

@callback
def _update_options(self) -> None:
"""Handle pipeline update."""
pipeline_store: PipelineStorageCollection = self.hass.data[
DOMAIN
].pipeline_store
options = [OPTION_PREFERRED]
options.extend(sorted(item.name for item in pipeline_store.async_items()))
self._attr_options = options

if self._attr_current_option not in options:
self._attr_current_option = OPTION_PREFERRED
12 changes: 12 additions & 0 deletions homeassistant/components/assist_pipeline/strings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"entity": {
"select": {
"pipeline": {
"name": "Assist Pipeline",
"state": {
"preferred": "Preferred"
}
}
}
}
}
1 change: 1 addition & 0 deletions homeassistant/components/voip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

PLATFORMS = (
Platform.BINARY_SENSOR,
Platform.SELECT,
Platform.SWITCH,
)
_LOGGER = logging.getLogger(__name__)
Expand Down
46 changes: 46 additions & 0 deletions homeassistant/components/voip/select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Select entities for VoIP integration."""

from __future__ import annotations

from typing import TYPE_CHECKING

from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback

from .const import DOMAIN
from .devices import VoIPDevice
from .entity import VoIPEntity

if TYPE_CHECKING:
from . import DomainData


async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up VoIP switch entities."""
domain_data: DomainData = hass.data[DOMAIN]

@callback
def async_add_device(device: VoIPDevice) -> None:
"""Add device."""
async_add_entities([VoipPipelineSelect(hass, device)])

domain_data.devices.async_add_new_device_listener(async_add_device)

async_add_entities(
[VoipPipelineSelect(hass, device) for device in domain_data.devices]
)


class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
"""Pipeline selector for VoIP devices."""

def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None:
"""Initialize a pipeline selector."""
VoIPEntity.__init__(self, device)
AssistPipelineSelect.__init__(self, hass, device.voip_id)
8 changes: 8 additions & 0 deletions homeassistant/components/voip/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
"allow_call": {
"name": "Allow Calls"
}
},
"select": {
"pipeline": {
"name": "[%key:component::assist_pipeline::entity::select::pipeline::name%]",
"state": {
"preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]"
}
}
}
}
}
7 changes: 6 additions & 1 deletion homeassistant/components/voip/voip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
PipelineEvent,
PipelineEventType,
async_pipeline_from_audio_stream,
select as pipeline_select,
)
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
from homeassistant.const import __version__
from homeassistant.core import HomeAssistant

from .const import DOMAIN

if TYPE_CHECKING:
from .devices import VoIPDevice, VoIPDevices

Expand Down Expand Up @@ -151,7 +154,9 @@ async def stt_stream():
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream(),
language=self.language,
pipeline_id=pipeline_select.get_chosen_pipeline(
self.hass, DOMAIN, self.voip_device.voip_id
),
conversation_id=self._conversation_id,
tts_options={tts.ATTR_AUDIO_OUTPUT: "raw"},
)
Expand Down
8 changes: 6 additions & 2 deletions homeassistant/helpers/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,24 @@ def async_items(self) -> list[_ItemT]:
return list(self.data.values())

@callback
def async_add_listener(self, listener: ChangeListener) -> None:
def async_add_listener(self, listener: ChangeListener) -> Callable[[], None]:
"""Add a listener.
Will be called with (change_type, item_id, updated_config).
"""
self.listeners.append(listener)
return lambda: self.listeners.remove(listener)

@callback
def async_add_change_set_listener(self, listener: ChangeSetListener) -> None:
def async_add_change_set_listener(
self, listener: ChangeSetListener
) -> Callable[[], None]:
"""Add a listener for a full change set.
Will be called with [(change_type, item_id, updated_config), ...]
"""
self.change_set_listeners.append(listener)
return lambda: self.change_set_listeners.remove(listener)

async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None:
"""Notify listeners of a change."""
Expand Down
8 changes: 8 additions & 0 deletions tests/components/assist_pipeline/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pytest

from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
Expand Down Expand Up @@ -137,3 +139,9 @@ async def init_components(
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(hass, "assist_pipeline", {})


@pytest.fixture
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
"""Return pipeline storage collection."""
return hass.data[DOMAIN].pipeline_store
Loading

0 comments on commit bd22e0b

Please sign in to comment.