Skip to content

Commit

Permalink
Merge pull request saharmor#40 from saharmor/fix/hallucinations
Browse files Browse the repository at this point in the history
Fix hallucinations + bugfixes
  • Loading branch information
saharmor authored Aug 17, 2023
2 parents cf496f6 + b53face commit 2db92ab
Show file tree
Hide file tree
Showing 13 changed files with 212 additions and 137 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ Note: You need to have a Hugging Face account to use pyannote

1. [In the sequential mode, there may be uncontrolled speaker swapping.](https://github.com/saharmor/whisper-playground/issues/27)
2. [In real-time mode, audio data not meeting the transcription timeout won't be transcribed.](https://github.com/saharmor/whisper-playground/issues/28)
3. [Speechless batches may cause hallucinations.](https://github.com/saharmor/whisper-playground/issues/25)

This repository hasn't been tested for all languages; please create an issue if you encounter any problems.

Expand Down
39 changes: 17 additions & 22 deletions backend/client_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
import asyncio
import threading
from clients.utils import initialize_client
from clients.utils import get_client_class
from utils import cleanup
from config import ClientState


class ClientManager:
Expand All @@ -11,21 +12,19 @@ def __init__(self):
self.clients = {}

async def create_new_client(self, sid, sio, config):
self.clients[sid] = "initializing"
new_client = initialize_client(sid, sio, config)
new_client = get_client_class(config)(sid, sio, config)
self.clients[sid] = new_client
new_client.initialize_client()
if self.clients.get(sid):
self.clients[sid] = new_client
await new_client.start_transcribing()
else:
logging.warning("Client removed before transcription could start")
logging.warning("Client removed before transcription could start, new client may connect")

async def start_stream(self, sid, sio, config):
if not self.clients:
if sid not in self.clients.keys():
threading.Thread(target=asyncio.run, args=(self.create_new_client(sid, sio, config),)).start()
else:
logging.warning("A streaming client tried to initiate another stream")
await sio.emit("clientAlreadyStreaming")
else:
logging.warning("A new client tried to start streaming when there is already a client streaming")
await sio.emit("noMoreClientsAllowed")
Expand All @@ -48,21 +47,17 @@ async def end_stream(self, sid):
def disconnect_from_stream(self, sid):
if sid in self.clients.keys():
client = self.clients[sid]
if not client.disconnected:
cleanup_needed = client.cleanup_needed
if client != "initializing":
client.handle_disconnection()
# No error if client is still not an object, it won't get to that point
if client == "initializing" or not client.is_ending_stream():
try:
self.clients.pop(sid)
except KeyError:
logging.warning("disconnected_from_stream attempted to remove an already removed client")
logging.info("Disconnected client removed")
if cleanup_needed:
cleanup()
else:
logging.warning("Disconnecting client attempting to disconnect multiple times")
client.handle_disconnection()
cleanup_needed = client.cleanup_needed
client_state = client.get_state()
if client_state == ClientState.NOT_INITIALIZED or not client_state == ClientState.ENDING_STREAM:
try:
self.clients.pop(sid)
except KeyError:
logging.warning("disconnected_from_stream attempted to remove an already removed client")
logging.info("Disconnected client removed")
if cleanup_needed:
cleanup()
else:
logging.warning("A non-existent client tried to disconnect from the stream.")

Expand Down
48 changes: 35 additions & 13 deletions backend/clients/Client.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,71 @@
import logging
from queue import Queue
from silero_vad import silero_vad
from diart.utils import decode_audio
from utils import get_transcriber_information
from transcription.whisper_transcriber import WhisperTranscriber
from abc import abstractmethod
from config import ClientState


class Client:

def __init__(self, sid, socket, transcriber, transcription_timeout):
def __init__(self, sid, socket, config):
self.sid = sid
self.config = config
self.diarization_pipeline = None
self.transcriber = transcriber
self.transcription_timeout = transcription_timeout
self.transcriber = None
self.transcription_timeout = None
self.socket = socket
self.audio_chunks = Queue()
self.transcription_thread = None
self.disconnected = False
self.ending_stream = False
self.cleanup_needed = False
self.state = ClientState.NOT_INITIALIZED

def initialize_client(self):
whisper_model_size, language_code = get_transcriber_information(self.config)
try:
beam_size = int(self.config.get("beamSize", 1))
except TypeError:
logging.warning(f"Invalid beam size {self.config.get('beamSize')}, defaulting to 1")
beam_size = 1
self.transcriber = WhisperTranscriber(model_size=whisper_model_size, language_code=language_code,
beam_size=beam_size)
self.transcription_timeout = int(self.config.get("transcribeTimeout", 5))
self.state = ClientState.INITIALIZED

@abstractmethod
async def start_transcribing(self):
pass
if self.transcriber is None:
raise ValueError("The transcriber must be defined before using this method")

async def stop_transcribing(self):
self.ending_stream = True
self.state = ClientState.ENDING_STREAM
self.transcription_thread.join()
logging.info("Transcription thread closed due to completion (stream ended)")
await self.socket.emit("whisperingStopped")
logging.info("Stream end signaled to client")

def handle_disconnection(self):
logging.info("Starting disconnection process, no longer sending transcriptions to client")
self.disconnected = True
if not self.ending_stream:
if self.state not in [ClientState.ENDING_STREAM, ClientState.NOT_INITIALIZED]:
self.state = ClientState.DISCONNECTED
self.transcription_thread.join()
logging.info("Transcription thread closed due to disconnection")

async def send_transcription(self, transcription):
logging.info(f"Transcription generated: {transcription}")
if not self.disconnected:
if self.state != ClientState.DISCONNECTED:
await self.socket.emit("transcriptionDataAvailable", transcription)
logging.info("Transcription sent")
else:
logging.info("Transcription not sent, client disconnected")

def handle_chunk(self, chunk):
self.audio_chunks.put(chunk)
speech_present, speech_confidence = silero_vad(decode_audio(chunk))
if speech_present:
self.audio_chunks.put(chunk)
logging.debug("Chunk added")

def is_ending_stream(self):
return self.ending_stream
def get_state(self):
return self.state
13 changes: 6 additions & 7 deletions backend/clients/RealTimeClient.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

from diart import OnlineSpeakerDiarization
from config import DIARIZATION_PIPELINE_CONFIG
from config import DIARIZATION_PIPELINE_CONFIG, ClientState
import asyncio
import diart.operators as dops
import rx.operators as ops
Expand All @@ -13,8 +12,8 @@

class RealTimeClient(Client):

def __init__(self, sid, socket, transcriber, transcription_timeout):
super().__init__(sid, socket, transcriber, transcription_timeout)
def __init__(self, sid, socket, config):
super().__init__(sid, socket, config)
self.pipeline_config = DIARIZATION_PIPELINE_CONFIG
self.diarization_pipeline = OnlineSpeakerDiarization(self.pipeline_config)
self.chunk_receiving_thread = None
Expand All @@ -28,7 +27,7 @@ async def start_transcribing(self):
await self.socket.emit("whisperingStarted")
logging.info("Stream start signaled to client")

def receive_chunk(self, chunk):
def receive_chunk(self, chunk: str):
self.source.receive_chunk(chunk)

def complete_stream(self):
Expand All @@ -38,7 +37,7 @@ def complete_stream(self):
def receive_chunks(self):
logging.info("New chunks handler started")
while True:
if self.disconnected:
if self.state == ClientState.DISCONNECTED:
logging.info("Client disconnected, ending transcription...")
self.complete_stream()
return
Expand All @@ -47,7 +46,7 @@ def receive_chunks(self):
# not a heavy operation but a blocking one during pipeline execution, shouldn't block the main thread thanks to threading
self.source.receive_chunk(current_chunk)
else:
if self.ending_stream:
if self.state == ClientState.ENDING_STREAM:
logging.info("No more chunks, preparing for a final transcription...")
self.complete_stream()
return
Expand Down
35 changes: 21 additions & 14 deletions backend/clients/SequentialClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from diart.utils import decode_audio
from utils import save_batch_to_wav
import numpy as np
from config import STEP, TEMP_FILE_PATH
from config import STEP, TEMP_FILE_PATH, REQUIRED_AUDIO_TYPE, ClientState
from pyannote.audio import Pipeline
from clients.Client import Client


class SequentialClient(Client):

def __init__(self, sid, socket, transcriber, transcription_timeout):
super().__init__(sid, socket, transcriber, transcription_timeout)
def __init__(self, sid, socket, config):
super().__init__(sid, socket, config)
self.diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization")
self.cleanup_needed = True

Expand All @@ -22,19 +22,24 @@ async def start_transcribing(self):
await self.socket.emit("whisperingStarted")
logging.info("Stream start signaled to client")

def get_diarization(self, waveform):
audio = waveform.astype("float32").reshape(-1)
save_batch_to_wav(audio, TEMP_FILE_PATH)
def get_diarization(self, buffer: np.ndarray):
assert buffer.dtype == REQUIRED_AUDIO_TYPE, f"audio array data type must be {REQUIRED_AUDIO_TYPE}"
save_batch_to_wav(buffer, TEMP_FILE_PATH)
diarization = self.diarization_pipeline(TEMP_FILE_PATH)
return diarization

def transcribe_buffer(self, buffer):
def transcribe_buffer(self, buffer: np.ndarray):
assert buffer.dtype == REQUIRED_AUDIO_TYPE, f"audio array data type must be {REQUIRED_AUDIO_TYPE}"
diarization = self.get_diarization(buffer)
result = self.transcriber.sequential_transcription(buffer, diarization)
asyncio.run(self.send_transcription(result))

@staticmethod
def modify_buffer(chunk, buffer):
def convert_buffer_to_float32(buffer: np.ndarray):
return buffer.astype("float32").reshape(-1)

@staticmethod
def modify_buffer(chunk: str, buffer: np.ndarray):
decoded_chunk = decode_audio(chunk)
buffer = decoded_chunk if buffer is None else np.concatenate([buffer, decoded_chunk], axis=1)
return buffer
Expand All @@ -47,24 +52,26 @@ def stream_sequential_transcription(self):
assert batch_size > 0, "batch size must be above 0"

while True:
if self.disconnected:
if self.state == ClientState.DISCONNECTED:
logging.info("Client disconnected, ending transcription...")
break
if not self.ending_stream:
if not self.state == ClientState.ENDING_STREAM:
if chunk_counter >= batch_size:
self.transcribe_buffer(buffer)
buffer_float32 = self.convert_buffer_to_float32(buffer)
self.transcribe_buffer(buffer_float32)
chunk_counter = 0

if not self.audio_chunks.empty():
current_chunk = self.audio_chunks.get()
buffer = self.modify_buffer(current_chunk, buffer)
chunk_counter += 1
else:
logging.info("Client is ending stream, preparing for a final transcription...")
chunk_counter = 0
while not self.audio_chunks.empty():
current_chunk = self.audio_chunks.get()
buffer = self.modify_buffer(current_chunk, buffer)
chunk_counter += 1
if chunk_counter > 0:
self.transcribe_buffer(buffer)
break
buffer_float32 = self.convert_buffer_to_float32(buffer)
self.transcribe_buffer(buffer_float32)
break
42 changes: 3 additions & 39 deletions backend/clients/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
from clients.SequentialClient import SequentialClient
from enum import Enum

from transcription.whisper_transcriber import WhisperTranscriber
from utils import format_whisper_model_name
from config import WhisperModelSize, LANGUAGE_MAPPING


class TranscriptionMethod(Enum):
REAL_TIME = RealTimeClient
Expand All @@ -20,40 +16,8 @@ def format_transcription_method_name(transcription_method):
def get_client_class(config):
transcription_method_name = format_transcription_method_name(config.get("transcriptionMethod"))
try:
client_class = getattr(TranscriptionMethod, transcription_method_name).value
client_class = getattr(TranscriptionMethod, transcription_method_name)
except AttributeError:
logging.warning(f"Invalid transcription method {transcription_method_name}, defaulting to sequential.")
client_class = TranscriptionMethod.SEQUENTIAL.value
return client_class


def get_whisper_model_name(config):
# Format the name received from the client to match the enum members
whisper_model_name = format_whisper_model_name(config.get("model", "small"))
try:
# Retrieve the corresponding enum member
whisper_model = getattr(WhisperModelSize, whisper_model_name)
except AttributeError:
logging.warning(f"Invalid model size {whisper_model_name}, defaulting to small")
whisper_model = WhisperModelSize.SMALL
language = config.get("language", "english")
try:
language_code = LANGUAGE_MAPPING[language.lower()]
except KeyError:
logging.warning(f"Language {language} not supported, defaulting to English")
language_code = "en"
return whisper_model, language_code


def initialize_client(sid, socket, config):
client_class = get_client_class(config)
whisper_model, language_code = get_whisper_model_name(config)
try:
beam_size = int(config.get("beamSize", 1))
except TypeError:
logging.warning(f"Invalid beam size {config.get('beamSize')}, defaulting to 1")
beam_size = 1
transcriber = WhisperTranscriber(model_name=whisper_model.value, language_code=language_code, beam_size=beam_size)
transcription_timeout = int(config.get("transcribeTimeout", 5))
new_client = client_class(sid, socket, transcriber, transcription_timeout)
return new_client
client_class = TranscriptionMethod.SEQUENTIAL
return client_class.value
19 changes: 18 additions & 1 deletion backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,37 @@
-1: "unknown"
}

SPEECH_CONFIDENCE_THRESHOLD = 0.3 # The minimal amount of confidence to determine speech presence in batch (e.g. 0.5 means 50% chance at minimum)

SAMPLE_RATE = 16000
NON_ENGLISH_SPECIFIC_MODELS = ["large", "large-v1", "large-v2"] # Models that don't have an English-only version

TEMP_FILE_PATH = "temp/batch.wav" # Path to the temporary file used for batch transcription in SequentialClient


class ClientState(Enum):
NOT_INITIALIZED = "not_initialized"
INITIALIZED = "initialized"
ENDING_STREAM = "ending_stream"
DISCONNECTED = "disconnected"


class WhisperModelSize(Enum):
TINY = 'tiny'
TINY_ENGLISH = 'tiny.en'
BASE = 'base'
BASE_ENGLISH = 'base.en'
SMALL = 'small'
SMALL_ENGLISH = 'small.en'
MEDIUM = 'medium'
MEDIUM_ENGLISH = 'medium.en'
LARGE_V1 = 'large-v1'
LARGE_V2 = 'large-v2'


NON_ENGLISH_SPECIFIC_MODELS = [WhisperModelSize.LARGE_V1, WhisperModelSize.LARGE_V2] # Models that don't have an English-only version

REQUIRED_AUDIO_TYPE = "float32"

# Language code mapping
LANGUAGE_MAPPING = {
"english": "en",
Expand Down
Loading

0 comments on commit 2db92ab

Please sign in to comment.