Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented Moonshine and set it as the default speech to text service #463

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ furo==2023.9.10
pywin32==306
mss==9.0.1
opencv-python==4.10.0.84
git+https://github.com/usefulsensors/moonshine.git@6a0846b556ca1c1f2f373996fe53b0b79e23cc17#subdirectory=moonshine-onnx
4 changes: 3 additions & 1 deletion src/config/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def __update_config_values_from_current_state(self):
self.use_sr = self.__definitions.get_bool_value("use_sr")

#STT
self.whisper_model = self.__definitions.get_string_value("model_size")
self.stt_service = self.__definitions.get_string_value("stt_service").lower()
self.moonshine_model = self.__definitions.get_string_value("moonshine_model_size")
self.whisper_model = self.__definitions.get_string_value("whisper_model_size")
self.whisper_process_device = self.__definitions.get_string_value("process_device")
self.stt_language = self.__definitions.get_string_value("stt_language")
if (self.stt_language == 'default'):
Expand Down
55 changes: 35 additions & 20 deletions src/config/definitions/stt_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,38 @@ def get_audio_threshold_folder_config_value() -> ConfigValue:
audio_threshold_description = """Controls how much background noise is filtered out.
If the mic is not picking up speech, try lowering this value.
If the mic is picking up too much background noise, try increasing this value."""
return ConfigValueInt("audio_threshold","Audio Threshold",audio_threshold_description, 175, 0, 999, tags=[ConfigValueTag.share_row])
return ConfigValueInt("audio_threshold","Audio Threshold",audio_threshold_description, 175, 0, 999)

@staticmethod
def get_model_size_config_value() -> ConfigValue:
def get_stt_service_config_value() -> ConfigValue:
description = """Choose between running Moonshine or Whisper as your speech to text service.
Moonshine runs faster than Whisper on a CPU, but only support English.
Whisper can run on a CPU, GPU, or via an external service (see Advanced settings below)."""
options = ["Moonshine", "Whisper"]
return ConfigValueSelection("stt_service", "STT Service", description, "Moonshine", options, allows_free_edit=False)

@staticmethod
def get_pause_threshold_config_value() -> ConfigValue:
description = """How long to wait (in seconds) before converting mic input to text.
If you feel like you are being cut off before you finish your response, increase this value.
If you feel like there is too much of a delay between you finishing your response and the text conversion, decrease this value.
It is recommended to set to this value to at least 0.1."""
return ConfigValueFloat("pause_threshold","Pause Threshold", description, 1.0, 0, 999, tags=[ConfigValueTag.advanced,ConfigValueTag.share_row])

@staticmethod
def get_listen_timeout_config_value() -> ConfigValue:
description = """How long to wait (in seconds) for the player to speak before retrying.
This needs to be set to ensure that Mantella can periodically check if the conversation has ended."""
return ConfigValueInt("listen_timeout","Listen Timeout", description, 30, 0, 999, tags=[ConfigValueTag.advanced,ConfigValueTag.share_row])

@staticmethod
def get_moonshine_model_size_config_value() -> ConfigValue:
description = """The size of the Moonshine model to use. The larger the model, the more accurate the transcription (at the cost of speed)."""
options = ["moonshine/tiny", "moonshine/base"]
return ConfigValueSelection("moonshine_model_size", "Moonshine Model", description, "moonshine/tiny", options, allows_free_edit=True, tags=[ConfigValueTag.advanced,ConfigValueTag.share_row])

@staticmethod
def get_whisper_model_size_config_value() -> ConfigValue:
description = """The size of the Whisper model used. Some languages require larger models. The base.en model works well enough for English.
See here for a comparison of languages and their Whisper performance:
https://github.com/openai/whisper#available-models-and-languages"""
Expand All @@ -43,7 +71,7 @@ def get_model_size_config_value() -> ConfigValue:
"medium", "medium.en", "distil-medium.en",
"large-v1", "large-v2", "large-v3", "distil-large-v2", "distil-large-v3",
"whisper-1"]
return ConfigValueSelection("model_size", "Model Size", description, "base", options, allows_free_edit=True, tags=[ConfigValueTag.share_row])
return ConfigValueSelection("whisper_model_size", "Whisper Model", description, "base", options, allows_free_edit=True, tags=[ConfigValueTag.advanced,ConfigValueTag.share_row])

