Skip to content

Commit

Permalink
Smoother audio playback in poor networking conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
dumitrugutu committed Dec 16, 2024
1 parent 33fac22 commit c73cf3e
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 45 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [0.0.8] - 2024-11-29

### Added

- Introduced a new class to enable custom configuration of audio playback settings, offering greater flexibility for
fine-tuning audio playback.
- The client now buffers audio to ensure smoother playback, especially in challenging network conditions.

### Fixed

- Resolved an issue with reading piped audio from stdin.

## [0.0.7] - 2024-11-25

### Added
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.7
0.0.8
28 changes: 24 additions & 4 deletions speechmatics_flow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ServerMessageType,
Interaction,
ConnectionSettings,
PlaybackSettings,
)
from speechmatics_flow.templates import TemplateOptions

Expand Down Expand Up @@ -139,6 +140,25 @@ def get_audio_settings(args):
return settings


def get_playback_settings(args):
"""
Helper function which returns a PlaybackSettings object based on the command
line options given to the program.
Args:
args (dict): Keyword arguments, typically from the command line.
Returns:
models.PlaybackSettings: Settings for the audio playback stream
in the connection.
"""
return PlaybackSettings(
buffering=args.get("playback_buffering"),
sample_rate=args.get("playback_sample_rate"),
chunk_size=args.get("playback_chunk_size"),
)


