Skip to content

Commit

Permalink
More informative error
Browse files Browse the repository at this point in the history
  • Loading branch information
David Rubinstein committed Oct 1, 2023
1 parent e2dd22b commit b85ed1b
Showing 1 changed file with 38 additions and 6 deletions.
44 changes: 38 additions & 6 deletions basic_pitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import csv
import enum
import json
import logging
import os
import pathlib
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast
Expand Down Expand Up @@ -72,44 +73,75 @@ class MODEL_TYPES(enum.Enum):
ONNX = enum.auto()

def __init__(self, model_path: Union[pathlib.Path, str]):
present = []
if TF_PRESENT:
present.append("TensorFlow")
try:
self.model_type = Model.MODEL_TYPES.TENSORFLOW
self.model = tf.saved_model.load(model_path)
return
except Exception:
pass
if os.path.isdir(model_path) and {"saved_model.pb", "variables"} & set(os.path.listdir(model_path)):
logging.warning(
"Could not load TensorFlowLite file %s even "
"though it looks like a TFLite file with error %s. Are you sure it's a TFLite file",
model_path,
e.__repr__(),
)

if CT_PRESENT:
present.append("CoreML")
try:
self.model_type = Model.MODEL_TYPES.COREML
self.model = ct.models.MLModel(str(model_path))
return
except Exception:
pass
if str(model_path).endswith(".mlpackage"):
logging.warning(
"Could not load TensorFlowLite file %s even "
"though it looks like a TFLite file with error %s. Are you sure it's a TFLite file",
model_path,
e.__repr__(),
)

if TFLITE_PRESENT:
present.append("TensorFlowLite")
try:
self.model_type = Model.MODEL_TYPES.TFLITE
self.interpreter = tflite.Interpreter(model_path=model_path)
self.interpreter.allocate_tensors()
self.model = self.interpreter.get_signature_runner()
return
except Exception:
pass
except Exception as e:
if str(model_path).endswith(".tflite"):
logging.warning(
"Could not load TensorFlowLite file %s even "
"though it looks like a TFLite file with error %s. Are you sure it's a TFLite file",
model_path,
e.__repr__(),
)

if ONNX_PRESENT:
present.append("ONNX")
try:
self.model_type = Model.MODEL_TYPES.ONNX
self.model = ort.InferenceSession(model_path)
return
except Exception:
pass
if str(model_path).endswith(".onnx"):
logging.warning(
"Could not load TensorFlowLite file %s even "
"though it looks like a TFLite file with error %s. Are you sure it's a TFLite file",
model_path,
e.__repr__(),
)

raise ValueError(
f"File {model_path} cannot be loaded into either "
"TensorFlow, CoreML, TFLite or ONNX. "
"Please check if it is a supported and valid serialized model"
"Please check if it is a supported and valid serialized model "
"and that one of these packages are installed. On this system, "
f"{present} is installed."
)

def predict(self, x: npt.NDArray[np.float32]) -> Dict[str, npt.NDArray[np.float32]]:
Expand Down

0 comments on commit b85ed1b

Please sign in to comment.