From 3f3a3bcc8ac7eec2e5e9eba9981c74db3842f22d Mon Sep 17 00:00:00 2001 From: Pascal Vizeli Date: Wed, 11 Jan 2017 16:31:16 +0100 Subject: [PATCH] Cleanup language support on TTS (#5255) * Cleanup language support on TTS * change to default_language & address comments * Cleanup not needed code / comment from paulus --- homeassistant/components/tts/__init__.py | 32 +++++++++---- homeassistant/components/tts/demo.py | 36 ++++++++++++--- homeassistant/components/tts/google.py | 22 +++++---- homeassistant/components/tts/picotts.py | 22 +++++++-- homeassistant/components/tts/voicerss.py | 37 +++++++++------ tests/components/tts/test_init.py | 58 +++++++++++++++++++----- 6 files changed, 152 insertions(+), 55 deletions(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 01d0a6a15e331d..0f731a51485059 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -5,7 +5,6 @@ https://home-assistant.io/components/tts/ """ import asyncio -import functools import hashlib import logging import mimetypes @@ -247,8 +246,6 @@ def remove_files(): def async_register_engine(self, engine, provider, config): """Register a TTS provider.""" provider.hass = self.hass - if CONF_LANG in config: - provider.language = config.get(CONF_LANG) self.providers[engine] = provider @asyncio.coroutine @@ -257,9 +254,16 @@ def async_get_url(self, engine, message, cache=None, language=None): This method is a coroutine. """ + provider = self.providers[engine] + + language = language or provider.default_language + if language is None or \ + language not in provider.supported_languages: + raise HomeAssistantError("Not supported language {0}".format( + language)) + msg_hash = hashlib.sha1(bytes(message, 'utf-8')).hexdigest() - language_key = language or self.providers[engine].language - key = KEY_PATTERN.format(msg_hash, language_key, engine).lower() + key = KEY_PATTERN.format(msg_hash, language, engine).lower() use_cache = cache if cache is not None else self.use_cache # is speech allready in memory @@ -387,13 +391,22 @@ class Provider(object): """Represent a single provider.""" hass = None - language = None - def get_tts_audio(self, message, language=None): + @property + def default_language(self): + """Default language.""" + return None + + @property + def supported_languages(self): + """List of supported languages.""" + return None + + def get_tts_audio(self, message, language): """Load tts audio file from provider.""" raise NotImplementedError() - def async_get_tts_audio(self, message, language=None): + def async_get_tts_audio(self, message, language): """Load tts audio file from provider. Return a tuple of file extension and data as bytes. @@ -401,8 +414,7 @@ def async_get_tts_audio(self, message, language=None): This method must be run in the event loop and returns a coroutine. """ return self.hass.loop.run_in_executor( - None, - functools.partial(self.get_tts_audio, message, language=language)) + None, self.get_tts_audio, message, language) class TextToSpeechView(HomeAssistantView): diff --git a/homeassistant/components/tts/demo.py b/homeassistant/components/tts/demo.py index 68d49d58f78a14..88afa0643f2b71 100644 --- a/homeassistant/components/tts/demo.py +++ b/homeassistant/components/tts/demo.py @@ -6,28 +6,50 @@ """ import os -from homeassistant.components.tts import Provider +import voluptuous as vol + +from homeassistant.components.tts import Provider, PLATFORM_SCHEMA, CONF_LANG + +SUPPORT_LANGUAGES = [ + 'en', 'de' +] + +DEFAULT_LANG = 'en' + +PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ + vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES), +}) def get_engine(hass, config): """Setup Demo speech component.""" - return DemoProvider() + return DemoProvider(config[CONF_LANG]) class DemoProvider(Provider): """Demo speech api provider.""" - def __init__(self): - """Initialize demo provider for TTS.""" - self.language = 'en' + def __init__(self, lang): + """Initialize demo provider.""" + self._lang = lang + + @property + def default_language(self): + """Default language.""" + return self._lang + + @property + def supported_languages(self): + """List of supported languages.""" + return SUPPORT_LANGUAGES - def get_tts_audio(self, message, language=None): + def get_tts_audio(self, message, language): """Load TTS from demo.""" filename = os.path.join(os.path.dirname(__file__), "demo.mp3") try: with open(filename, 'rb') as voice: data = voice.read() except OSError: - return + return (None, None) return ("mp3", data) diff --git a/homeassistant/components/tts/google.py b/homeassistant/components/tts/google.py index e1bb4e5e4e5b4d..dc03013d4f1a4f 100644 --- a/homeassistant/components/tts/google.py +++ b/homeassistant/components/tts/google.py @@ -42,15 +42,16 @@ @asyncio.coroutine def async_get_engine(hass, config): """Setup Google speech component.""" - return GoogleProvider(hass) + return GoogleProvider(hass, config[CONF_LANG]) class GoogleProvider(Provider): """Google speech api provider.""" - def __init__(self, hass): + def __init__(self, hass, lang): """Init Google TTS service.""" self.hass = hass + self._lang = lang self.headers = { 'Referer': "http://translate.google.com/", 'User-Agent': ("Mozilla/5.0 (Windows NT 10.0; WOW64) " @@ -58,8 +59,18 @@ def __init__(self, hass): "Chrome/47.0.2526.106 Safari/537.36") } + @property + def default_language(self): + """Default language.""" + return self._lang + + @property + def supported_languages(self): + """List of supported languages.""" + return SUPPORT_LANGUAGES + @asyncio.coroutine - def async_get_tts_audio(self, message, language=None): + def async_get_tts_audio(self, message, language): """Load TTS from google.""" from gtts_token import gtts_token @@ -67,11 +78,6 @@ def async_get_tts_audio(self, message, language=None): websession = async_get_clientsession(self.hass) message_parts = self._split_message_to_parts(message) - # If language is not specified or is not supported - use the language - # from the config. - if language not in SUPPORT_LANGUAGES: - language = self.language - data = b'' for idx, part in enumerate(message_parts): part_token = yield from self.hass.loop.run_in_executor( diff --git a/homeassistant/components/tts/picotts.py b/homeassistant/components/tts/picotts.py index 366973813a288d..28db88c03b04d8 100644 --- a/homeassistant/components/tts/picotts.py +++ b/homeassistant/components/tts/picotts.py @@ -29,18 +29,31 @@ def get_engine(hass, config): if shutil.which("pico2wave") is None: _LOGGER.error("'pico2wave' was not found") return False - return PicoProvider() + return PicoProvider(config[CONF_LANG]) class PicoProvider(Provider): """pico speech api provider.""" - def get_tts_audio(self, message, language=None): + def __init__(self, lang): + """Initialize pico provider.""" + self._lang = lang + + @property + def default_language(self): + """Default language.""" + return self._lang + + @property + def supported_languages(self): + """List of supported languages.""" + return SUPPORT_LANGUAGES + + def get_tts_audio(self, message, language): """Load TTS using pico2wave.""" - if language not in SUPPORT_LANGUAGES: - language = self.language with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmpf: fname = tmpf.name + cmd = ['pico2wave', '--wave', fname, '-l', language, message] subprocess.call(cmd) data = None @@ -52,6 +65,7 @@ def get_tts_audio(self, message, language=None): return (None, None) finally: os.remove(fname) + if data: return ("wav", data) return (None, None) diff --git a/homeassistant/components/tts/voicerss.py b/homeassistant/components/tts/voicerss.py index 688ae7f6e25221..2dda27b0c06728 100644 --- a/homeassistant/components/tts/voicerss.py +++ b/homeassistant/components/tts/voicerss.py @@ -93,27 +93,34 @@ class VoiceRSSProvider(Provider): def __init__(self, hass, conf): """Init VoiceRSS TTS service.""" self.hass = hass - self.extension = conf.get(CONF_CODEC) - - self.form_data = { - 'key': conf.get(CONF_API_KEY), - 'hl': conf.get(CONF_LANG), - 'c': (conf.get(CONF_CODEC)).upper(), - 'f': conf.get(CONF_FORMAT), + self._extension = conf[CONF_CODEC] + self._lang = conf[CONF_LANG] + + self._form_data = { + 'key': conf[CONF_API_KEY], + 'hl': conf[CONF_LANG], + 'c': (conf[CONF_CODEC]).upper(), + 'f': conf[CONF_FORMAT], } + @property + def default_language(self): + """Default language.""" + return self._lang + + @property + def supported_languages(self): + """List of supported languages.""" + return SUPPORT_LANGUAGES + @asyncio.coroutine - def async_get_tts_audio(self, message, language=None): + def async_get_tts_audio(self, message, language): """Load TTS from voicerss.""" websession = async_get_clientsession(self.hass) - form_data = self.form_data.copy() + form_data = self._form_data.copy() form_data['src'] = message - - # If language is specified and supported - use it instead of the - # language in the config. - if language in SUPPORT_LANGUAGES: - form_data['hl'] = language + form_data['hl'] = language request = None try: @@ -141,4 +148,4 @@ def async_get_tts_audio(self, message, language=None): if request is not None: yield from request.release() - return (self.extension, data) + return (self._extension, data) diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 553813953132f5..715b98c4740a82 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -22,7 +22,7 @@ class TestTTS(object): def setup_method(self): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() - self.demo_provider = DemoProvider() + self.demo_provider = DemoProvider('en') self.default_tts_cache = self.hass.config.path(tts.DEFAULT_CACHE_DIR) def teardown_method(self): @@ -95,7 +95,7 @@ def test_setup_component_and_test_service_with_config_language(self): config = { tts.DOMAIN: { 'platform': 'demo', - 'language': 'lang' + 'language': 'de' } } @@ -111,11 +111,23 @@ def test_setup_component_and_test_service_with_config_language(self): assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_lang_demo.mp3") \ + "_de_demo.mp3") \ != -1 assert os.path.isfile(os.path.join( self.default_tts_cache, - "265944c108cbb00b2a621be5930513e03a0bb2cd_lang_demo.mp3")) + "265944c108cbb00b2a621be5930513e03a0bb2cd_de_demo.mp3")) + + def test_setup_component_and_test_service_with_wrong_conf_language(self): + """Setup the demo platform and call service with wrong config.""" + config = { + tts.DOMAIN: { + 'platform': 'demo', + 'language': 'ru' + } + } + + with assert_setup_component(0, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) def test_setup_component_and_test_service_with_service_language(self): """Setup the demo platform and call service.""" @@ -132,7 +144,7 @@ def test_setup_component_and_test_service_with_service_language(self): self.hass.services.call(tts.DOMAIN, 'demo_say', { tts.ATTR_MESSAGE: "I person is on front of your door.", - tts.ATTR_LANGUAGE: "lang", + tts.ATTR_LANGUAGE: "de", }) self.hass.block_till_done() @@ -140,9 +152,33 @@ def test_setup_component_and_test_service_with_service_language(self): assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find( "/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd" - "_lang_demo.mp3") \ + "_de_demo.mp3") \ != -1 assert os.path.isfile(os.path.join( + self.default_tts_cache, + "265944c108cbb00b2a621be5930513e03a0bb2cd_de_demo.mp3")) + + def test_setup_component_test_service_with_wrong_service_language(self): + """Setup the demo platform and call service.""" + calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + config = { + tts.DOMAIN: { + 'platform': 'demo', + } + } + + with assert_setup_component(1, tts.DOMAIN): + setup_component(self.hass, tts.DOMAIN, config) + + self.hass.services.call(tts.DOMAIN, 'demo_say', { + tts.ATTR_MESSAGE: "I person is on front of your door.", + tts.ATTR_LANGUAGE: "lang", + }) + self.hass.block_till_done() + + assert len(calls) == 0 + assert not os.path.isfile(os.path.join( self.default_tts_cache, "265944c108cbb00b2a621be5930513e03a0bb2cd_lang_demo.mp3")) @@ -198,7 +234,7 @@ def test_setup_component_and_test_service_with_receive_voice(self): assert len(calls) == 1 req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID]) - _, demo_data = self.demo_provider.get_tts_audio("bla") + _, demo_data = self.demo_provider.get_tts_audio("bla", 'en') assert req.status_code == 200 assert req.content == demo_data @@ -319,7 +355,7 @@ def test_setup_component_test_with_cache_dir(self): """Setup demo platform with cache and call service without cache.""" calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) - _, demo_data = self.demo_provider.get_tts_audio("bla") + _, demo_data = self.demo_provider.get_tts_audio("bla", 'en') cache_file = os.path.join( self.default_tts_cache, "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3") @@ -339,7 +375,7 @@ def test_setup_component_test_with_cache_dir(self): setup_component(self.hass, tts.DOMAIN, config) with patch('homeassistant.components.tts.demo.DemoProvider.' - 'get_tts_audio', return_value=None): + 'get_tts_audio', return_value=(None, None)): self.hass.services.call(tts.DOMAIN, 'demo_say', { tts.ATTR_MESSAGE: "I person is on front of your door.", }) @@ -352,7 +388,7 @@ def test_setup_component_test_with_cache_dir(self): != -1 @patch('homeassistant.components.tts.demo.DemoProvider.get_tts_audio', - return_value=None) + return_value=(None, None)) def test_setup_component_test_with_error_on_get_tts(self, tts_mock): """Setup demo platform with wrong get_tts_audio.""" calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) @@ -375,7 +411,7 @@ def test_setup_component_test_with_error_on_get_tts(self, tts_mock): def test_setup_component_load_cache_retrieve_without_mem_cache(self): """Setup component and load cache and get without mem cache.""" - _, demo_data = self.demo_provider.get_tts_audio("bla") + _, demo_data = self.demo_provider.get_tts_audio("bla", 'en') cache_file = os.path.join( self.default_tts_cache, "265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")