Skip to content

Commit

Permalink
Fix TTS streaming for VoIP (#104620)
Browse files Browse the repository at this point in the history
* Use wav instead of raw tts audio in voip

* More tests

* Use mock TTS dir
  • Loading branch information
synesthesiam authored Nov 29, 2023
1 parent 47426a3 commit a894146
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 8 deletions.
29 changes: 26 additions & 3 deletions homeassistant/components/voip/voip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from collections import deque
from collections.abc import AsyncIterable, MutableSequence, Sequence
from functools import partial
import io
import logging
from pathlib import Path
import time
from typing import TYPE_CHECKING
import wave

from voip_utils import (
CallInfo,
Expand Down Expand Up @@ -285,7 +287,7 @@ async def stt_stream():
),
conversation_id=self._conversation_id,
device_id=self.voip_device.device_id,
tts_audio_output="raw",
tts_audio_output="wav",
)

if self._pipeline_error:
Expand Down Expand Up @@ -402,19 +404,40 @@ async def _send_tts(self, media_id: str) -> None:
if self.transport is None:
return

_extension, audio_bytes = await tts.async_get_media_source_audio(
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)

if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")

with io.BytesIO(data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()

if (
(sample_rate != 16000)
or (sample_width != 2)
or (sample_channels != 1)
):
raise ValueError(
"Expected rate/width/channels as 16000/2/1,"
" got {sample_rate}/{sample_width}/{sample_channels}}"
)

audio_bytes = wav_file.readframes(wav_file.getnframes())

_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))

# Time out 1 second after TTS audio should be finished
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
tts_seconds = tts_samples / RATE

async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
# Assume TTS audio is 16Khz 16-bit mono
# TTS audio is 16Khz 16-bit mono
await self._async_send_audio(audio_bytes)
except asyncio.TimeoutError as err:
_LOGGER.warning("TTS timeout")
Expand Down
218 changes: 213 additions & 5 deletions tests/components/voip/test_voip.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Test VoIP protocol."""
import asyncio
import io
import time
from unittest.mock import AsyncMock, Mock, patch
import wave

import pytest

Expand All @@ -14,6 +16,24 @@
_MEDIA_ID = "12345"


@pytest.fixture(autouse=True)
def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
"""Mock the TTS cache dir with empty dir."""
return mock_tts_cache_dir


def _empty_wav() -> bytes:
"""Return bytes of an empty WAV file."""
with io.BytesIO() as wav_io:
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
with wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)

return wav_io.getvalue()


async def test_pipeline(
hass: HomeAssistant,
voip_device: VoIPDevice,
Expand Down Expand Up @@ -72,8 +92,7 @@ async def async_get_media_source_audio(
media_source_id: str,
) -> tuple[str, bytes]:
assert media_source_id == _MEDIA_ID

return ("mp3", b"")
return ("wav", _empty_wav())

with patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
Expand Down Expand Up @@ -266,7 +285,7 @@ async def async_get_media_source_audio(
media_source_id: str,
) -> tuple[str, bytes]:
# Should time out immediately
return ("raw", bytes(0))
return ("wav", _empty_wav())

with patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
Expand Down Expand Up @@ -305,8 +324,197 @@ async def send_tts(*args, **kwargs):

done.set()

rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio)
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts)
rtp_protocol._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]

# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))

# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))

# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))

# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
await done.wait()


async def test_tts_wrong_extension(
hass: HomeAssistant,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will only stream WAV audio."""
assert await async_setup_component(hass, "voip", {})

def is_speech(self, chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0

done = asyncio.Event()

async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass

# Fake intent result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.INTENT_END,
data={
"intent_output": {
"conversation_id": "fake-conversation",
}
},
)
)

# Proceed with media output
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}},
)
)

async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
# Should fail because it's not "wav"
return ("mp3", b"")

with patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech,
), patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
), patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()

original_send_tts = rtp_protocol._send_tts

async def send_tts(*args, **kwargs):
# Call original then end test successfully
with pytest.raises(ValueError):
await original_send_tts(*args, **kwargs)

done.set()

rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]

# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))

# "speech"
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))

# silence (assumes relaxed VAD sensitivity)
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))

# Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1):
await done.wait()


async def test_tts_wrong_wav_format(
hass: HomeAssistant,
voip_device: VoIPDevice,
) -> None:
"""Test that TTS will only stream WAV audio with a specific format."""
assert await async_setup_component(hass, "voip", {})

def is_speech(self, chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0

done = asyncio.Event()

async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass

# Fake intent result
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.INTENT_END,
data={
"intent_output": {
"conversation_id": "fake-conversation",
}
},
)
)

# Proceed with media output
event_callback(
assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}},
)
)

async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
# Should fail because it's not 16Khz, 16-bit mono
with io.BytesIO() as wav_io:
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
with wav_file:
wav_file.setframerate(22050)
wav_file.setsampwidth(2)
wav_file.setnchannels(2)

return ("wav", wav_io.getvalue())

with patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech,
), patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
), patch(
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
):
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
hass,
hass.config.language,
voip_device,
Context(),
opus_payload_type=123,
)
rtp_protocol.transport = Mock()

original_send_tts = rtp_protocol._send_tts

async def send_tts(*args, **kwargs):
# Call original then end test successfully
with pytest.raises(ValueError):
await original_send_tts(*args, **kwargs)

done.set()

rtp_protocol._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]

# silence
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
Expand Down

0 comments on commit a894146

Please sign in to comment.