@staticmethod
def get_external_whisper_service_config_value() -> ConfigValue:
Expand All @@ -59,33 +87,20 @@ def get_whisper_url_config_value() -> ConfigValue:
OpenAI: Ensure 'Speech-to-Text'->'Model Size' is set to `whisper-1`. Requires an OpenAI secret key.
Groq: Ensure 'Speech-to-Text'->'Model Size' is set to one of the following: https://console.groq.com/docs/speech-text#supported-models. Requires a Groq secret key.
whisper.cpp: whisper.cpp (https://github.com/ggerganov/whisper.cpp) can be connected to when it is run in server mode. No secret key is required. Ensure the server is running before starting Mantella. By default, selecting whisper.cpp will connect to the URL http://127.0.0.1:8080/inference, but you can also manually enter a URL in this field if you have selected a port other than 8080 or are running whisper.cpp on another machine."""
return ConfigValueSelection("whisper_url", "Whisper URL", description, "OpenAI", ["OpenAI", "Groq", "whisper.cpp"], allows_free_edit=True, tags=[ConfigValueTag.advanced])

@staticmethod
def get_pause_threshold_config_value() -> ConfigValue:
description = """How long to wait (in seconds) before converting mic input to text.
If you feel like you are being cut off before you finish your response, increase this value.
If you feel like there is too much of a delay between you finishing your response and the text conversion, decrease this value."""
return ConfigValueFloat("pause_threshold","Pause Threshold", description, 1.0, 0.1, 999, tags=[ConfigValueTag.advanced,ConfigValueTag.share_row])

@staticmethod
def get_listen_timeout_config_value() -> ConfigValue:
description = """How long to wait (in seconds) for the player to speak before retrying.
This needs to be set to ensure that Mantella can periodically check if the conversation has ended."""
return ConfigValueInt("listen_timeout","Listen Timeout", description, 30, 0, 999, tags=[ConfigValueTag.advanced,ConfigValueTag.share_row])
return ConfigValueSelection("whisper_url", "Whisper Service", description, "OpenAI", ["OpenAI", "Groq", "whisper.cpp"], allows_free_edit=True, tags=[ConfigValueTag.advanced])

@staticmethod
def get_stt_language_config_value() -> ConfigValue:
description = """The player's spoken language."""
return ConfigValueSelection("stt_language","STT Language",description,"default",["default","en", "ar", "cs", "da", "de", "el", "es", "fi", "fr", "hi", "hu", "it", "ja", "ko", "nl", "pl", "pt", "ro", "ru", "sv", "sw", "uk", "ha", "tr", "vi", "yo"], tags=[ConfigValueTag.advanced,ConfigValueTag.share_row])
return ConfigValueSelection("stt_language","Whisper STT Language",description,"default",["default","en", "ar", "cs", "da", "de", "el", "es", "fi", "fr", "hi", "hu", "it", "ja", "ko", "nl", "pl", "pt", "ro", "ru", "sv", "sw", "uk", "ha", "tr", "vi", "yo"], tags=[ConfigValueTag.advanced,ConfigValueTag.share_row])

@staticmethod
def get_stt_translate_config_value() -> ConfigValue:
description = """Translate the transcribed speech to English if supported by the Speech-To-Text engine (only impacts faster_whisper option, no impact on whispercpp, which is controlled by your server).
STTs that support this function: Whisper (faster_whisper)."""
return ConfigValueBool("stt_translate", "STT Translate",description, False, tags=[ConfigValueTag.advanced,ConfigValueTag.share_row])
return ConfigValueBool("stt_translate", "Whisper STT Translate",description, False, tags=[ConfigValueTag.advanced,ConfigValueTag.share_row])

@staticmethod
def get_process_device_config_value() -> ConfigValue:
description = "Whether to run Whisper on your CPU or NVIDIA GPU (with CUDA installed) (only impacts faster_whisper option, no impact on whispercpp, which is controlled by your server)."
return ConfigValueSelection("process_device", "Process Device", description,"cpu",["cpu","cuda"], constraints=[STTDefinitions.WhisperProcessDeviceChecker()], tags=[ConfigValueTag.advanced])
return ConfigValueSelection("process_device", "Whisper Process Device", description,"cpu",["cpu","cuda"], constraints=[STTDefinitions.WhisperProcessDeviceChecker()], tags=[ConfigValueTag.advanced])
8 changes: 5 additions & 3 deletions src/config/mantella_config_value_definitions_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ def get_config_values(is_integrated: bool, actions: list[action], on_value_chang
stt_category = ConfigValueGroup("STT", "Speech-to-Text", "Settings for the STT methods Mantella supports.", on_value_change_callback)
stt_category.add_config_value(STTDefinitions.get_use_automatic_audio_threshold_folder_config_value())
stt_category.add_config_value(STTDefinitions.get_audio_threshold_folder_config_value())
stt_category.add_config_value(STTDefinitions.get_model_size_config_value())
stt_category.add_config_value(STTDefinitions.get_external_whisper_service_config_value())
stt_category.add_config_value(STTDefinitions.get_whisper_url_config_value())
stt_category.add_config_value(STTDefinitions.get_stt_service_config_value())
stt_category.add_config_value(STTDefinitions.get_pause_threshold_config_value())
stt_category.add_config_value(STTDefinitions.get_listen_timeout_config_value())
stt_category.add_config_value(STTDefinitions.get_moonshine_model_size_config_value())
stt_category.add_config_value(STTDefinitions.get_whisper_model_size_config_value())
stt_category.add_config_value(STTDefinitions.get_external_whisper_service_config_value())
stt_category.add_config_value(STTDefinitions.get_whisper_url_config_value())
stt_category.add_config_value(STTDefinitions.get_stt_language_config_value())
stt_category.add_config_value(STTDefinitions.get_stt_translate_config_value())
stt_category.add_config_value(STTDefinitions.get_process_device_config_value())
Expand Down
68 changes: 49 additions & 19 deletions src/stt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import numpy as np
from faster_whisper import WhisperModel
import speech_recognition as sr
import logging
Expand All @@ -17,6 +18,7 @@
import queue
import threading
import time
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer

@dataclass
class TranscriptionJob:
Expand Down Expand Up @@ -48,7 +50,9 @@ def __init__(self, config: ConfigLoader, stt_secret_key_file: str, secret_key_fi
# self.mic_enabled = config.mic_enabled
self.language = config.stt_language
self.task = "translate" if config.stt_translate == 1 else "transcribe"
self.model = config.whisper_model
self.stt_service = config.stt_service
self.moonshine_model = config.moonshine_model
self.whisper_model = config.whisper_model
self.process_device = config.whisper_process_device
self.audio_threshold = config.audio_threshold
self.listen_timeout = config.listen_timeout
Expand All @@ -57,7 +61,12 @@ def __init__(self, config: ConfigLoader, stt_secret_key_file: str, secret_key_fi
self.whisper_url = self.__get_endpoint(config.whisper_url)
self.pause_threshold = config.pause_threshold
# heavy-handed fix to non_speaking_duration as it it always required to be less than pause_threshold
self.non_speaking_duration = 0.5 if self.pause_threshold > 0.5 else self.pause_threshold - 0.01
if self.pause_threshold > 0.5:
self.non_speaking_duration = 0.5
elif self.pause_threshold == 0:
self.non_speaking_duration = 0
else:
self.non_speaking_duration = self.pause_threshold - 0.0001
self.show_mic_warning = True

self.end_conversation_keyword = config.end_conversation_keyword
Expand All @@ -69,7 +78,7 @@ def __init__(self, config: ConfigLoader, stt_secret_key_file: str, secret_key_fi
self.__secret_key_file = secret_key_file
self.__api_key: str | None = self.__get_api_key()
self.__initial_client: OpenAI | None = None
if (self.__api_key) and ('openai' in self.whisper_url):
if (self.__api_key) and ('openai' in self.whisper_url) and (self.external_whisper_service):
self.__initial_client = self.__generate_sync_client() # initialize first client in advance to save time

self.__ignore_list = ['', 'thank you', 'thank you for watching', 'thanks for watching', 'the transcript is from the', 'the', 'thank you very much', "thank you for watching and i'll see you in the next video", "we'll see you in the next video", 'see you next time']
Expand All @@ -89,13 +98,18 @@ def __init__(self, config: ConfigLoader, stt_secret_key_file: str, secret_key_fi
self.recognizer.energy_threshold = int(self.audio_threshold)
logging.log(self.loglevel, f"Audio threshold set to {self.audio_threshold}. If the mic is not picking up your voice, try lowering this `Speech-to-Text`->`Audio Threshold` value in the Mantella UI. If the mic is picking up too much background noise, try increasing this value.\n")

self.transcribe_model: WhisperModel | None = None
# if using faster_whisper, load model selected by player, otherwise skip this step
if not self.external_whisper_service:
if self.process_device == 'cuda':
self.transcribe_model = WhisperModel(self.model, device=self.process_device)
else:
self.transcribe_model = WhisperModel(self.model, device=self.process_device, compute_type="float32")
self.transcribe_model: WhisperModel | MoonshineOnnxModel | None = None

if self.stt_service == 'whisper':
# if using faster_whisper, load model selected by player, otherwise skip this step
if not self.external_whisper_service:
if self.process_device == 'cuda':
self.transcribe_model = WhisperModel(self.whisper_model, device=self.process_device)
else:
self.transcribe_model = WhisperModel(self.whisper_model, device=self.process_device, compute_type="float32")
else:
self.transcribe_model = MoonshineOnnxModel(model_name=self.moonshine_model)
self.tokenizer = load_tokenizer()

# Thread management
self.__listen_thread: Optional[threading.Thread] = None
Expand Down Expand Up @@ -173,29 +187,42 @@ def whisper_transcribe(self, audio, prompt: str):
elif 'openai' in self.whisper_url: # OpenAI compatible endpoint
client = self.__generate_sync_client()
try:
response_data = client.audio.transcriptions.create(model=self.model, language=self.language, file=audio, prompt=prompt)
response_data = client.audio.transcriptions.create(model=self.whisper_model, language=self.language, file=audio, prompt=prompt)
except Exception as e:
if e.code in [404, 'model_not_found']:
if self.whisper_service == 'OpenAI':
logging.error(f"Selected Whisper model '{self.model}' does not exist in the OpenAI service. Try changing 'Speech-to-Text'->'Model Size' to 'whisper-1' in the Mantella UI")
logging.error(f"Selected Whisper model '{self.whisper_model}' does not exist in the OpenAI service. Try changing 'Speech-to-Text'->'Model Size' to 'whisper-1' in the Mantella UI")
elif self.whisper_service == 'Groq':
logging.error(f"Selected Whisper model '{self.model}' does not exist in the Groq service. Try changing 'Speech-to-Text'->'Model Size' to one of the following models in the Mantella UI: https://console.groq.com/docs/speech-text#supported-models")
logging.error(f"Selected Whisper model '{self.whisper_model}' does not exist in the Groq service. Try changing 'Speech-to-Text'->'Model Size' to one of the following models in the Mantella UI: https://console.groq.com/docs/speech-text#supported-models")
else:
logging.error(f"Selected Whisper model '{self.model}' does not exist in the selected service {self.whisper_service}. Try changing 'Speech-to-Text'->'Model Size' to a compatible model in the Mantella UI")
logging.error(f"Selected Whisper model '{self.whisper_model}' does not exist in the selected service {self.whisper_service}. Try changing 'Speech-to-Text'->'Model Size' to a compatible model in the Mantella UI")
else:
logging.error(f'STT error: {e}')
input("Press Enter to exit.")
client.close()
return response_data.text.strip()
else: # custom server model
data = {'model': self.model, 'prompt': prompt}
data = {'model': self.whisper_model, 'prompt': prompt}
files = {'file': ('audio.wav', audio, 'audio/wav')}
response = requests.post(self.whisper_url, files=files, data=data)
if response.status_code != 200:
logging.error(f'STT Error: {response.content}')
response_data = json.loads(response.text)
if 'text' in response_data:
return response_data['text'].strip()


@utils.time_it
def moonshine_transcribe(self, audio_data: bytes) -> str:
"""Transcribe audio using Moonshine model"""
# Convert wav data to numpy array
audio_np = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0

# Generate transcription
tokens = self.transcribe_model.generate(audio_np[np.newaxis, :])
text = self.tokenizer.decode_batch(tokens)[0]

return text.strip()


@utils.time_it
Expand Down Expand Up @@ -240,10 +267,13 @@ def __process_transcriptions(self):
continue

audio_data = capture.audio_data.get_wav_data(convert_rate=16_000)
#transcript = base64.b64encode(audio_data).decode('utf-8')
audio_file = io.BytesIO(audio_data)
audio_file.name = 'out.wav'
transcript = self.whisper_transcribe(audio_file, capture.prompt)

if self.stt_service == 'whisper':
audio_file = io.BytesIO(audio_data)
audio_file.name = 'out.wav' # audio file needs a name or else Whisper gets angry
transcript = self.whisper_transcribe(audio_file, capture.prompt)
else:
transcript = self.moonshine_transcribe(audio_data)

transcript_cleaned = utils.clean_text(transcript)

Expand Down