Skip to content

Commit

Permalink
TFLite should be enabled if TF is installed
Browse files Browse the repository at this point in the history
  • Loading branch information
David Rubinstein committed Oct 2, 2023
1 parent 82abe24 commit e946bb1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
24 changes: 13 additions & 11 deletions basic_pitch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,6 @@
import logging
import pathlib

try:
import tensorflow

TF_PRESENT = True
except ImportError:
TF_PRESENT = False
logging.warning(
"Tensorflow is not installed. "
"If you plan to use a TF Saved Model, "
"reinstall basic-pitch with `pip install 'basic-pitch[tf]'`"
)

try:
import coremltools
Expand Down Expand Up @@ -69,6 +58,19 @@
)


try:
import tensorflow

TF_PRESENT = True
except ImportError:
TF_PRESENT = False
logging.warning(
"Tensorflow is not installed. "
"If you plan to use a TF Saved Model, "
"reinstall basic-pitch with `pip install 'basic-pitch[tf]'`"
)


class FilenameSuffix(enum.Enum):
tf = "nmp"
coreml = "nmp.mlpackage"
Expand Down
27 changes: 20 additions & 7 deletions basic_pitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast


from basic_pitch import CT_PRESENT, ONNX_PRESENT, TF_PRESENT, TFLITE_PRESENT

try:
import tensorflow as tf
except ImportError:
Expand All @@ -37,7 +39,8 @@
try:
import tflite_runtime.interpreter as tflite
except ImportError:
pass
if TF_PRESENT:
import tensorflow.lite as tflite

try:
import onnxruntime as ort
Expand All @@ -55,7 +58,6 @@
ANNOTATIONS_FPS,
FFT_HOP,
)
from basic_pitch import CT_PRESENT, ONNX_PRESENT, TF_PRESENT, TFLITE_PRESENT
from basic_pitch.commandline_printing import (
generating_file_message,
no_tf_warnings,
Expand Down Expand Up @@ -106,7 +108,7 @@ def __init__(self, model_path: Union[pathlib.Path, str]):
e.__repr__(),
)

if TFLITE_PRESENT:
if TFLITE_PRESENT or TF_PRESENT:
present.append("TensorFlowLite")
try:
self.model_type = Model.MODEL_TYPES.TFLITE
Expand Down Expand Up @@ -153,11 +155,22 @@ def predict(self, x: npt.NDArray[np.float32]) -> Dict[str, npt.NDArray[np.float3
elif self.model_type == Model.MODEL_TYPES.COREML:
return cast(ct.models.MLModel, self.model.predict({"input": x.tolist()})) # type: ignore
elif self.model_type == Model.MODEL_TYPES.TFLITE:
return cast(tflite.SignatureRunner, self.model)(x) # type: ignore
return self.model(input_2=x) # type: ignore
elif self.model_type == Model.MODEL_TYPES.ONNX:
return cast(ort.InferenceSession, self.model).run( # type: ignore
["note", "onset", "contour"], {"input": x}
)
return {
k: v
for k, v in zip(
["note", "onset", "contour"],
cast(ort.InferenceSession, self.model).run(
[
"StatefulPartitionedCall:1",
"StatefulPartitionedCall:2",
"StatefulPartitionedCall:0",
],
{"serving_default_input_2:0": x},
),
)
}


def window_audio_file(
Expand Down

0 comments on commit e946bb1

Please sign in to comment.