From 96510721039c6ff846e5593ca000b6b0f0302d7f Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Wed, 11 Sep 2024 19:57:47 -0500 Subject: [PATCH] Fix audio format for VoIP (#125785) Fix audio format --- .../components/voip/assist_satellite.py | 12 +++++++++++- tests/components/voip/test_voip.py | 18 ++++++++++++++++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/homeassistant/components/voip/assist_satellite.py b/homeassistant/components/voip/assist_satellite.py index 9f117fc98782a0..f75f65a08ea3cc 100644 --- a/homeassistant/components/voip/assist_satellite.py +++ b/homeassistant/components/voip/assist_satellite.py @@ -8,7 +8,7 @@ import io import logging from pathlib import Path -from typing import TYPE_CHECKING, Final +from typing import TYPE_CHECKING, Any, Final import wave from voip_utils import RtpDatagramProtocol @@ -120,6 +120,16 @@ def vad_sensitivity_entity_id(self) -> str | None: """Return the entity ID of the VAD sensitivity to use for the next conversation.""" return self.voip_device.get_vad_sensitivity_entity_id(self.hass) + @property + def tts_options(self) -> dict[str, Any] | None: + """Options passed for text-to-speech.""" + return { + tts.ATTR_PREFERRED_FORMAT: "wav", + tts.ATTR_PREFERRED_SAMPLE_RATE: 16000, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1, + tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, + } + async def async_added_to_hass(self) -> None: """Run when entity about to be added to hass.""" await super().async_added_to_hass() diff --git a/tests/components/voip/test_voip.py b/tests/components/voip/test_voip.py index e6a635619a1e07..edd4d2972f44b6 100644 --- a/tests/components/voip/test_voip.py +++ b/tests/components/voip/test_voip.py @@ -3,6 +3,7 @@ import asyncio import io from pathlib import Path +from typing import Any from unittest.mock import AsyncMock, Mock, patch import wave @@ -10,7 +11,7 @@ from syrupy.assertion import SnapshotAssertion from voip_utils import CallInfo -from homeassistant.components import assist_pipeline, assist_satellite, voip +from homeassistant.components import assist_pipeline, assist_satellite, tts, voip from homeassistant.components.assist_satellite.entity import ( AssistSatelliteEntity, AssistSatelliteState, @@ -205,11 +206,24 @@ async def test_pipeline( bad_chunk = bytes([1, 2, 3, 4]) async def async_pipeline_from_audio_stream( - hass: HomeAssistant, context: Context, *args, device_id: str | None, **kwargs + hass: HomeAssistant, + context: Context, + *args, + device_id: str | None, + tts_audio_output: str | dict[str, Any] | None, + **kwargs, ): assert context.user_id == voip_user_id assert device_id == voip_device.device_id + # voip can only stream WAV + assert tts_audio_output == { + tts.ATTR_PREFERRED_FORMAT: "wav", + tts.ATTR_PREFERRED_SAMPLE_RATE: 16000, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1, + tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, + } + stt_stream = kwargs["stt_stream"] event_callback = kwargs["event_callback"] in_command = False