Skip to content

Commit

Permalink
Improved AudioEnded handling and other small optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
dumitrugutu committed Oct 16, 2024
1 parent 5a22092 commit dceab4c
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 32 deletions.
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [0.0.2] - 2024-10-16

### Added

- Improved handling of the AudioEnded which caused the client to abruptly close the connection.
The client now waits up to 5 seconds for a ConversationEnded message from the server before closing the connection.

### Changed

- Do not generate JWT when connecting to a local Flow server.

### Fixed

- Removed `-` from CLI usage example from README which caused an `unrecognized arguments` error.

### Removed

- `--generate-temp-token` CLI option, as it was always set to True with no option to change it to False.

## [0.0.1] - 2024-10-14

### Added
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ python setup.py install --user
*Note: Requires access to microphone

```bash
speechmatics-flow --url $URL --auth-token $TOKEN -
speechmatics-flow --url $URL --auth-token $TOKEN
```

## Support
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.1
0.0.2
9 changes: 4 additions & 5 deletions speechmatics_flow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,14 @@ def get_connection_settings(args):
:type args: dict
:return: Settings for the WebSocket connection.
:rtype: speechmatics_flow.models.ConnectionSettings
:rtype: models.ConnectionSettings
"""
auth_token = args.get("auth_token")
url = args.get("url")
generate_temp_token = args.get("generate_temp_token")
settings = ConnectionSettings(
url=url,
auth_token=auth_token,
generate_temp_token=generate_temp_token,
generate_temp_token=True,
)

if args.get("buffer_size") is not None:
Expand All @@ -101,7 +100,7 @@ def get_conversation_config(
:type args: Dict
:return: Settings for the ASR engine.
:rtype: flow.models.ConversationConfig
:rtype: models.ConversationConfig
"""

config: Dict[str, Any] = {}
Expand All @@ -124,7 +123,7 @@ def get_audio_settings(args):
args (dict): Keyword arguments, typically from the command line.
Returns:
flow.models.AudioSettings: Settings for the audio stream
models.AudioSettings: Settings for the audio stream
in the connection.
"""
settings = AudioSettings(
Expand Down
6 changes: 0 additions & 6 deletions speechmatics_flow/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ def get_arg_parser():
type=str,
help="Authentication token to authorize the client.",
)
parser.add_argument(
"--generate-temp-token",
default=True,
action="store_true",
help="Automatically generate a temporary token for authentication.",
)
parser.add_argument(
"--ssl-mode",
default="regular",
Expand Down
72 changes: 53 additions & 19 deletions speechmatics_flow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@ def __init__(

self.seq_no = 0
self.session_running = False
self._language_pack_info = None
self._transcription_config_needs_update = False
self.conversation_ended_wait_timeout = 5
self._session_needs_closing = False
self._audio_buffer = None
self._pyaudio = pyaudio.PyAudio

# The following asyncio fields are fully instantiated in
# _init_synchronization_primitives
self._conversation_started = asyncio.Event
self._conversation_ended = asyncio.Event
# Semaphore used to ensure that we don't send too much audio data to
# the server too quickly and burst any buffers downstream.
self._buffer_semaphore = asyncio.BoundedSemaphore
Expand All @@ -85,6 +86,8 @@ async def _init_synchronization_primitives(self):
an event loop
"""
self._conversation_started = asyncio.Event()
self._conversation_ended = asyncio.Event()
self._pyaudio = pyaudio.PyAudio()
self._buffer_semaphore = asyncio.BoundedSemaphore(
self.connection_settings.message_buffer_size
)
Expand All @@ -99,6 +102,16 @@ def _flag_conversation_started(self):
"""
self._conversation_started.set()

def _flag_conversation_ended(self):
"""
Handle a
:py:attr:`models.ClientMessageType.ConversationEnded`
message from the server.
This updates an internal flag to mark the session ended
and server connection is closed
"""
self._conversation_ended.set()

@json_utf8
def _start_conversation(self):
"""
Expand Down Expand Up @@ -130,6 +143,15 @@ def _end_of_audio(self):
LOGGER.debug(msg)
return msg

async def _wait_for_conversation_ended(self):
"""
Waits for :py:attr:`models.ClientMessageType.ConversationEnded`
message from the server.
"""
await asyncio.wait_for(
self._conversation_ended.wait(), self.conversation_ended_wait_timeout
)

