16
16
import collections
17
17
import os
18
18
import tempfile
19
+ from typing import Tuple , Dict
19
20
20
21
from absl import logging
21
22
import tensorflow as tf
26
27
from tensorflow_examples .lite .model_maker .third_party .efficientdet import hparams_config
27
28
from tensorflow_examples .lite .model_maker .third_party .efficientdet import utils
28
29
from tensorflow_examples .lite .model_maker .third_party .efficientdet .keras import efficientdet_keras
30
+ from tensorflow_examples .lite .model_maker .third_party .efficientdet .keras import eval_tflite
29
31
from tensorflow_examples .lite .model_maker .third_party .efficientdet .keras import inference
30
32
from tensorflow_examples .lite .model_maker .third_party .efficientdet .keras import label_util
31
33
from tensorflow_examples .lite .model_maker .third_party .efficientdet .keras import postprocess
@@ -202,8 +204,10 @@ def train(self,
202
204
validation_steps = validation_steps )
203
205
return model
204
206
205
- def evaluate (self , model , dataset , steps , json_file = None ):
206
- """Evaluate the EfficientDet keras model."""
207
+ def _get_evaluator_and_label_map (
208
+ self , json_file : str
209
+ ) -> Tuple [coco_metric .EvaluationMetric , collections .OrderedDict ]:
210
+ """Gets evaluator and label_map for evaluation."""
207
211
label_map = label_util .get_label_map (self .config .label_map )
208
212
# Sorts label_map.keys since pycocotools.cocoeval uses sorted catIds
209
213
# (category ids) in COCOeval class.
@@ -213,6 +217,43 @@ def evaluate(self, model, dataset, steps, json_file=None):
213
217
filename = json_file , label_map = label_map )
214
218
215
219
evaluator .reset_states ()
220
+ return evaluator , label_map
221
+
222
+ def _get_metric_dict (self , evaluator : coco_metric .EvaluationMetric ,
223
+ label_map : collections .OrderedDict ) -> Dict [str , float ]:
224
+ """Gets the metric dict for evaluation."""
225
+ metrics = evaluator .result ()
226
+ metric_dict = {}
227
+ for i , name in enumerate (evaluator .metric_names ):
228
+ metric_dict [name ] = metrics [i ]
229
+
230
+ if label_map :
231
+ for i , cid in enumerate (label_map .keys ()):
232
+ name = 'AP_/%s' % label_map [cid ]
233
+ metric_dict [name ] = metrics [i + len (evaluator .metric_names )]
234
+ return metric_dict
235
+
236
+ def evaluate (self ,
237
+ model : tf .keras .Model ,
238
+ dataset : tf .data .Dataset ,
239
+ steps : int ,
240
+ json_file : str = None ) -> Dict [str , float ]:
241
+ """Evaluate the EfficientDet keras model.
242
+
243
+ Args:
244
+ model: The keras model to be evaluated.
245
+ dataset: tf.data.Dataset used for evaluation.
246
+ steps: Number of steps to evaluate the model.
247
+ json_file: JSON with COCO data format containing golden bounding boxes.
248
+ Used for validation. If None, use the ground truth from the dataloader.
249
+ Refer to
250
+ https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5
251
+ for the description of COCO data format.
252
+
253
+ Returns:
254
+ A dict contains AP metrics.
255
+ """
256
+ evaluator , label_map = self ._get_evaluator_and_label_map (json_file )
216
257
dataset = dataset .take (steps )
217
258
218
259
@tf .function
@@ -228,18 +269,66 @@ def _get_detections(images, labels):
228
269
], [])
229
270
230
271
dataset = self .ds_strategy .experimental_distribute_dataset (dataset )
231
- for (images , labels ) in dataset :
272
+ progbar = tf .keras .utils .Progbar (steps )
273
+ for i , (images , labels ) in enumerate (dataset ):
232
274
self .ds_strategy .run (_get_detections , (images , labels ))
275
+ progbar .update (i )
233
276
234
- metrics = evaluator .result ()
235
- metric_dict = {}
236
- for i , name in enumerate (evaluator .metric_names ):
237
- metric_dict [name ] = metrics [i ]
277
+ metric_dict = self ._get_metric_dict (evaluator , label_map )
278
+ return metric_dict
238
279
239
- if label_map :
240
- for i , cid in enumerate (label_map .keys ()):
241
- name = 'AP_/%s' % label_map [cid ]
242
- metric_dict [name ] = metrics [i + len (evaluator .metric_names )]
280
+ def evaluate_tflite (self ,
281
+ tflite_filepath : str ,
282
+ dataset : tf .data .Dataset ,
283
+ steps : int ,
284
+ json_file : str = None ) -> Dict [str , float ]:
285
+ """Evaluate the EfficientDet TFLite model.
286
+
287
+ Args:
288
+ tflite_filepath: File path to the TFLite model.
289
+ dataset: tf.data.Dataset used for evaluation.
290
+ steps: Number of steps to evaluate the model.
291
+ json_file: JSON with COCO data format containing golden bounding boxes.
292
+ Used for validation. If None, use the ground truth from the dataloader.
293
+ Refer to
294
+ https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5
295
+ for the description of COCO data format.
296
+
297
+ Returns:
298
+ A dict contains AP metrics.
299
+ """
300
+ # TODO(b/182441458): Use the task library for evaluation instead once it
301
+ # supports python interface.
302
+ evaluator , label_map = self ._get_evaluator_and_label_map (json_file )
303
+ dataset = dataset .take (steps )
304
+
305
+ lite_runner = eval_tflite .LiteRunner (tflite_filepath , only_network = False )
306
+ progbar = tf .keras .utils .Progbar (steps )
307
+ for i , (images , labels ) in enumerate (dataset ):
308
+ # Get the output result after post-processing NMS op.
309
+ nms_boxes , nms_classes , nms_scores , _ = lite_runner .run (images )
310
+
311
+ # CLASS_OFFSET is used since label_id for `background` is 0 in label_map
312
+ # while it's not actually included the model. We don't need to add the
313
+ # offset in the Android application.
314
+ nms_classes += postprocess .CLASS_OFFSET
315
+
316
+ height , width = utils .parse_image_size (self .config .image_size )
317
+ normalize_factor = tf .constant ([height , width , height , width ],
318
+ dtype = tf .float32 )
319
+ nms_boxes *= normalize_factor
320
+ if labels ['image_scales' ] is not None :
321
+ scales = tf .expand_dims (tf .expand_dims (labels ['image_scales' ], - 1 ), - 1 )
322
+ nms_boxes = nms_boxes * tf .cast (scales , nms_boxes .dtype )
323
+ detections = postprocess .generate_detections_from_nms_output (
324
+ nms_boxes , nms_classes , nms_scores , labels ['source_ids' ])
325
+
326
+ detections = postprocess .transform_detections (detections )
327
+ evaluator .update_state (labels ['groundtruth_data' ].numpy (),
328
+ detections .numpy ())
329
+ progbar .update (i )
330
+
331
+ metric_dict = self ._get_metric_dict (evaluator , label_map )
243
332
return metric_dict
244
333
245
334
def export_saved_model (self ,
0 commit comments