1111
1212class ImageClassificationModel ():
1313 __input_key_image = 'Image'
14- __input_key_batch_size = "batch_size"
15- __output_key_labels = 'Labels_idx_000'
16- __output_key_confidences = 'Labels_idx_001'
14+ __output_key_confidences = 'Confidences'
1715 __output_key_prediction = 'Prediction'
1816
1917 def __init__ (self , signature ):
2018 self .__model_path = signature .model_path
2119 self .__tf_predict_fn = None
20+ self .__labels = signature .classes
2221
2322 def __load (self ):
2423 self .__tf_predict_fn = predictor .from_saved_model (self .__model_path )
@@ -34,11 +33,10 @@ def predict(self, image: Image.Image) -> PredictionResult:
3433 np_image = np_image [np .newaxis , ...]
3534
3635 predictions = self .__tf_predict_fn ({
37- self .__input_key_image : np_image ,
38- self . __input_key_batch_size : 1 })
36+ self .__input_key_image : np_image
37+ })
3938
40- labels = [label .decode ('utf-8' ) for label in predictions [self .__output_key_labels ][0 ].tolist ()]
4139 confidences = predictions [self .__output_key_confidences ][0 ]
4240 top_prediction = predictions [self .__output_key_prediction ][0 ].decode ('utf-8' )
4341
44- return PredictionResult (labels = labels , confidences = confidences , prediction = top_prediction )
42+ return PredictionResult (labels = self . __labels , confidences = confidences , prediction = top_prediction )
0 commit comments