async def _consumer(self, message, from_cli: False):
"""
Consumes messages and acts on them.
Expand Down Expand Up @@ -170,6 +192,7 @@ async def _consumer(self, message, from_cli: False):
elif message_type == ServerMessageType.AudioAdded:
self._buffer_semaphore.release()
elif message_type == ServerMessageType.ConversationEnded:
self._flag_conversation_ended()
raise ConversationEndedException()
elif message_type == ServerMessageType.EndOfTranscript:
raise EndOfTranscriptException()
Expand All @@ -179,19 +202,22 @@ async def _consumer(self, message, from_cli: False):
raise TranscriptionError(message["reason"])

async def _read_from_microphone(self):
p = pyaudio.PyAudio()
print(f"Default input device: {p.get_default_input_device_info()['name']}")
print(f"Default output device: {p.get_default_output_device_info()['name']}")
print(
f"Default input device: {self._pyaudio.get_default_input_device_info()['name']}"
)
print(
f"Default output device: {self._pyaudio.get_default_output_device_info()['name']}"
)
print("Start speaking...")
stream = p.open(
stream = self._pyaudio.open(
format=pyaudio.paInt16,
channels=1,
rate=self.audio_settings.sample_rate,
input=True,
)
try:
while True:
if self._session_needs_closing:
if self._session_needs_closing or self._conversation_ended.is_set():
break

await asyncio.wait_for(
Expand All @@ -205,11 +231,13 @@ async def _read_from_microphone(self):
self.seq_no += 1
self._call_middleware(ClientMessageType.AddAudio, audio_chunk, True)
await self.websocket.send(audio_chunk)
finally:
except KeyboardInterrupt:
await self.websocket.send(self._end_of_audio())
finally:
await self._wait_for_conversation_ended()
stream.stop_stream()
stream.close()
p.terminate()
self._pyaudio.terminate()

async def _consumer_handler(self, from_cli: False):
"""
Expand All @@ -231,9 +259,8 @@ async def _consumer_handler(self, from_cli: False):

async def _stream_producer(self, stream, audio_chunk_size):
async for audio_chunk in read_in_chunks(stream, audio_chunk_size):
if self._session_needs_closing:
if self._session_needs_closing or self._conversation_ended.is_set():
break

await asyncio.wait_for(
self._buffer_semaphore.acquire(),
timeout=self.connection_settings.semaphore_timeout_seconds,
Expand All @@ -248,7 +275,6 @@ async def _producer_handler(self, interactions: List[Interaction]):
Controls the producer loop for sending messages to the server.
"""
await self._conversation_started.wait()

if interactions[0].stream.name == "<stdin>":
return await self._read_from_microphone()

Expand All @@ -265,21 +291,21 @@ async def _producer_handler(self, interactions: List[Interaction]):
interaction.callback(self)

await self.websocket.send(self._end_of_audio())
await self._wait_for_conversation_ended()

async def _playback_handler(self):
"""
Reads audio binary messages from the playback buffer and plays them to the user.
"""
p = pyaudio.PyAudio()
stream = p.open(
stream = self._pyaudio.open(
format=pyaudio.paInt16,
channels=1,
rate=self.audio_settings.sample_rate,
output=True,
)
try:
while True:
if self._session_needs_closing:
if self._session_needs_closing or self._conversation_ended.is_set():
break
try:
audio_message = await self._audio_buffer.get()
Expand All @@ -291,8 +317,7 @@ async def _playback_handler(self):
finally:
stream.close()
stream.stop_stream()
p.terminate()
LOGGER.debug("Exiting playback handler")
self._pyaudio.terminate()

def _call_middleware(self, event_name, *args):
"""
Expand Down Expand Up @@ -450,14 +475,23 @@ async def run(
consumer/producer tasks.
"""
self.seq_no = 0
self._language_pack_info = None
self.conversation_config = conversation_config
self.audio_settings = audio_settings

await self._init_synchronization_primitives()

extra_headers = {}
auth_token = await get_temp_token(self.connection_settings.auth_token)
auth_token = self.connection_settings.auth_token
# Do not generate a JWT when connecting to a local server
local_servers = ("localhost", "127.0.0.1", "0.0.0.0")
if (
not any(
local_server in self.connection_settings.url
for local_server in local_servers
)
and self.connection_settings.generate_temp_token
):
auth_token = await get_temp_token(self.connection_settings.auth_token)
extra_headers["Authorization"] = f"Bearer {auth_token}"
try:
async with websockets.connect( # pylint: disable=no-member
Expand Down

0 comments on commit dceab4c

Please sign in to comment.