# pylint: disable=too-many-arguments,too-many-statements
def add_printing_handlers(
api,
Expand Down Expand Up @@ -248,7 +268,6 @@ def flow_main(args):
:param args: arguments from parse_args()
:type args: argparse.Namespace
"""
conversation_config = get_conversation_config(args)
settings = get_connection_settings(args)
api = WebsocketClient(settings)
transcripts = Transcripts()
Expand All @@ -261,9 +280,10 @@ def flow_main(args):
def run(stream):
try:
api.run_synchronously(
[Interaction(stream)],
get_audio_settings(args),
conversation_config,
interactions=[Interaction(stream)],
audio_settings=get_audio_settings(args),
conversation_config=get_conversation_config(args),
playback_settings=get_playback_settings(args),
from_cli=True,
)
except KeyboardInterrupt:
Expand Down
25 changes: 25 additions & 0 deletions speechmatics_flow/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,31 @@ def get_arg_parser():
"acknowledgements from the server."
),
)
parser.add_argument(
"--playback-buffering",
type=int,
default=10,
help=(
"Buffer (in milliseconds) for audio received from the server before playback. "
"Increasing the buffer size can improve resilience to poor network conditions, "
"at the cost of increased latency."
),
),
parser.add_argument(
"--playback-sample-rate",
type=int,
default=16_000,
help="The sample rate in Hz of the output audio.",
)
parser.add_argument(
"--playback-chunk-size",
type=int,
default=256,
help=(
"The size of each audio chunk, in bytes, to read from the audio buffer. "
"Increasing the chunk size may improve playback smoothness."
),
)
parser.add_argument(
"--print-json",
default=False,
Expand Down
124 changes: 89 additions & 35 deletions speechmatics_flow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional

Expand All @@ -22,12 +23,13 @@
ConversationError,
)
from speechmatics_flow.models import (
ClientMessageType,
ServerMessageType,
AudioSettings,
ClientMessageType,
ConnectionSettings,
ConversationConfig,
Interaction,
ConnectionSettings,
PlaybackSettings,
ServerMessageType,
)
from speechmatics_flow.tool_function_param import ToolFunctionParam
from speechmatics_flow.utils import read_in_chunks, json_utf8
Expand Down Expand Up @@ -63,6 +65,7 @@ def __init__(
self.websocket = None
self.conversation_config = None
self.audio_settings = None
self.playback_settings = None
self.tools = None

self.event_handlers = {x: [] for x in ServerMessageType}
Expand All @@ -73,13 +76,15 @@ def __init__(
self.session_running = False
self.conversation_ended_wait_timeout = 5
self._session_needs_closing = False
self._audio_buffer = None
self._audio_buffer = bytearray()
self._audio_buffer_lock = asyncio.Lock()
self._executor = ThreadPoolExecutor()

# The following asyncio fields are fully instantiated in
# _init_synchronization_primitives
self._conversation_started = asyncio.Event
self._conversation_ended = asyncio.Event
self._response_started = asyncio.Event
# Semaphore used to ensure that we don't send too much audio data to
# the server too quickly and burst any buffers downstream.
self._buffer_semaphore = asyncio.BoundedSemaphore
Expand All @@ -91,24 +96,34 @@ async def _init_synchronization_primitives(self):
"""
self._conversation_started = asyncio.Event()
self._conversation_ended = asyncio.Event()
self._response_started = asyncio.Event()
self._buffer_semaphore = asyncio.BoundedSemaphore(
self.connection_settings.message_buffer_size
)

def _flag_conversation_started(self):
"""
Handle a
:py:attr:`models.ClientMessageType.ConversationStarted`
:py:attr:`models.ServerMessageType.ConversationStarted`
message from the server.
This updates an internal flag to mark the session started
as started meaning, AddAudio is now allowed.
"""
self._conversation_started.set()

def _flag_response_started(self):
"""
Handle a
:py:attr:`models.ServerMessageType.ResponseStarted`
message from the server.
This updates an internal flag to mark that the server started sending audio.
"""
self._response_started.set()

def _flag_conversation_ended(self):
"""
Handle a
:py:attr:`models.ClientMessageType.ConversationEnded`
:py:attr:`models.ServerMessageType.ConversationEnded`
message from the server.
This updates an internal flag to mark the session ended
and server connection is closed
Expand Down Expand Up @@ -158,7 +173,7 @@ def _audio_received(self):
msg = {
"message": ClientMessageType.AudioReceived,
"seq_no": self.server_seq_no,
"buffering": 0.01, # 10ms
"buffering": self.playback_settings.buffering / 1000,
}
self._call_middleware(ClientMessageType.AudioReceived, msg, False)
LOGGER.debug(msg)
Expand All @@ -169,9 +184,12 @@ async def _wait_for_conversation_ended(self):
Waits for :py:attr:`models.ClientMessageType.ConversationEnded`
message from the server.
"""
await asyncio.wait_for(
self._conversation_ended.wait(), self.conversation_ended_wait_timeout
)
try:
await asyncio.wait_for(
self._conversation_ended.wait(), self.conversation_ended_wait_timeout
)
except asyncio.TimeoutError:
LOGGER.warning("Timeout waiting for ConversationEnded message.")

async def _consumer(self, message, from_cli=False):
"""
Expand All @@ -192,7 +210,8 @@ async def _consumer(self, message, from_cli=False):
await self.websocket.send(self._audio_received())
# add an audio message to local buffer only when running from cli
if from_cli:
await self._audio_buffer.put(message)
async with self._audio_buffer_lock:
self._audio_buffer.extend(message)
# Implicit name for all inbound binary messages.
# We must manually set it for event handler subscribed
# to `ServerMessageType.AddAudio` messages to work.
Expand Down Expand Up @@ -226,6 +245,13 @@ async def _consumer(self, message, from_cli=False):

if message_type == ServerMessageType.ConversationStarted:
self._flag_conversation_started()
if message_type == ServerMessageType.ResponseStarted:
self._flag_response_started()
if message_type in [
ServerMessageType.ResponseCompleted,
ServerMessageType.ResponseInterrupted,
]:
self._response_started.clear()
elif message_type == ServerMessageType.AudioAdded:
self._buffer_semaphore.release()
elif message_type == ServerMessageType.ConversationEnded:
Expand Down Expand Up @@ -313,20 +339,31 @@ async def _producer_handler(self, interactions: List[Interaction]):
Controls the producer loop for sending messages to the server.
"""
await self._conversation_started.wait()
if interactions[0].stream.name == "<stdin>":
# Stream audio from microphone when running from the terminal and input is not piped
if (
sys.stdin.isatty()
and hasattr(interactions[0].stream, "name")
and interactions[0].stream.name == "<stdin>"
):
return await self._read_from_microphone()

for interaction in interactions:
async for message in self._stream_producer(
interaction.stream, self.audio_settings.chunk_size
):
try:
await self.websocket.send(message)
except Exception as e:
LOGGER.error(f"error sending message: {e}")
return
if interaction.callback:
interaction.callback(self)
try:
async for message in self._stream_producer(
interaction.stream, self.audio_settings.chunk_size
):
try:
await self.websocket.send(message)
except Exception as e:
LOGGER.error(f"Error sending message: {e}")
return

if interaction.callback:
LOGGER.debug("Executing callback for interaction.")
interaction.callback(self)

except Exception as e:
LOGGER.error(f"Error processing interaction: {e}")

await self.websocket.send(self._end_of_audio())
await self._wait_for_conversation_ended()
Expand All @@ -339,26 +376,38 @@ async def _playback_handler(self):
stream = _pyaudio.open(
format=pyaudio.paInt16,
channels=1,
rate=self.audio_settings.sample_rate,
frames_per_buffer=128,
rate=self.playback_settings.sample_rate,
frames_per_buffer=self.playback_settings.chunk_size,
output=True,
)
chunk_size = self.playback_settings.chunk_size

try:
while True:
if self._session_needs_closing or self._conversation_ended.is_set():
break
while not self._session_needs_closing or self._conversation_ended.is_set():
# Wait for the server to start sending audio
await self._response_started.wait()

# Ensure enough data is added to the buffer before starting playback
await asyncio.sleep(self.playback_settings.buffering / 1000)

# Start playback
try:
audio_message = await self._audio_buffer.get()
stream.write(audio_message)
self._audio_buffer.task_done()
# read from buffer at a constant rate
await asyncio.sleep(0.005)
while self._audio_buffer:
if len(self._audio_buffer) >= chunk_size:
async with self._audio_buffer_lock:
audio_chunk = bytes(self._audio_buffer[:chunk_size])
self._audio_buffer = self._audio_buffer[chunk_size:]
stream.write(audio_chunk)
await asyncio.sleep(0.005)
except Exception as e:
LOGGER.error(f"Error during audio playback: {e}")
LOGGER.error(f"Error during audio playback: {e}", exc_info=True)
raise e

except asyncio.CancelledError:
LOGGER.info("Playback handler cancelled.")
finally:
stream.close()
stream.stop_stream()
stream.close()
_pyaudio.terminate()

def _call_middleware(self, event_name, *args):
Expand Down Expand Up @@ -482,7 +531,6 @@ async def _communicate(self, interactions: List[Interaction], from_cli=False):

# Run the playback task that plays audio messages to the user when started from cli
if from_cli:
self._audio_buffer = asyncio.Queue()
tasks.append(asyncio.create_task(self._playback_handler()))

(done, pending) = await asyncio.wait(
Expand All @@ -509,6 +557,7 @@ async def run(
conversation_config: ConversationConfig = None,
from_cli: bool = False,
tools: Optional[List[ToolFunctionParam]] = None,
playback_settings: PlaybackSettings = PlaybackSettings(),
):
"""
Begin a new recognition session.
Expand All @@ -528,13 +577,18 @@ async def run(
:param tools: Optional list of tool functions.
:type tools: List[ToolFunctionParam]
:param playback_settings: Configuration for the playback stream.
:type playback_settings: models.PlaybackSettings
:raises Exception: Can raise any exception returned by the
consumer/producer tasks.
"""
self.client_seq_no = 0
self.server_seq_no = 0
self.conversation_config = conversation_config
self.audio_settings = audio_settings
self.playback_settings = playback_settings
self.tools = tools

await self._init_synchronization_primitives()
Expand Down
Loading

0 comments on commit c73cf3e

Please sign in to comment.