Skip to content

Commit

Permalink
Add WS API to stt (home-assistant#91329)
Browse files Browse the repository at this point in the history
  • Loading branch information
emontnemery authored Apr 17, 2023
1 parent e3ff7d0 commit b5817e4
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 2 deletions.
45 changes: 44 additions & 1 deletion homeassistant/components/stt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
HTTPNotFound,
HTTPUnsupportedMediaType,
)
import voluptuous as vol

from homeassistant.components import websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import dt as dt_util
from homeassistant.util import dt as dt_util, language as language_util

from .const import (
DATA_PROVIDERS,
Expand Down Expand Up @@ -74,6 +76,8 @@ def async_get_speech_to_text_entity(

async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up STT."""
websocket_api.async_register_command(hass, websocket_list_engines)

component = hass.data[DOMAIN] = EntityComponent[SpeechToTextEntity](
_LOGGER, DOMAIN, hass
)
Expand Down Expand Up @@ -376,3 +380,42 @@ def _metadata_from_header(request: web.Request) -> SpeechMetadata:
)
except ValueError as err:
raise ValueError(f"Wrong format of X-Speech-Content: {err}") from err


@websocket_api.websocket_command(
{
"type": "stt/engine/list",
vol.Optional("language"): str,
}
)
@callback
def websocket_list_engines(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""List speech to text engines and, optionally, if they support a given language."""
component: EntityComponent[SpeechToTextEntity] = hass.data[DOMAIN]
legacy_providers: dict[str, Provider] = hass.data[DATA_PROVIDERS]

language = msg.get("language")
providers = []
provider_info: dict[str, Any]

for entity in component.entities:
provider_info = {"engine_id": entity.entity_id}
if language:
provider_info["language_supported"] = bool(
language_util.matches(language, entity.supported_languages)
)
providers.append(provider_info)

for engine_id, provider in legacy_providers.items():
provider_info = {"engine_id": engine_id}
if language:
provider_info["language_supported"] = bool(
language_util.matches(language, provider.supported_languages)
)
providers.append(provider_info)

connection.send_message(
websocket_api.result_message(msg["id"], {"providers": providers})
)
47 changes: 46 additions & 1 deletion tests/components/stt/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
mock_platform,
mock_restore_cache,
)
from tests.typing import ClientSessionGenerator
from tests.typing import ClientSessionGenerator, WebSocketGenerator

TEST_DOMAIN = "test"

Expand Down Expand Up @@ -375,3 +375,48 @@ async def test_restore_state(
state = hass.states.get(entity_id)
assert state
assert state.state == timestamp


@pytest.mark.parametrize(
("setup", "engine_id"),
[("mock_setup", "test"), ("mock_config_entry_setup", "stt.test")],
indirect=["setup"],
)
async def test_ws_list_engines(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
setup: str,
engine_id: str,
) -> None:
"""Test listing speech to text engines."""
client = await hass_ws_client()

await client.send_json_auto_id({"type": "stt/engine/list"})

msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"providers": [{"engine_id": engine_id}]}

await client.send_json_auto_id({"type": "stt/engine/list", "language": "smurfish"})

msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"providers": [{"engine_id": engine_id, "language_supported": False}]
}

await client.send_json_auto_id({"type": "stt/engine/list", "language": "en"})

msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"providers": [{"engine_id": engine_id, "language_supported": True}]
}

await client.send_json_auto_id({"type": "stt/engine/list", "language": "en-UK"})

msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"providers": [{"engine_id": engine_id, "language_supported": True}]
}

0 comments on commit b5817e4

Please sign in to comment.