Skip to content

Commit a9d5159

Browse files
authored
fix signature bugs (#10)
* fix signature bugs * Update setup.py
1 parent 7960f35 commit a9d5159

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup, find_packages
22
setup(
33
name="lobe",
4-
version="0.2.0",
4+
version="0.2.1",
55
packages=find_packages("src"),
66
package_dir={"": "src"},
77
install_requires=[

src/lobe/Signature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def load(model_path: str) -> Signature:
3333
class Signature:
3434
def __init__(self, signature_path: str):
3535
signature_path = pathlib.Path(signature_path)
36-
self.__model_path = signature_path.parent
36+
self.__model_path = str(signature_path.parent)
3737

3838
with open(signature_path, "r") as f:
3939
self.__signature = json.load(f)

src/lobe/backends/_backend_tf.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111

1212
class ImageClassificationModel():
1313
__input_key_image = 'Image'
14-
__input_key_batch_size = "batch_size"
15-
__output_key_labels = 'Labels_idx_000'
16-
__output_key_confidences = 'Labels_idx_001'
14+
__output_key_confidences = 'Confidences'
1715
__output_key_prediction = 'Prediction'
1816

1917
def __init__(self, signature):
2018
self.__model_path = signature.model_path
2119
self.__tf_predict_fn = None
20+
self.__labels = signature.classes
2221

2322
def __load(self):
2423
self.__tf_predict_fn = predictor.from_saved_model(self.__model_path)
@@ -34,11 +33,10 @@ def predict(self, image: Image.Image) -> PredictionResult:
3433
np_image = np_image[np.newaxis, ...]
3534

3635
predictions = self.__tf_predict_fn({
37-
self.__input_key_image: np_image,
38-
self.__input_key_batch_size: 1 })
36+
self.__input_key_image: np_image
37+
})
3938

40-
labels = [label.decode('utf-8') for label in predictions[self.__output_key_labels][0].tolist()]
4139
confidences = predictions[self.__output_key_confidences][0]
4240
top_prediction = predictions[self.__output_key_prediction][0].decode('utf-8')
4341

44-
return PredictionResult(labels=labels, confidences=confidences, prediction=top_prediction)
42+
return PredictionResult(labels=self.__labels, confidences=confidences, prediction=top_prediction)

0 commit comments

Comments
 (0)