Skip to content

Commit

Permalink
feat/multi_lang_limited_voc (#3)
Browse files Browse the repository at this point in the history
* feat/multi_lang_limited_voc

- add support for loading multiple languages, allows for lang to be set by request
- add support for OpenVoiceOS/ovos-core#78

authored-by: jarbasai <jarbasai@mailfence.com>
  • Loading branch information
NeonJarbas authored Mar 1, 2022
1 parent b6fc7b1 commit 07f01ae
Showing 1 changed file with 85 additions and 11 deletions.
96 changes: 85 additions & 11 deletions ovos_stt_plugin_vosk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from os.path import isdir
import json
from vosk import Model as KaldiModel, KaldiRecognizer
from queue import Queue
Expand All @@ -7,25 +8,82 @@
from ovos_skill_installer import download_extract_zip, download_extract_tar
from os.path import join, exists, isdir
from ovos_utils.xdg_utils import xdg_data_home
from ovos_utils.file_utils import read_vocab_file, resolve_resource_file, resolve_ovos_resource_file


class VoskKaldiSTT(STT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# model_folder for backwards compat
model_path = self.config.get("model_folder") or \
self.config.get("model")
lang = self.config.get("lang")
if not model_path and lang:
model_path = self.lang2modelurl(lang)
if model_path and model_path.startswith("http"):
model_path = self.download_model(model_path)
if not model_path or not isdir(model_path):
self.model_path = self.config.get("model_folder") or self.config.get("model")
if not self.model_path and self.lang:
self.model_path = self.download_language(self.lang)
if not self.model_path or not isdir(self.model_path):
LOG.error("You need to provide a valid model path or url")
LOG.info(
"download a model from https://alphacephei.com/vosk/models")
raise FileNotFoundError
self.kaldi = KaldiRecognizer(KaldiModel(model_path), 16000)

self.engines = {
self.lang: KaldiRecognizer(KaldiModel(self.model_path), 16000)
}
self.limited_voc_engines = {}
self.limited = False

def download_language(self, lang=None):
lang = lang or self.lang
lang = lang.split("-")[0].lower()
model_path = self.lang2modelurl(lang)
if model_path and model_path.startswith("http"):
model_path = self.download_model(model_path)
return model_path

def load_language(self, lang=None):
lang = lang or self.lang
lang = lang.split("-")[0].lower()
if lang in self.engines or lang in self.limited_voc_engines:
return
model_path = self.download_language(lang)
if model_path:
self.engines[lang] = KaldiRecognizer(KaldiModel(model_path), 16000)
else:
LOG.error(f"No default model available for {lang}")
raise FileNotFoundError

def unload_language(self, lang=None):
lang = lang or self.lang
if lang in self.engines:
del self.engines[lang]
self.engines.pop(lang)
if lang in self.limited_voc_engines:
del self.limited_voc_engines[lang]
self.limited_voc_engines.pop(lang)

def enable_full_vocabulary(self, lang=None):
""" enable default transcription mode """
lang = lang or self.lang
self.limited = False
if lang in self.limited_voc_engines:
self.limited_voc_engines.pop(lang)
self.engines[lang] = KaldiRecognizer(KaldiModel(model_path), 16000)

def enable_limited_vocabulary(self, words, lang=None, permanent=True):
"""
enable limited vocabulary mode
will only consider pre defined .voc files
"""
lang = lang or self.lang
if lang == self.lang:
model_path = self.model_path
else:
model_path = self.lang2modelurl(lang)
if model_path:
self.limited_voc_engines[lang] = KaldiRecognizer(KaldiModel(model_path),
16000, json.dumps(words))
if permanent:
del self.engines[lang]
self.engines[lang] = self.limited_voc_engines[lang]
self.limited = True

@staticmethod
def download_model(url):
Expand Down Expand Up @@ -84,11 +142,27 @@ def lang2modelurl(lang, small=True):
return lang2url.get(lang)

def execute(self, audio, language=None):
self.kaldi.AcceptWaveform(audio.get_wav_data())
res = self.kaldi.FinalResult()
# load a new model on the fly if needed
lang = language or self.lang
self.load_language(lang)

# if limited vocabulary mode is enabled use that model instead
if self.limited:
engine = self.limited_voc_engines.get(lang) or self.engines[lang]
else:
engine = self.engines[lang]

# transcribe
engine.AcceptWaveform(audio.get_wav_data())
res = engine.FinalResult()
res = json.loads(res)
return res["text"]

def shutdown(self):
for lang in set(self.engines.keys()) + \
set(self.limited_voc_engines.keys()):
self.unload_language(lang)


class VoskKaldiStreamThread(StreamThread):
def __init__(self, queue, lang, kaldi, verbose=True):
Expand Down

0 comments on commit 07f01ae

Please sign in to comment.