Skip to content

Commit fe1b2d5

Browse files
yhayarannvidiarmittal-github
authored andcommitted
FIx microphone case for realtime ASR client (#151)
* Realtime ASR micropphone fix * Set default server to localhost for realtime ASR client * refactor argument parsing in realtime ASR client to use mutually exclusive input options and update default server port to 9000 * refactor logging in RealtimeClient to use debug level for detailed internal state and error messages * Enhance RealtimeClient to support microphone input with PCM16 encoding and improve audio chunk handling with async iteration. Update logging for word information formatting and handle timeouts during audio processing.
1 parent 4a9bf9e commit fe1b2d5

File tree

2 files changed

+206
-75
lines changed

2 files changed

+206
-75
lines changed

riva/client/realtime.py

Lines changed: 99 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ def __init__(self, args: argparse.Namespace):
3838
self.input_playback_thread = None
3939
self.is_input_playing = False
4040
self.input_buffer_size = 1024 # Buffer size for input audio playback
41-
42-
# Transcription results
43-
self.delta_transcripts: List[str] = []
44-
self.interim_final_transcripts: List[str] = []
4541
self.final_transcript: str = ""
4642
self.is_config_updated = False
4743

@@ -58,13 +54,13 @@ async def connect(self):
5854
await self._initialize_session()
5955

6056
except requests.exceptions.RequestException as e:
61-
logger.error(f"HTTP request failed: {e}")
57+
logger.error("HTTP request failed: %s", e)
6258
raise
6359
except WebSocketException as e:
64-
logger.error(f"WebSocket connection failed: {e}")
60+
logger.error("WebSocket connection failed: %s", e)
6561
raise
6662
except Exception as e:
67-
logger.error(f"Unexpected error during connection: {e}")
63+
logger.error("Unexpected error during connection: %s", e)
6864
raise
6965

7066
async def _initialize_http_session(self) -> Dict[str, Any]:
@@ -73,7 +69,7 @@ async def _initialize_http_session(self) -> Dict[str, Any]:
7369
uri = f"http://{self.args.server}/v1/realtime/transcription_sessions"
7470
if self.args.use_ssl:
7571
uri = f"https://{self.args.server}/v1/realtime/transcription_sessions"
76-
logger.info(f"Initializing session via HTTP POST request to: {uri}")
72+
logger.debug("Initializing session via HTTP POST request to: %s", uri)
7773
response = requests.post(
7874
uri,
7975
headers=headers,
@@ -89,7 +85,7 @@ async def _initialize_http_session(self) -> Dict[str, Any]:
8985
)
9086

9187
session_data = response.json()
92-
logger.info(f"Session initialized: {session_data}")
88+
logger.debug("Session initialized: %s", session_data)
9389
return session_data
9490

9591
async def _connect_websocket(self):
@@ -110,7 +106,7 @@ async def _connect_websocket(self):
110106
ssl_context.check_hostname = False
111107
# ssl_context.verify_mode = ssl.CERT_REQUIRED
112108

113-
logger.info(f"Connecting to WebSocket: {ws_url}")
109+
logger.debug("Connecting to WebSocket: %s", ws_url)
114110
self.websocket = await websockets.connect(ws_url, ssl=ssl_context)
115111

116112
async def _initialize_session(self):
@@ -119,14 +115,14 @@ async def _initialize_session(self):
119115
# Handle first response: "conversation.created"
120116
response = await self.websocket.recv()
121117
response_data = json.loads(response)
122-
logger.info("Session created: %s", response_data)
118+
logger.debug("Session created: %s", response_data)
123119

124120
event_type = response_data.get("type", "")
125121
if event_type == "conversation.created":
126-
logger.info("Conversation created successfully")
122+
logger.debug("Conversation created successfully")
127123
logger.debug("Response structure: %s", list(response_data.keys()))
128124
else:
129-
logger.warning(f"Unexpected first response type: {event_type}")
125+
logger.warning("Unexpected first response type: %s", event_type)
130126
logger.debug("Full response: %s", response_data)
131127

132128
# Update session configuration
@@ -135,16 +131,16 @@ async def _initialize_session(self):
135131
logger.error("Failed to update session")
136132
raise Exception("Failed to update session")
137133

138-
logger.info("Session initialization complete")
134+
logger.debug("Session initialization complete")
139135

140136
except json.JSONDecodeError as e:
141-
logger.error(f"Failed to parse JSON response: {e}")
137+
logger.error("Failed to parse JSON response: %s", e)
142138
raise
143139
except KeyError as e:
144-
logger.error(f"Missing expected key in response: {e}")
140+
logger.error("Missing expected key in response: %s", e)
145141
raise
146142
except Exception as e:
147-
logger.error(f"Unexpected error during session initialization: {e}")
143+
logger.error("Unexpected error during session initialization: %s", e)
148144
raise
149145

150146
def _safe_update_config(self, config: Dict[str, Any], key: str, value: Any, section: str = None):
@@ -160,25 +156,33 @@ def _safe_update_config(self, config: Dict[str, Any], key: str, value: Any, sect
160156
if section not in config:
161157
config[section] = {}
162158
config[section][key] = value
163-
logger.debug(f"Updated {section}.{key} = {value}")
159+
logger.debug("Updated %s.%s = %s", section, key, value)
164160
else:
165161
config[key] = value
166-
logger.debug(f"Updated {key} = {value}")
162+
logger.debug("Updated %s = %s", key, value)
167163

168164
async def _update_session(self) -> bool:
169165
"""Update session configuration by selectively overriding server defaults.
170166
171167
Returns:
172168
True if session was updated successfully, False otherwise
173169
"""
174-
logger.info("Updating session configuration...")
175-
logger.info(f"Server default config: {self.session_config}")
170+
logger.debug("Updating session configuration...")
171+
logger.debug("Server default config: %s", self.session_config)
176172

177173
# Create a copy of the session config from server defaults
178174
session_config = self.session_config.copy()
179175

180176
# Track what we're overriding
181177
overrides = []
178+
179+
# Check if the input is microphone, then set the encoding to pcm16
180+
if hasattr(self.args, 'mic') and self.args.mic:
181+
self._safe_update_config(session_config, "input_audio_format", "pcm16")
182+
overrides.append("input_audio_format")
183+
else:
184+
self._safe_update_config(session_config, "input_audio_format", "none")
185+
overrides.append("input_audio_format")
182186

183187
# Update input audio transcription - only override if args are provided
184188
if hasattr(self.args, 'language_code') and self.args.language_code:
@@ -260,11 +264,11 @@ async def _update_session(self) -> bool:
260264
overrides.append("custom_configuration")
261265

262266
if overrides:
263-
logger.info(f"Overriding server defaults for: {', '.join(overrides)}")
267+
logger.debug("Overriding server defaults for: %s", ', '.join(overrides))
264268
else:
265-
logger.info("Using server default configuration (no overrides)")
269+
logger.debug("Using server default configuration (no overrides)")
266270

267-
logger.info(f"Final session config: {session_config}")
271+
logger.debug("Final session config: %s", session_config)
268272

269273
# Send update request
270274
update_session_request = {
@@ -333,16 +337,16 @@ async def _handle_session_update_response(self) -> bool:
333337
"""
334338
response = await self.websocket.recv()
335339
response_data = json.loads(response)
336-
logger.info("Session updated: %s", response_data)
340+
logger.info("Current Session Config: %s", response_data)
337341

338342
event_type = response_data.get("type", "")
339343
if event_type == "transcription_session.updated":
340-
logger.info("Transcription session updated successfully")
344+
logger.debug("Transcription session updated successfully")
341345
logger.debug("Response structure: %s", list(response_data.keys()))
342346
self.session_config = response_data["session"]
343347
return True
344348
else:
345-
logger.warning(f"Unexpected response type: {event_type}")
349+
logger.warning("Unexpected response type: %s", event_type)
346350
logger.debug("Full response: %s", response_data)
347351
return False
348352

@@ -352,23 +356,49 @@ async def _send_message(self, message: Dict[str, Any]):
352356

353357
async def send_audio_chunks(self, audio_chunks):
354358
"""Send audio chunks to the server for transcription."""
355-
logger.info("Sending audio chunks...")
356-
357-
for chunk in audio_chunks:
358-
chunk_base64 = base64.b64encode(chunk).decode("utf-8")
359-
360-
# Send chunk to the server
361-
await self._send_message({
362-
"type": "input_audio_buffer.append",
363-
"audio": chunk_base64,
364-
})
365-
366-
# Commit the chunk
367-
await self._send_message({
368-
"type": "input_audio_buffer.commit",
369-
})
370-
371-
logger.info("All chunks sent")
359+
logger.debug("Sending audio chunks...")
360+
361+
# Check if the audio_chunks supports async iteration
362+
if hasattr(audio_chunks, '__aiter__'):
363+
# Use async for for async iterators - this allows proper task switching
364+
async for chunk in audio_chunks:
365+
try:
366+
chunk_base64 = base64.b64encode(chunk).decode("utf-8")
367+
368+
# Send chunk to the server
369+
await self._send_message({
370+
"type": "input_audio_buffer.append",
371+
"audio": chunk_base64,
372+
})
373+
374+
# Commit the chunk
375+
await self._send_message({
376+
"type": "input_audio_buffer.commit",
377+
})
378+
except TimeoutError:
379+
# Handle timeout from AsyncAudioIterator - no audio available, continue
380+
logger.debug("No audio chunk available within timeout, continuing...")
381+
continue
382+
except Exception as e:
383+
logger.error(f"Error processing audio chunk: {e}")
384+
continue
385+
else:
386+
# Fallback for regular iterators
387+
for chunk in audio_chunks:
388+
chunk_base64 = base64.b64encode(chunk).decode("utf-8")
389+
390+
# Send chunk to the server
391+
await self._send_message({
392+
"type": "input_audio_buffer.append",
393+
"audio": chunk_base64,
394+
})
395+
396+
# Commit the chunk
397+
await self._send_message({
398+
"type": "input_audio_buffer.commit",
399+
})
400+
401+
logger.debug("All chunks sent")
372402

373403
# Tell the server that we are done sending chunks
374404
await self._send_message({
@@ -377,7 +407,7 @@ async def send_audio_chunks(self, audio_chunks):
377407

378408
async def receive_responses(self):
379409
"""Receive and process transcription responses from the server."""
380-
logger.info("Listening for responses...")
410+
logger.debug("Listening for responses...")
381411
received_final_response = False
382412

383413
while not received_final_response:
@@ -389,12 +419,10 @@ async def receive_responses(self):
389419
if event_type == "conversation.item.input_audio_transcription.delta":
390420
delta = event.get("delta", "")
391421
logger.info("Transcript: %s", delta)
392-
self.delta_transcripts.append(delta)
393422

394423
elif event_type == "conversation.item.input_audio_transcription.completed":
395424
is_last_result = event.get("is_last_result", False)
396425
interim_final_transcript = event.get("transcript", "")
397-
self.interim_final_transcripts.append(interim_final_transcript)
398426
self.final_transcript = interim_final_transcript
399427

400428
if is_last_result:
@@ -405,7 +433,28 @@ async def receive_responses(self):
405433
else:
406434
logger.info("Interim Transcript: %s", interim_final_transcript)
407435

408-
logger.info("Words Info: %s", event.get("words_info", ""))
436+
# Format Words Info similar to print_streaming function
437+
words_info = event.get("words_info", {})
438+
if words_info and "words" in words_info:
439+
print("Words Info:")
440+
441+
# Create header format similar to print_streaming
442+
header_format = '{: <40s}{: <16s}{: <16s}{: <16s}{: <16s}'
443+
header_values = ['Word', 'Start (ms)', 'End (ms)', 'Confidence', 'Speaker']
444+
print(header_format.format(*header_values))
445+
446+
# Print each word with formatted information
447+
for word_data in words_info["words"]:
448+
word = word_data.get("word", "")
449+
start_time = word_data.get("start_time", 0)
450+
end_time = word_data.get("end_time", 0)
451+
confidence = word_data.get("confidence", 0.0)
452+
speaker_tag = word_data.get("speaker_tag", 0)
453+
454+
# Format the word info line similar to print_streaming
455+
word_format = '{: <40s}{: <16.0f}{: <16.0f}{: <16.4f}{: <16d}'
456+
word_values = [word, start_time, end_time, confidence, speaker_tag]
457+
print(word_format.format(*word_values))
409458

410459
elif "error" in event_type.lower():
411460
logger.error(
@@ -417,7 +466,7 @@ async def receive_responses(self):
417466
except asyncio.TimeoutError:
418467
continue
419468
except Exception as e:
420-
logger.error(f"Error: {e}")
469+
logger.error("Error: %s", e)
421470
break
422471

423472
def save_responses(self, output_text_file: str):
@@ -431,7 +480,7 @@ def save_responses(self, output_text_file: str):
431480
with open(output_text_file, "w") as f:
432481
f.write(self.final_transcript)
433482
except Exception as e:
434-
logger.error(f"Error saving text: {e}")
483+
logger.error("Error saving text: %s", e)
435484

436485
async def disconnect(self):
437486
"""Close the WebSocket connection."""

0 commit comments

Comments
 (0)