Skip to content

Commit

Permalink
Enhanced audio pre-processing by implementing voice activity detectio…
Browse files Browse the repository at this point in the history
…n for selecting speech-containing audio chunks. Also, addressed end-of-stream behavior for sequential mode and cleanup behavior for disconnected clients. Tidied up some logging messages for better clarity.
  • Loading branch information
ethanzrd committed Aug 15, 2023
1 parent cf496f6 commit 33534eb
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 28 deletions.
26 changes: 12 additions & 14 deletions backend/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def create_new_client(self, sid, sio, config):
self.clients[sid] = new_client
await new_client.start_transcribing()
else:
logging.warning("Client removed before transcription could start")
logging.warning("Client removed before transcription could start, new client may connect")

async def start_stream(self, sid, sio, config):
if not self.clients:
Expand Down Expand Up @@ -48,21 +48,19 @@ async def end_stream(self, sid):
def disconnect_from_stream(self, sid):
if sid in self.clients.keys():
client = self.clients[sid]
if not client.disconnected:
cleanup_needed = False
if client != "initializing":
client.handle_disconnection()
cleanup_needed = client.cleanup_needed
if client != "initializing":
client.handle_disconnection()
# No error if client is still not an object, it won't get to that point
if client == "initializing" or not client.is_ending_stream():
try:
self.clients.pop(sid)
except KeyError:
logging.warning("disconnected_from_stream attempted to remove an already removed client")
logging.info("Disconnected client removed")
if cleanup_needed:
cleanup()
else:
logging.warning("Disconnecting client attempting to disconnect multiple times")
if client == "initializing" or not client.is_ending_stream():
try:
self.clients.pop(sid)
except KeyError:
logging.warning("disconnected_from_stream attempted to remove an already removed client")
logging.info("Disconnected client removed")
if cleanup_needed:
cleanup()
else:
logging.warning("A non-existent client tried to disconnect from the stream.")

Expand Down
8 changes: 5 additions & 3 deletions backend/clients/Client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from queue import Queue


from silero_vad import silero_vad
from diart.utils import decode_audio
class Client:

def __init__(self, sid, socket, transcriber, transcription_timeout):
Expand Down Expand Up @@ -42,7 +42,9 @@ async def send_transcription(self, transcription):
logging.info("Transcription not sent, client disconnected")

def handle_chunk(self, chunk):
self.audio_chunks.put(chunk)
speech_present, speech_confidence = silero_vad(decode_audio(chunk))
if speech_present:
self.audio_chunks.put(chunk)
logging.debug("Chunk added")

def is_ending_stream(self):
Expand Down
23 changes: 14 additions & 9 deletions backend/clients/SequentialClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,20 @@ async def start_transcribing(self):
await self.socket.emit("whisperingStarted")
logging.info("Stream start signaled to client")

def get_diarization(self, waveform):
audio = waveform.astype("float32").reshape(-1)
def get_diarization(self, audio):
save_batch_to_wav(audio, TEMP_FILE_PATH)
diarization = self.diarization_pipeline(TEMP_FILE_PATH)
return diarization

def transcribe_buffer(self, buffer):
diarization = self.get_diarization(buffer)
result = self.transcriber.sequential_transcription(buffer, diarization)
def transcribe_buffer(self, audio):
diarization = self.get_diarization(audio)
result = self.transcriber.sequential_transcription(audio, diarization)
asyncio.run(self.send_transcription(result))

@staticmethod
def convert_buffer_to_audio(buffer):
return buffer.astype("float32").reshape(-1)

@staticmethod
def modify_buffer(chunk, buffer):
decoded_chunk = decode_audio(chunk)
Expand All @@ -52,19 +55,21 @@ def stream_sequential_transcription(self):
break
if not self.ending_stream:
if chunk_counter >= batch_size:
self.transcribe_buffer(buffer)
buffer_audio = self.convert_buffer_to_audio(buffer)
self.transcribe_buffer(buffer_audio)
chunk_counter = 0

if not self.audio_chunks.empty():
current_chunk = self.audio_chunks.get()
buffer = self.modify_buffer(current_chunk, buffer)
chunk_counter += 1
else:
logging.info("Client is ending stream, preparing for a final transcription...")
chunk_counter = 0
while not self.audio_chunks.empty():
current_chunk = self.audio_chunks.get()
buffer = self.modify_buffer(current_chunk, buffer)
chunk_counter += 1
if chunk_counter > 0:
self.transcribe_buffer(buffer)
break
buffer_audio = self.convert_buffer_to_audio(buffer)
self.transcribe_buffer(buffer_audio)
break
2 changes: 2 additions & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
-1: "unknown"
}

SPEECH_CONFIDENCE_THRESHOLD = 0.3 # The minimal amount of confidence to determine speech presence in batch (e.g. 0.5 means 50% chance at minimum)

SAMPLE_RATE = 16000
NON_ENGLISH_SPECIFIC_MODELS = ["large", "large-v1", "large-v2"] # Models that don't have an English-only version

Expand Down
21 changes: 21 additions & 0 deletions backend/silero_vad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from config import SPEECH_CONFIDENCE_THRESHOLD


class SileroVAD:

def __init__(self):
self.model, self.utils = torch.hub.load(
repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=True
)
(self.get_speech_timestamps, self.save_audio, self.read_audio, self.VADIterator,
self.collect_chunks) = self.utils

def __call__(self, audio):
confidence = self.model(torch.from_numpy(audio), 16000).item()
return confidence >= SPEECH_CONFIDENCE_THRESHOLD, confidence


silero_vad = SileroVAD()
3 changes: 1 addition & 2 deletions backend/transcription/whisper_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,8 @@ def real_time_transcription(self, diarization, waveform):
speaker_transcriptions = identify_speakers(transcription, diarization, time_shift)
return speaker_transcriptions

def sequential_transcription(self, buffer, diarization):
def sequential_transcription(self, audio, diarization):
# Step 1: Transcribe
audio = buffer.astype("float32").reshape(-1)
transcription = self.transcribe(audio)
# Step 2: Assign speakers
diarizated_transcription = assign_speakers(transcription, diarization)
Expand Down

0 comments on commit 33534eb

Please sign in to comment.