Skip to content

Commit

Permalink
Add support for sample bytes in preferred TTS format (home-assistant#…
Browse files Browse the repository at this point in the history
  • Loading branch information
synesthesiam authored Sep 4, 2024
1 parent 892c32c commit 4ecc655
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 8 deletions.
3 changes: 2 additions & 1 deletion homeassistant/components/assist_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from collections.abc import AsyncIterable
from typing import Any

import voluptuous as vol

Expand Down Expand Up @@ -99,7 +100,7 @@ async def async_pipeline_from_audio_stream(
wake_word_phrase: str | None = None,
pipeline_id: str | None = None,
conversation_id: str | None = None,
tts_audio_output: str | None = None,
tts_audio_output: str | dict[str, Any] | None = None,
wake_word_settings: WakeWordSettings | None = None,
audio_settings: AudioSettings | None = None,
device_id: str | None = None,
Expand Down
7 changes: 5 additions & 2 deletions homeassistant/components/assist_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ class PipelineRun:
language: str = None # type: ignore[assignment]
runner_data: Any | None = None
intent_agent: str | None = None
tts_audio_output: str | None = None
tts_audio_output: str | dict[str, Any] | None = None
wake_word_settings: WakeWordSettings | None = None
audio_settings: AudioSettings = field(default_factory=AudioSettings)

Expand Down Expand Up @@ -1052,12 +1052,15 @@ async def prepare_text_to_speech(self) -> None:
if self.pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = self.pipeline.tts_voice

if self.tts_audio_output is not None:
if isinstance(self.tts_audio_output, dict):
tts_options.update(self.tts_audio_output)
elif isinstance(self.tts_audio_output, str):
tts_options[tts.ATTR_PREFERRED_FORMAT] = self.tts_audio_output
if self.tts_audio_output == "wav":
# 16 Khz, 16-bit mono
tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = SAMPLE_RATE
tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = SAMPLE_CHANNELS
tts_options[tts.ATTR_PREFERRED_SAMPLE_BYTES] = SAMPLE_WIDTH

try:
options_supported = await tts.async_support_options(
Expand Down
26 changes: 26 additions & 0 deletions homeassistant/components/tts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"ATTR_PREFERRED_FORMAT",
"ATTR_PREFERRED_SAMPLE_RATE",
"ATTR_PREFERRED_SAMPLE_CHANNELS",
"ATTR_PREFERRED_SAMPLE_BYTES",
"CONF_LANG",
"DEFAULT_CACHE_DIR",
"generate_media_source_id",
Expand All @@ -95,6 +96,7 @@
ATTR_PREFERRED_FORMAT = "preferred_format"
ATTR_PREFERRED_SAMPLE_RATE = "preferred_sample_rate"
ATTR_PREFERRED_SAMPLE_CHANNELS = "preferred_sample_channels"
ATTR_PREFERRED_SAMPLE_BYTES = "preferred_sample_bytes"
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
ATTR_VOICE = "voice"

Expand All @@ -103,6 +105,7 @@
ATTR_PREFERRED_FORMAT,
ATTR_PREFERRED_SAMPLE_RATE,
ATTR_PREFERRED_SAMPLE_CHANNELS,
ATTR_PREFERRED_SAMPLE_BYTES,
}

CONF_LANG = "language"
Expand Down Expand Up @@ -223,6 +226,7 @@ async def async_convert_audio(
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
to_sample_bytes: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
Expand All @@ -234,6 +238,7 @@ async def async_convert_audio(
to_extension,
to_sample_rate=to_sample_rate,
to_sample_channels=to_sample_channels,
to_sample_bytes=to_sample_bytes,
)
)

Expand All @@ -245,6 +250,7 @@ def _convert_audio(
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
to_sample_bytes: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""

Expand Down Expand Up @@ -277,6 +283,10 @@ def _convert_audio(
# Max quality for MP3
command.extend(["-q:a", "0"])

if to_sample_bytes == 2:
# 16-bit samples
command.extend(["-sample_fmt", "s16"])

command.append(output_file.name)

with subprocess.Popen(
Expand Down Expand Up @@ -738,11 +748,25 @@ async def _async_get_tts_audio(
else:
sample_rate = options.pop(ATTR_PREFERRED_SAMPLE_RATE, None)

if sample_rate is not None:
sample_rate = int(sample_rate)

if ATTR_PREFERRED_SAMPLE_CHANNELS in supported_options:
sample_channels = options.get(ATTR_PREFERRED_SAMPLE_CHANNELS)
else:
sample_channels = options.pop(ATTR_PREFERRED_SAMPLE_CHANNELS, None)

if sample_channels is not None:
sample_channels = int(sample_channels)

if ATTR_PREFERRED_SAMPLE_BYTES in supported_options:
sample_bytes = options.get(ATTR_PREFERRED_SAMPLE_BYTES)
else:
sample_bytes = options.pop(ATTR_PREFERRED_SAMPLE_BYTES, None)

if sample_bytes is not None:
sample_bytes = int(sample_bytes)

async def get_tts_data() -> str:
"""Handle data available."""
if engine_instance.name is None or engine_instance.name is UNDEFINED:
Expand All @@ -769,6 +793,7 @@ async def get_tts_data() -> str:
(final_extension != extension)
or (sample_rate is not None)
or (sample_channels is not None)
or (sample_bytes is not None)
)

if needs_conversion:
Expand All @@ -779,6 +804,7 @@ async def get_tts_data() -> str:
to_extension=final_extension,
to_sample_rate=sample_rate,
to_sample_channels=sample_channels,
to_sample_bytes=sample_bytes,
)

# Create file infos
Expand Down
84 changes: 79 additions & 5 deletions tests/components/assist_pipeline/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,13 +788,12 @@ async def test_tts_audio_output(
assert len(extra_options) == 0, extra_options


async def test_tts_supports_preferred_format(
async def test_tts_wav_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
snapshot: SnapshotAssertion,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
Expand Down Expand Up @@ -829,6 +828,80 @@ async def test_tts_supports_preferred_format(
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
tts.ATTR_PREFERRED_SAMPLE_BYTES,
]
)

with (
patch.object(mock_tts_provider, "_supported_options", supported_options),
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()

for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
media_id = event.data["tts_output"]["media_id"]
resolved = await media_source.async_resolve_media(hass, media_id, None)
await client.get(resolved.url)

assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]

# We should have received preferred format options in get_tts_audio
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2


async def test_tts_dict_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})

events: list[assist_pipeline.PipelineEvent] = []

pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)

pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
conversation_id=None,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output={
tts.ATTR_PREFERRED_FORMAT: "flac",
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
},
),
)
await pipeline_input.validate()

# Make the TTS provider support preferred format options
supported_options = list(mock_tts_provider.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
tts.ATTR_PREFERRED_SAMPLE_BYTES,
]
)

Expand All @@ -850,6 +923,7 @@ async def test_tts_supports_preferred_format(
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]

# We should have received preferred format options in get_tts_audio
assert tts.ATTR_PREFERRED_FORMAT in options
assert tts.ATTR_PREFERRED_SAMPLE_RATE in options
assert tts.ATTR_PREFERRED_SAMPLE_CHANNELS in options
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac"
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2

0 comments on commit 4ecc655

Please sign in to comment.