Skip to content

Commit

Permalink
TTS Cleanup and expose get audio (home-assistant#79065)
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob authored Sep 26, 2022
1 parent 39ddc37 commit 697e7b3
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 84 deletions.
3 changes: 2 additions & 1 deletion homeassistant/components/media_source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
URI_SCHEME_REGEX,
)
from .error import MediaSourceError, Unresolvable
from .models import BrowseMediaSource, MediaSourceItem, PlayMedia
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia

__all__ = [
"DOMAIN",
Expand All @@ -46,6 +46,7 @@
"PlayMedia",
"MediaSourceItem",
"Unresolvable",
"MediaSource",
"MediaSourceError",
"MEDIA_CLASS_MAP",
"MEDIA_MIME_TYPES",
Expand Down
168 changes: 110 additions & 58 deletions homeassistant/components/tts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import os
from pathlib import Path
import re
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast

from aiohttp import web
import mutagen
Expand All @@ -28,7 +28,6 @@
SERVICE_PLAY_MEDIA,
MediaType,
)
from homeassistant.components.media_source import generate_media_source_id
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_DESCRIPTION,
Expand All @@ -48,6 +47,7 @@
from homeassistant.util.yaml import load_yaml

from .const import DOMAIN
from .media_source import generate_media_source_id, media_source_id_to_kwargs

_LOGGER = logging.getLogger(__name__)

Expand All @@ -74,9 +74,6 @@
DEFAULT_CACHE_DIR = "tts"
DEFAULT_TIME_MEMORY = 300

MEM_CACHE_FILENAME = "filename"
MEM_CACHE_VOICE = "voice"

SERVICE_CLEAR_CACHE = "clear_cache"
SERVICE_SAY = "say"

Expand Down Expand Up @@ -131,6 +128,24 @@ def valid_base_url(value: str) -> str:
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})


class TTSCache(TypedDict):
"""Cached TTS file."""

filename: str
voice: bytes


async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
"""Get TTS audio as extension, data."""
manager: SpeechManager = hass.data[DOMAIN]
return await manager.async_get_tts_audio(
**media_source_id_to_kwargs(media_source_id),
)


async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up TTS."""
tts = SpeechManager(hass)
Expand Down Expand Up @@ -197,30 +212,19 @@ async def async_setup_platform(
async def async_say_handle(service: ServiceCall) -> None:
"""Service handle for say."""
entity_ids = service.data[ATTR_ENTITY_ID]
message = service.data[ATTR_MESSAGE]
cache = service.data.get(ATTR_CACHE)
language = service.data.get(ATTR_LANGUAGE)
options = service.data.get(ATTR_OPTIONS)

tts.process_options(p_type, language, options)
params = {
"message": message,
}
if cache is not None:
params["cache"] = "true" if cache else "false"
if language is not None:
params["language"] = language
if options is not None:
params.update(options)

await hass.services.async_call(
DOMAIN_MP,
SERVICE_PLAY_MEDIA,
{
ATTR_ENTITY_ID: entity_ids,
ATTR_MEDIA_CONTENT_ID: generate_media_source_id(
DOMAIN,
str(yarl.URL.build(path=p_type, query=params)),
hass,
engine=p_type,
message=service.data[ATTR_MESSAGE],
language=service.data.get(ATTR_LANGUAGE),
options=service.data.get(ATTR_OPTIONS),
cache=service.data.get(ATTR_CACHE),
),
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
ATTR_MEDIA_ANNOUNCE: True,
Expand Down Expand Up @@ -296,7 +300,7 @@ def __init__(self, hass: HomeAssistant) -> None:
self.time_memory = DEFAULT_TIME_MEMORY
self.base_url: str | None = None
self.file_cache: dict[str, str] = {}
self.mem_cache: dict[str, dict[str, str | bytes]] = {}
self.mem_cache: dict[str, TTSCache] = {}

async def async_init_cache(
self, use_cache: bool, cache_dir: str, time_memory: int, base_url: str | None
Expand Down Expand Up @@ -380,10 +384,11 @@ def process_options(
options = options or provider.default_options

if options is not None:
supported_options = provider.supported_options or []
invalid_opts = [
opt_name
for opt_name in options.keys()
if opt_name not in (provider.supported_options or [])
if opt_name not in supported_options
]
if invalid_opts:
raise HomeAssistantError(f"Invalid options found: {invalid_opts}")
Expand All @@ -403,39 +408,80 @@ async def async_get_url_path(
This method is a coroutine.
"""
language, options = self.process_options(engine, language, options)
options_key = _hash_options(options) if options else "-"
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
cache_key = self._generate_cache_key(message, language, options, engine)
use_cache = cache if cache is not None else self.use_cache

key = KEY_PATTERN.format(
msg_hash, language.replace("_", "-"), options_key, engine
).lower()

