Skip to content

Commit c180cda

Browse files
ziyeqinghancopybara-github
authored andcommitted
add evaluate_tflite for object detection in Model Maker.
PiperOrigin-RevId: 362245976
1 parent 732de9e commit c180cda

File tree

3 files changed

+115
-11
lines changed

3 files changed

+115
-11
lines changed

tensorflow_examples/lite/model_maker/core/task/model_spec/object_detector_spec.py

Lines changed: 100 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import collections
1717
import os
1818
import tempfile
19+
from typing import Tuple, Dict
1920

2021
from absl import logging
2122
import tensorflow as tf
@@ -26,6 +27,7 @@
2627
from tensorflow_examples.lite.model_maker.third_party.efficientdet import hparams_config
2728
from tensorflow_examples.lite.model_maker.third_party.efficientdet import utils
2829
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import efficientdet_keras
30+
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import eval_tflite
2931
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import inference
3032
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import label_util
3133
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import postprocess
@@ -202,8 +204,10 @@ def train(self,
202204
validation_steps=validation_steps)
203205
return model
204206

205-
def evaluate(self, model, dataset, steps, json_file=None):
206-
"""Evaluate the EfficientDet keras model."""
207+
def _get_evaluator_and_label_map(
208+
self, json_file: str
209+
) -> Tuple[coco_metric.EvaluationMetric, collections.OrderedDict]:
210+
"""Gets evaluator and label_map for evaluation."""
207211
label_map = label_util.get_label_map(self.config.label_map)
208212
# Sorts label_map.keys since pycocotools.cocoeval uses sorted catIds
209213
# (category ids) in COCOeval class.
@@ -213,6 +217,43 @@ def evaluate(self, model, dataset, steps, json_file=None):
213217
filename=json_file, label_map=label_map)
214218

215219
evaluator.reset_states()
220+
return evaluator, label_map
221+
222+
def _get_metric_dict(self, evaluator: coco_metric.EvaluationMetric,
223+
label_map: collections.OrderedDict) -> Dict[str, float]:
224+
"""Gets the metric dict for evaluation."""
225+
metrics = evaluator.result()
226+
metric_dict = {}
227+
for i, name in enumerate(evaluator.metric_names):
228+
metric_dict[name] = metrics[i]
229+
230+
if label_map:
231+
for i, cid in enumerate(label_map.keys()):
232+
name = 'AP_/%s' % label_map[cid]
233+
metric_dict[name] = metrics[i + len(evaluator.metric_names)]
234+
return metric_dict
235+
236+
def evaluate(self,
237+
model: tf.keras.Model,
238+
dataset: tf.data.Dataset,
239+
steps: int,
240+
json_file: str = None) -> Dict[str, float]:
241+
"""Evaluate the EfficientDet keras model.
242+
243+
Args:
244+
model: The keras model to be evaluated.
245+
dataset: tf.data.Dataset used for evaluation.
246+
steps: Number of steps to evaluate the model.
247+
json_file: JSON with COCO data format containing golden bounding boxes.
248+
Used for validation. If None, use the ground truth from the dataloader.
249+
Refer to
250+
https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5
251+
for the description of COCO data format.
252+
253+
Returns:
254+
A dict contains AP metrics.
255+
"""
256+
evaluator, label_map = self._get_evaluator_and_label_map(json_file)
216257
dataset = dataset.take(steps)
217258

218259
@tf.function
@@ -228,18 +269,66 @@ def _get_detections(images, labels):
228269
], [])
229270

230271
dataset = self.ds_strategy.experimental_distribute_dataset(dataset)
231-
for (images, labels) in dataset:
272+
progbar = tf.keras.utils.Progbar(steps)
273+
for i, (images, labels) in enumerate(dataset):
232274
self.ds_strategy.run(_get_detections, (images, labels))
275+
progbar.update(i)
233276

234-
metrics = evaluator.result()
235-
metric_dict = {}
236-
for i, name in enumerate(evaluator.metric_names):
237-
metric_dict[name] = metrics[i]
277+
metric_dict = self._get_metric_dict(evaluator, label_map)
278+
return metric_dict
238279

