Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 99 additions & 50 deletions riva/client/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ def __init__(self, args: argparse.Namespace):
self.input_playback_thread = None
self.is_input_playing = False
self.input_buffer_size = 1024 # Buffer size for input audio playback

# Transcription results
self.delta_transcripts: List[str] = []
self.interim_final_transcripts: List[str] = []
self.final_transcript: str = ""
self.is_config_updated = False

Expand All @@ -58,13 +54,13 @@ async def connect(self):
await self._initialize_session()

except requests.exceptions.RequestException as e:
logger.error(f"HTTP request failed: {e}")
logger.error("HTTP request failed: %s", e)
raise
except WebSocketException as e:
logger.error(f"WebSocket connection failed: {e}")
logger.error("WebSocket connection failed: %s", e)
raise
except Exception as e:
logger.error(f"Unexpected error during connection: {e}")
logger.error("Unexpected error during connection: %s", e)
raise

async def _initialize_http_session(self) -> Dict[str, Any]:
Expand All @@ -73,7 +69,7 @@ async def _initialize_http_session(self) -> Dict[str, Any]:
uri = f"http://{self.args.server}/v1/realtime/transcription_sessions"
if self.args.use_ssl:
uri = f"https://{self.args.server}/v1/realtime/transcription_sessions"
logger.info(f"Initializing session via HTTP POST request to: {uri}")
logger.debug("Initializing session via HTTP POST request to: %s", uri)
response = requests.post(
uri,
headers=headers,
Expand All @@ -89,7 +85,7 @@ async def _initialize_http_session(self) -> Dict[str, Any]:
)

session_data = response.json()
logger.info(f"Session initialized: {session_data}")
logger.debug("Session initialized: %s", session_data)
return session_data

async def _connect_websocket(self):
Expand All @@ -110,7 +106,7 @@ async def _connect_websocket(self):
ssl_context.check_hostname = False
# ssl_context.verify_mode = ssl.CERT_REQUIRED

logger.info(f"Connecting to WebSocket: {ws_url}")
logger.debug("Connecting to WebSocket: %s", ws_url)
self.websocket = await websockets.connect(ws_url, ssl=ssl_context)

async def _initialize_session(self):
Expand All @@ -119,14 +115,14 @@ async def _initialize_session(self):
# Handle first response: "conversation.created"
response = await self.websocket.recv()
response_data = json.loads(response)
logger.info("Session created: %s", response_data)
logger.debug("Session created: %s", response_data)

event_type = response_data.get("type", "")
if event_type == "conversation.created":
logger.info("Conversation created successfully")
logger.debug("Conversation created successfully")
logger.debug("Response structure: %s", list(response_data.keys()))
else:
logger.warning(f"Unexpected first response type: {event_type}")
logger.warning("Unexpected first response type: %s", event_type)
logger.debug("Full response: %s", response_data)

# Update session configuration
Expand All @@ -135,16 +131,16 @@ async def _initialize_session(self):
logger.error("Failed to update session")
raise Exception("Failed to update session")

logger.info("Session initialization complete")
logger.debug("Session initialization complete")

except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON response: {e}")
logger.error("Failed to parse JSON response: %s", e)
raise
except KeyError as e:
logger.error(f"Missing expected key in response: {e}")
logger.error("Missing expected key in response: %s", e)
raise
except Exception as e:
logger.error(f"Unexpected error during session initialization: {e}")
logger.error("Unexpected error during session initialization: %s", e)
raise