# Is speech already in memory
if key in self.mem_cache:
filename = cast(str, self.mem_cache[key][MEM_CACHE_FILENAME])
if cache_key in self.mem_cache:
filename = self.mem_cache[cache_key]["filename"]
# Is file store in file cache
elif use_cache and key in self.file_cache:
filename = self.file_cache[key]
self.hass.async_create_task(self.async_file_to_mem(key))
elif use_cache and cache_key in self.file_cache:
filename = self.file_cache[cache_key]
self.hass.async_create_task(self._async_file_to_mem(cache_key))
# Load speech from provider into memory
else:
filename = await self.async_get_tts_audio(
engine, key, message, use_cache, language, options
filename = await self._async_get_tts_audio(
engine,
cache_key,
message,
use_cache,
language,
options,
)

return f"/api/tts_proxy/{filename}"

async def async_get_tts_audio(
self,
engine: str,
key: str,
message: str,
cache: bool | None = None,
language: str | None = None,
options: dict | None = None,
) -> tuple[str, bytes]:
"""Fetch TTS audio."""
language, options = self.process_options(engine, language, options)
cache_key = self._generate_cache_key(message, language, options, engine)
use_cache = cache if cache is not None else self.use_cache

# If we have the file, load it into memory if necessary
if cache_key not in self.mem_cache:
if use_cache and cache_key in self.file_cache:
await self._async_file_to_mem(cache_key)
else:
await self._async_get_tts_audio(
engine, cache_key, message, use_cache, language, options
)

extension = os.path.splitext(self.mem_cache[cache_key]["filename"])[1][1:]
data = self.mem_cache[cache_key]["voice"]
return extension, data

@callback
def _generate_cache_key(
self,
message: str,
language: str,
options: dict | None,
engine: str,
) -> str:
"""Generate a cache key for a message."""
options_key = _hash_options(options) if options else "-"
msg_hash = hashlib.sha1(bytes(message, "utf-8")).hexdigest()
return KEY_PATTERN.format(
msg_hash, language.replace("_", "-"), options_key, engine
).lower()

async def _async_get_tts_audio(
self,
engine: str,
cache_key: str,
message: str,
cache: bool,
language: str,
options: dict | None,
) -> str:
"""Receive TTS and store for view in cache.
"""Receive TTS, store for view in cache and return filename.
This method is a coroutine.
"""
Expand All @@ -446,7 +492,7 @@ async def async_get_tts_audio(
raise HomeAssistantError(f"No TTS from {engine} for '{message}'")

# Create file infos
filename = f"{key}.{extension}".lower()
filename = f"{cache_key}.{extension}".lower()

# Validate filename
if not _RE_VOICE_FILE.match(filename):
Expand All @@ -456,14 +502,18 @@ async def async_get_tts_audio(

# Save to memory
data = self.write_tags(filename, data, provider, message, language, options)
self._async_store_to_memcache(key, filename, data)
self._async_store_to_memcache(cache_key, filename, data)

if cache:
self.hass.async_create_task(self.async_save_tts_audio(key, filename, data))
self.hass.async_create_task(
self._async_save_tts_audio(cache_key, filename, data)
)

return filename

async def async_save_tts_audio(self, key: str, filename: str, data: bytes) -> None:
async def _async_save_tts_audio(
self, cache_key: str, filename: str, data: bytes
) -> None:
"""Store voice data to file and file_cache.
This method is a coroutine.
Expand All @@ -477,17 +527,17 @@ def save_speech() -> None:

try:
await self.hass.async_add_executor_job(save_speech)
self.file_cache[key] = filename
self.file_cache[cache_key] = filename
except OSError as err:
_LOGGER.error("Can't write %s: %s", filename, err)

async def async_file_to_mem(self, key: str) -> None:
async def _async_file_to_mem(self, cache_key: str) -> None:
"""Load voice from file cache into memory.
This method is a coroutine.
"""
if not (filename := self.file_cache.get(key)):
raise HomeAssistantError(f"Key {key} not in file cache!")
if not (filename := self.file_cache.get(cache_key)):
raise HomeAssistantError(f"Key {cache_key} not in file cache!")

voice_file = os.path.join(self.cache_dir, filename)

Expand All @@ -499,20 +549,22 @@ def load_speech() -> bytes:
try:
data = await self.hass.async_add_executor_job(load_speech)
except OSError as err:
del self.file_cache[key]
del self.file_cache[cache_key]
raise HomeAssistantError(f"Can't read {voice_file}") from err

self._async_store_to_memcache(key, filename, data)
self._async_store_to_memcache(cache_key, filename, data)

@callback
def _async_store_to_memcache(self, key: str, filename: str, data: bytes) -> None:
def _async_store_to_memcache(
self, cache_key: str, filename: str, data: bytes
) -> None:
"""Store data to memcache and set timer to remove it."""
self.mem_cache[key] = {MEM_CACHE_FILENAME: filename, MEM_CACHE_VOICE: data}
self.mem_cache[cache_key] = {"filename": filename, "voice": data}

@callback
def async_remove_from_mem() -> None:
"""Cleanup memcache."""
self.mem_cache.pop(key, None)
self.mem_cache.pop(cache_key, None)

self.hass.loop.call_later(self.time_memory, async_remove_from_mem)

Expand All @@ -524,17 +576,17 @@ async def async_read_tts(self, filename: str) -> tuple[str | None, bytes]:
if not (record := _RE_VOICE_FILE.match(filename.lower())):
raise HomeAssistantError("Wrong tts file format!")

key = KEY_PATTERN.format(
cache_key = KEY_PATTERN.format(
record.group(1), record.group(2), record.group(3), record.group(4)
)

if key not in self.mem_cache:
if key not in self.file_cache:
raise HomeAssistantError(f"{key} not in cache!")
await self.async_file_to_mem(key)
if cache_key not in self.mem_cache:
if cache_key not in self.file_cache:
raise HomeAssistantError(f"{cache_key} not in cache!")
await self._async_file_to_mem(cache_key)

content, _ = mimetypes.guess_type(filename)
return content, cast(bytes, self.mem_cache[key][MEM_CACHE_VOICE])
return content, self.mem_cache[cache_key]["voice"]

@staticmethod
def write_tags(
Expand Down
Loading

0 comments on commit 697e7b3

Please sign in to comment.