From 4ecc6555bf5aada9f4923ae5d4d0884edffa70f2 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Wed, 4 Sep 2024 12:42:41 -0500 Subject: [PATCH] Add support for sample bytes in preferred TTS format (#125235) --- .../components/assist_pipeline/__init__.py | 3 +- .../components/assist_pipeline/pipeline.py | 7 +- homeassistant/components/tts/__init__.py | 26 ++++++ tests/components/assist_pipeline/test_init.py | 84 +++++++++++++++++-- 4 files changed, 112 insertions(+), 8 deletions(-) diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index 8ee053162b066..0a03402105abd 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import AsyncIterable +from typing import Any import voluptuous as vol @@ -99,7 +100,7 @@ async def async_pipeline_from_audio_stream( wake_word_phrase: str | None = None, pipeline_id: str | None = None, conversation_id: str | None = None, - tts_audio_output: str | None = None, + tts_audio_output: str | dict[str, Any] | None = None, wake_word_settings: WakeWordSettings | None = None, audio_settings: AudioSettings | None = None, device_id: str | None = None, diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 342f811c99bf7..f6a6bc45b57eb 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -538,7 +538,7 @@ class PipelineRun: language: str = None # type: ignore[assignment] runner_data: Any | None = None intent_agent: str | None = None - tts_audio_output: str | None = None + tts_audio_output: str | dict[str, Any] | None = None wake_word_settings: WakeWordSettings | None = None audio_settings: AudioSettings = field(default_factory=AudioSettings) @@ -1052,12 +1052,15 @@ async def prepare_text_to_speech(self) -> None: if self.pipeline.tts_voice is not None: tts_options[tts.ATTR_VOICE] = self.pipeline.tts_voice - if self.tts_audio_output is not None: + if isinstance(self.tts_audio_output, dict): + tts_options.update(self.tts_audio_output) + elif isinstance(self.tts_audio_output, str): tts_options[tts.ATTR_PREFERRED_FORMAT] = self.tts_audio_output if self.tts_audio_output == "wav": # 16 Khz, 16-bit mono tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = SAMPLE_RATE tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = SAMPLE_CHANNELS + tts_options[tts.ATTR_PREFERRED_SAMPLE_BYTES] = SAMPLE_WIDTH try: options_supported = await tts.async_support_options( diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index 70bb2b4c713f8..9e3d9f65a768b 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -77,6 +77,7 @@ "ATTR_PREFERRED_FORMAT", "ATTR_PREFERRED_SAMPLE_RATE", "ATTR_PREFERRED_SAMPLE_CHANNELS", + "ATTR_PREFERRED_SAMPLE_BYTES", "CONF_LANG", "DEFAULT_CACHE_DIR", "generate_media_source_id", @@ -95,6 +96,7 @@ ATTR_PREFERRED_FORMAT = "preferred_format" ATTR_PREFERRED_SAMPLE_RATE = "preferred_sample_rate" ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels" +ATTR_PREFERRED_SAMPLE_BYTES = "preferred_sample_bytes" ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id" ATTR_VOICE = "voice" @@ -103,6 +105,7 @@ ATTR_PREFERRED_FORMAT, ATTR_PREFERRED_SAMPLE_RATE, ATTR_PREFERRED_SAMPLE_CHANNELS, + ATTR_PREFERRED_SAMPLE_BYTES, } CONF_LANG = "language" @@ -223,6 +226,7 @@ async def async_convert_audio( to_extension: str, to_sample_rate: int | None = None, to_sample_channels: int | None = None, + to_sample_bytes: int | None = None, ) -> bytes: """Convert audio to a preferred format using ffmpeg.""" ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass) @@ -234,6 +238,7 @@ async def async_convert_audio( to_extension, to_sample_rate=to_sample_rate, to_sample_channels=to_sample_channels, + to_sample_bytes=to_sample_bytes, ) ) @@ -245,6 +250,7 @@ def _convert_audio( to_extension: str, to_sample_rate: int | None = None, to_sample_channels: int | None = None, + to_sample_bytes: int | None = None, ) -> bytes: """Convert audio to a preferred format using ffmpeg.""" @@ -277,6 +283,10 @@ def _convert_audio( # Max quality for MP3 command.extend(["-q:a", "0"]) + if to_sample_bytes == 2: + # 16-bit samples + command.extend(["-sample_fmt", "s16"]) + command.append(output_file.name) with subprocess.Popen( @@ -738,11 +748,25 @@ async def _async_get_tts_audio( else: sample_rate = options.pop(ATTR_PREFERRED_SAMPLE_RATE, None) + if sample_rate is not None: + sample_rate = int(sample_rate) + if ATTR_PREFERRED_SAMPLE_CHANNELS in supported_options: sample_channels = options.get(ATTR_PREFERRED_SAMPLE_CHANNELS) else: sample_channels = options.pop(ATTR_PREFERRED_SAMPLE_CHANNELS, None) + if sample_channels is not None: + sample_channels = int(sample_channels) + + if ATTR_PREFERRED_SAMPLE_BYTES in supported_options: + sample_bytes = options.get(ATTR_PREFERRED_SAMPLE_BYTES) + else: + sample_bytes = options.pop(ATTR_PREFERRED_SAMPLE_BYTES, None) + + if sample_bytes is not None: + sample_bytes = int(sample_bytes) + async def get_tts_data() -> str: """Handle data available.""" if engine_instance.name is None or engine_instance.name is UNDEFINED: @@ -769,6 +793,7 @@ async def get_tts_data() -> str: (final_extension != extension) or (sample_rate is not None) or (sample_channels is not None) + or (sample_bytes is not None) ) if needs_conversion: @@ -779,6 +804,7 @@ async def get_tts_data() -> str: to_extension=final_extension, to_sample_rate=sample_rate, to_sample_channels=sample_channels, + to_sample_bytes=sample_bytes, ) # Create file infos diff --git a/tests/components/assist_pipeline/test_init.py b/tests/components/assist_pipeline/test_init.py index 31cc1268098c2..c4696573bade3 100644 --- a/tests/components/assist_pipeline/test_init.py +++ b/tests/components/assist_pipeline/test_init.py @@ -788,13 +788,12 @@ async def test_tts_audio_output( assert len(extra_options) == 0, extra_options -async def test_tts_supports_preferred_format( +async def test_tts_wav_preferred_format( hass: HomeAssistant, hass_client: ClientSessionGenerator, mock_tts_provider: MockTTSProvider, init_components, pipeline_data: assist_pipeline.pipeline.PipelineData, - snapshot: SnapshotAssertion, ) -> None: """Test that preferred format options are given to the TTS system if supported.""" client = await hass_client() @@ -829,6 +828,80 @@ async def test_tts_supports_preferred_format( tts.ATTR_PREFERRED_FORMAT, tts.ATTR_PREFERRED_SAMPLE_RATE, tts.ATTR_PREFERRED_SAMPLE_CHANNELS, + tts.ATTR_PREFERRED_SAMPLE_BYTES, + ] + ) + + with ( + patch.object(mock_tts_provider, "_supported_options", supported_options), + patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio, + ): + await pipeline_input.execute() + + for event in events: + if event.type == assist_pipeline.PipelineEventType.TTS_END: + # We must fetch the media URL to trigger the TTS + assert event.data + media_id = event.data["tts_output"]["media_id"] + resolved = await media_source.async_resolve_media(hass, media_id, None) + await client.get(resolved.url) + + assert mock_get_tts_audio.called + options = mock_get_tts_audio.call_args_list[0].kwargs["options"] + + # We should have received preferred format options in get_tts_audio + assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav" + assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000 + assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1 + assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2 + + +async def test_tts_dict_preferred_format( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + mock_tts_provider: MockTTSProvider, + init_components, + pipeline_data: assist_pipeline.pipeline.PipelineData, +) -> None: + """Test that preferred format options are given to the TTS system if supported.""" + client = await hass_client() + assert await async_setup_component(hass, media_source.DOMAIN, {}) + + events: list[assist_pipeline.PipelineEvent] = [] + + pipeline_store = pipeline_data.pipeline_store + pipeline_id = pipeline_store.async_get_preferred_item() + pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id) + + pipeline_input = assist_pipeline.pipeline.PipelineInput( + tts_input="This is a test.", + conversation_id=None, + device_id=None, + run=assist_pipeline.pipeline.PipelineRun( + hass, + context=Context(), + pipeline=pipeline, + start_stage=assist_pipeline.PipelineStage.TTS, + end_stage=assist_pipeline.PipelineStage.TTS, + event_callback=events.append, + tts_audio_output={ + tts.ATTR_PREFERRED_FORMAT: "flac", + tts.ATTR_PREFERRED_SAMPLE_RATE: 48000, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2, + tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, + }, + ), + ) + await pipeline_input.validate() + + # Make the TTS provider support preferred format options + supported_options = list(mock_tts_provider.supported_options or []) + supported_options.extend( + [ + tts.ATTR_PREFERRED_FORMAT, + tts.ATTR_PREFERRED_SAMPLE_RATE, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS, + tts.ATTR_PREFERRED_SAMPLE_BYTES, ] ) @@ -850,6 +923,7 @@ async def test_tts_supports_preferred_format( options = mock_get_tts_audio.call_args_list[0].kwargs["options"] # We should have received preferred format options in get_tts_audio - assert tts.ATTR_PREFERRED_FORMAT in options - assert tts.ATTR_PREFERRED_SAMPLE_RATE in options - assert tts.ATTR_PREFERRED_SAMPLE_CHANNELS in options + assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac" + assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000 + assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2 + assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2