Skip to content

Commit

Permalink
Fixed bug with custom verifier model loading/prediction and increment…
Browse files Browse the repository at this point in the history
…ed versioning accordingly
  • Loading branch information
dscripka committed Mar 6, 2023
1 parent a0311f2 commit 8322a96
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
19 changes: 14 additions & 5 deletions openwakeword/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
class_mapping_dicts: List[dict] = [],
enable_speex_noise_suppression: bool = False,
vad_threshold: float = 0,
custom_verifier_models: Union[bool, dict] = False,
custom_verifier_models: dict = {},
custom_verifier_threshold: float = 0.1,
**kwargs
):
Expand Down Expand Up @@ -112,6 +112,14 @@ def __init__(
if custom_verifier_models.get(mdl_name, False):
self.custom_verifier_models[mdl_name] = pickle.load(open(custom_verifier_models[mdl_name], 'rb'))

if len(self.custom_verifier_models.keys()) < len(custom_verifier_models.keys()):
raise ValueError(
"Custom verifier models were provided, but some were not matched with a base model!"
" Make sure that the keys provided in the `custom_verifier_models` dictionary argument"
" exactly match that of the `.models` attribute of an instantiated openWakeWord Model object"
" that has the same base models but doesn't have custom verifier models."
)

# Create buffer to store frame predictions
self.prediction_buffer: DefaultDict[str, deque] = defaultdict(partial(deque, maxlen=30))

Expand Down Expand Up @@ -208,10 +216,11 @@ def predict(self, x: np.ndarray, patience: dict = {}, threshold: dict = {}, timi
for cls in predictions.keys():
if predictions[cls] >= self.custom_verifier_threshold:
parent_model = self.get_parent_model_from_label(cls)
verifier_prediction = self.custom_verifier_models[parent_model].predict_proba(
self.preprocessor.get_features(self.model_inputs[mdl])
)[0][-1]
predictions[cls] = verifier_prediction
if self.custom_verifier_models.get(parent_model, False):
verifier_prediction = self.custom_verifier_models[parent_model].predict_proba(
self.preprocessor.get_features(self.model_inputs[mdl])
)[0][-1]
predictions[cls] = verifier_prediction

# Update prediction buffer, and zero predictions for first 5 frames during model initialization
for cls in predictions.keys():
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ testpaths = [

[project]
name = "openwakeword"
version = "0.3.0"
version = "0.3.1"
authors = [
{ name="David Scripka", email="david.scripka@gmail.com" },
]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def build_additional_requires():

setuptools.setup(
name="openwakeword",
version="0.3.0",
version="0.3.1",
install_requires=['onnxruntime>=1.10.0,<2', 'tqdm>=4.0,<5.0', 'scipy>=1.3,<2', 'scikit-learn>=1,<2'],
extras_require={
'test': [
Expand Down
10 changes: 9 additions & 1 deletion tests/test_custom_verifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,15 @@ def test_train_verifier_model(self):
model_name=os.path.join("openwakeword", "resources", "models", "hey_mycroft_v0.1.onnx")
)

# Load model with verifier model
with pytest.raises(ValueError):
# Load model with verifier model incorrectly to catch ValueError
owwModel = openwakeword.Model(
wakeword_model_paths=[os.path.join("openwakeword", "resources", "models", "hey_mycroft_v0.1.onnx")],
custom_verifier_models={"bad_key": os.path.join(tmp_dir, "verifier_model.pkl")},
custom_verifier_threshold=0.3,
)

# Load model with verifier model incorrectly to catch ValueError
owwModel = openwakeword.Model(
wakeword_model_paths=[os.path.join("openwakeword", "resources", "models", "hey_mycroft_v0.1.onnx")],
custom_verifier_models={"hey_mycroft_v0.1": os.path.join(tmp_dir, "verifier_model.pkl")},
Expand Down

0 comments on commit 8322a96

Please sign in to comment.