-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodelConfig.py
66 lines (51 loc) · 2.42 KB
/
modelConfig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import tensorflow as tf
import numpy as np
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
# lbael
category_index = label_map_util.create_category_index_from_labelmap(os.getcwd()+"/tf/ssd_mobilenet_v2_320x320_coco17_tpu-8/mscoco_label_map.pbtxt", use_display_name=True)
# model
tf.keras.backend.clear_session()
a = os.getcwd()+"/tf/ssd_mobilenet_v2_320x320_coco17_tpu-8/saved_model/"
model = tf.saved_model.load(f'{a}')
def run_inference_for_single_image(model, image):
# The input needs to be a tensor, convert it using `tf.convert_to_tensor`.
input_tensor = tf.convert_to_tensor(image)
# The model expects a batch of images, so add an axis with `tf.newaxis`.
input_tensor = input_tensor[tf.newaxis,...]
# Run inference
model_fn = model.signatures['serving_default']
output_dict = model_fn(input_tensor)
# All outputs are batches tensors.
# Convert to numpy arrays, and take index [0] to remove the batch dimension.
# We're only interested in the first num_detections.
num_detections = int(output_dict.pop('num_detections'))
output_dict = {key:value[0, :num_detections].numpy()
for key,value in output_dict.items()}
output_dict['num_detections'] = num_detections
# detection_classes should be ints.
output_dict['detection_classes'] = output_dict['detection_classes'].astype(np.int64)
# Handle models with masks:
if 'detection_masks' in output_dict:
# Reframe the the bbox mask to the image size.
detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
output_dict['detection_masks'], output_dict['detection_boxes'],
image.shape[0], image.shape[1])
detection_masks_reframed = tf.cast(detection_masks_reframed > 0.5,
tf.uint8)
output_dict['detection_masks_reframed'] = detection_masks_reframed.numpy()
return output_dict
def main(img_np):
output_dict = run_inference_for_single_image(model, img_np)
vis_util.visualize_boxes_and_labels_on_image_array(
img_np,
output_dict['detection_boxes'],
output_dict['detection_classes'],
output_dict['detection_scores'],
category_index,
instance_masks=output_dict.get('detection_masks_reframed', None),
use_normalized_coordinates=True,
line_thickness=3)
return img_np