239-
if label_map:
240-
for i, cid in enumerate(label_map.keys()):
241-
name = 'AP_/%s' % label_map[cid]
242-
metric_dict[name] = metrics[i + len(evaluator.metric_names)]
280+
def evaluate_tflite(self,
281+
tflite_filepath: str,
282+
dataset: tf.data.Dataset,
283+
steps: int,
284+
json_file: str = None) -> Dict[str, float]:
285+
"""Evaluate the EfficientDet TFLite model.
286+
287+
Args:
288+
tflite_filepath: File path to the TFLite model.
289+
dataset: tf.data.Dataset used for evaluation.
290+
steps: Number of steps to evaluate the model.
291+
json_file: JSON with COCO data format containing golden bounding boxes.
292+
Used for validation. If None, use the ground truth from the dataloader.
293+
Refer to
294+
https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5
295+
for the description of COCO data format.
296+
297+
Returns:
298+
A dict contains AP metrics.
299+
"""
300+
# TODO(b/182441458): Use the task library for evaluation instead once it
301+
# supports python interface.
302+
evaluator, label_map = self._get_evaluator_and_label_map(json_file)
303+
dataset = dataset.take(steps)
304+
305+
lite_runner = eval_tflite.LiteRunner(tflite_filepath, only_network=False)
306+
progbar = tf.keras.utils.Progbar(steps)
307+
for i, (images, labels) in enumerate(dataset):
308+
# Get the output result after post-processing NMS op.
309+
nms_boxes, nms_classes, nms_scores, _ = lite_runner.run(images)
310+
311+
# CLASS_OFFSET is used since label_id for `background` is 0 in label_map
312+
# while it's not actually included the model. We don't need to add the
313+
# offset in the Android application.
314+
nms_classes += postprocess.CLASS_OFFSET
315+
316+
height, width = utils.parse_image_size(self.config.image_size)
317+
normalize_factor = tf.constant([height, width, height, width],
318+
dtype=tf.float32)
319+
nms_boxes *= normalize_factor
320+
if labels['image_scales'] is not None:
321+
scales = tf.expand_dims(tf.expand_dims(labels['image_scales'], -1), -1)
322+
nms_boxes = nms_boxes * tf.cast(scales, nms_boxes.dtype)
323+
detections = postprocess.generate_detections_from_nms_output(
324+
nms_boxes, nms_classes, nms_scores, labels['source_ids'])
325+
326+
detections = postprocess.transform_detections(detections)
327+
evaluator.update_state(labels['groundtruth_data'].numpy(),
328+
detections.numpy())
329+
progbar.update(i)
330+
331+
metric_dict = self._get_metric_dict(evaluator, label_map)
243332
return metric_dict
244333

245334
def export_saved_model(self,

tensorflow_examples/lite/model_maker/core/task/object_detector.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515

1616
import os
1717
import tempfile
18+
from typing import Dict
1819

1920
import tensorflow as tf
2021
from tensorflow_examples.lite.model_maker.core import compat
22+
from tensorflow_examples.lite.model_maker.core.data_util import object_detector_dataloader
2123
from tensorflow_examples.lite.model_maker.core.export_format import ExportFormat
2224
from tensorflow_examples.lite.model_maker.core.task import custom_model
2325
from tensorflow_examples.lite.model_maker.core.task import model_spec as ms
@@ -161,6 +163,14 @@ def evaluate(self, data, batch_size=None):
161163
return self.model_spec.evaluate(self.model, ds, steps,
162164
data.annotations_json_file)
163165

166+
def evaluate_tflite(
167+
self, tflite_filepath: str,
168+
data: object_detector_dataloader.DataLoader) -> Dict[str, float]:
169+
"""Evaluate the TFLite model."""
170+
ds = data.gen_dataset(self.model_spec, batch_size=1, is_training=False)
171+
return self.model_spec.evaluate_tflite(tflite_filepath, ds, len(data),
172+
data.annotations_json_file)
173+
164174
def _export_saved_model(self, saved_model_dir):
165175
"""Saves the model to Tensorflow SavedModel."""
166176
self.model_spec.export_saved_model(saved_model_dir)

tensorflow_examples/lite/model_maker/core/task/object_detector_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def testEfficientDetLite0(self):
7575
'efficientdet_lite0_metadata.json')
7676
self.assertTrue(filecmp.cmp(json_output_file, expected_json_file))
7777

78+
# Evaluate the TFLite model.
79+
task.evaluate_tflite(output_path, data)
80+
self.assertIsInstance(metrics, dict)
81+
self.assertGreaterEqual(metrics['AP'], 0)
82+
7883
# Export the model to quantized TFLite model.
7984
# TODO(b/175173304): Skips the test for stable tensorflow 2.4 for now since
8085
# it fails. Will revert this change after TF upgrade.

0 commit comments

Comments
 (0)