def _safe_update_config(self, config: Dict[str, Any], key: str, value: Any, section: str = None):
Expand All @@ -160,25 +156,33 @@ def _safe_update_config(self, config: Dict[str, Any], key: str, value: Any, sect
if section not in config:
config[section] = {}
config[section][key] = value
logger.debug(f"Updated {section}.{key} = {value}")
logger.debug("Updated %s.%s = %s", section, key, value)
else:
config[key] = value
logger.debug(f"Updated {key} = {value}")
logger.debug("Updated %s = %s", key, value)

async def _update_session(self) -> bool:
"""Update session configuration by selectively overriding server defaults.

Returns:
True if session was updated successfully, False otherwise
"""
logger.info("Updating session configuration...")
logger.info(f"Server default config: {self.session_config}")
logger.debug("Updating session configuration...")
logger.debug("Server default config: %s", self.session_config)

# Create a copy of the session config from server defaults
session_config = self.session_config.copy()

# Track what we're overriding
overrides = []

# Check if the input is microphone, then set the encoding to pcm16
if hasattr(self.args, 'mic') and self.args.mic:
self._safe_update_config(session_config, "input_audio_format", "pcm16")
overrides.append("input_audio_format")
else:
self._safe_update_config(session_config, "input_audio_format", "none")
overrides.append("input_audio_format")

# Update input audio transcription - only override if args are provided
if hasattr(self.args, 'language_code') and self.args.language_code:
Expand Down Expand Up @@ -260,11 +264,11 @@ async def _update_session(self) -> bool:
overrides.append("custom_configuration")

if overrides:
logger.info(f"Overriding server defaults for: {', '.join(overrides)}")
logger.debug("Overriding server defaults for: %s", ', '.join(overrides))
else:
logger.info("Using server default configuration (no overrides)")
logger.debug("Using server default configuration (no overrides)")

logger.info(f"Final session config: {session_config}")
logger.debug("Final session config: %s", session_config)

# Send update request
update_session_request = {
Expand Down Expand Up @@ -333,16 +337,16 @@ async def _handle_session_update_response(self) -> bool:
"""
response = await self.websocket.recv()
response_data = json.loads(response)
logger.info("Session updated: %s", response_data)
logger.info("Current Session Config: %s", response_data)

event_type = response_data.get("type", "")
if event_type == "transcription_session.updated":
logger.info("Transcription session updated successfully")
logger.debug("Transcription session updated successfully")
logger.debug("Response structure: %s", list(response_data.keys()))
self.session_config = response_data["session"]
return True
else:
logger.warning(f"Unexpected response type: {event_type}")
logger.warning("Unexpected response type: %s", event_type)
logger.debug("Full response: %s", response_data)
return False

Expand All @@ -352,23 +356,49 @@ async def _send_message(self, message: Dict[str, Any]):

async def send_audio_chunks(self, audio_chunks):
"""Send audio chunks to the server for transcription."""
logger.info("Sending audio chunks...")

for chunk in audio_chunks:
chunk_base64 = base64.b64encode(chunk).decode("utf-8")

# Send chunk to the server
await self._send_message({
"type": "input_audio_buffer.append",
"audio": chunk_base64,
})

# Commit the chunk
await self._send_message({
"type": "input_audio_buffer.commit",
})

logger.info("All chunks sent")
logger.debug("Sending audio chunks...")

# Check if the audio_chunks supports async iteration
if hasattr(audio_chunks, '__aiter__'):
# Use async for for async iterators - this allows proper task switching
async for chunk in audio_chunks:
try:
chunk_base64 = base64.b64encode(chunk).decode("utf-8")

# Send chunk to the server
await self._send_message({
"type": "input_audio_buffer.append",
"audio": chunk_base64,
})

# Commit the chunk
await self._send_message({
"type": "input_audio_buffer.commit",
})
except TimeoutError:
# Handle timeout from AsyncAudioIterator - no audio available, continue
logger.debug("No audio chunk available within timeout, continuing...")
continue
except Exception as e:
logger.error(f"Error processing audio chunk: {e}")
continue
else:
# Fallback for regular iterators
for chunk in audio_chunks:
chunk_base64 = base64.b64encode(chunk).decode("utf-8")

# Send chunk to the server
await self._send_message({
"type": "input_audio_buffer.append",
"audio": chunk_base64,
})

# Commit the chunk
await self._send_message({
"type": "input_audio_buffer.commit",
})

logger.debug("All chunks sent")

# Tell the server that we are done sending chunks
await self._send_message({
Expand All @@ -377,7 +407,7 @@ async def send_audio_chunks(self, audio_chunks):

async def receive_responses(self):
"""Receive and process transcription responses from the server."""
logger.info("Listening for responses...")
logger.debug("Listening for responses...")
received_final_response = False

while not received_final_response:
Expand All @@ -389,12 +419,10 @@ async def receive_responses(self):
if event_type == "conversation.item.input_audio_transcription.delta":
delta = event.get("delta", "")
logger.info("Transcript: %s", delta)
self.delta_transcripts.append(delta)

elif event_type == "conversation.item.input_audio_transcription.completed":
is_last_result = event.get("is_last_result", False)
interim_final_transcript = event.get("transcript", "")
self.interim_final_transcripts.append(interim_final_transcript)
self.final_transcript = interim_final_transcript

if is_last_result:
Expand All @@ -405,7 +433,28 @@ async def receive_responses(self):
else:
logger.info("Interim Transcript: %s", interim_final_transcript)

logger.info("Words Info: %s", event.get("words_info", ""))
# Format Words Info similar to print_streaming function
words_info = event.get("words_info", {})
if words_info and "words" in words_info:
print("Words Info:")

# Create header format similar to print_streaming
header_format = '{: <40s}{: <16s}{: <16s}{: <16s}{: <16s}'
header_values = ['Word', 'Start (ms)', 'End (ms)', 'Confidence', 'Speaker']
print(header_format.format(*header_values))

# Print each word with formatted information
for word_data in words_info["words"]:
word = word_data.get("word", "")
start_time = word_data.get("start_time", 0)
end_time = word_data.get("end_time", 0)
confidence = word_data.get("confidence", 0.0)
speaker_tag = word_data.get("speaker_tag", 0)

# Format the word info line similar to print_streaming
word_format = '{: <40s}{: <16.0f}{: <16.0f}{: <16.4f}{: <16d}'
word_values = [word, start_time, end_time, confidence, speaker_tag]
print(word_format.format(*word_values))

elif "error" in event_type.lower():
logger.error(
Expand All @@ -417,7 +466,7 @@ async def receive_responses(self):
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"Error: {e}")
logger.error("Error: %s", e)
break

def save_responses(self, output_text_file: str):
Expand All @@ -431,7 +480,7 @@ def save_responses(self, output_text_file: str):
with open(output_text_file, "w") as f:
f.write(self.final_transcript)
except Exception as e:
logger.error(f"Error saving text: {e}")
logger.error("Error saving text: %s", e)

async def disconnect(self):
"""Close the WebSocket connection."""
Expand Down
Loading