Skip to content

Commit 889ef98

Browse files
feat: Support evaluation with non-integer image IDs
Previously, the evaluator could only handle image IDs (and therefore image names) that consisted of digits only. Now it can handle arbitrary image names.
1 parent 0240cd6 commit 889ef98

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

eval_utils/average_precision_evaluator.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ def __call__(self,
218218
matching_iou_threshold=matching_iou_threshold,
219219
include_border_pixels=include_border_pixels,
220220
sorting_algorithm=sorting_algorithm,
221-
pred_format={'image_id': 0, 'conf': 1, 'xmin': 2, 'ymin': 3, 'xmax': 4, 'ymax': 5},
222221
verbose=verbose,
223222
ret=False)
224223

@@ -404,7 +403,7 @@ def predict_on_dataset(self,
404403
# Iterate over all batch items.
405404
for k, batch_item in enumerate(y_pred):
406405

407-
image_id = int(batch_image_ids[k])
406+
image_id = batch_image_ids[k]
408407

409408
for box in batch_item:
410409
class_id = int(box[class_id_pred])
@@ -417,13 +416,10 @@ def predict_on_dataset(self,
417416
ymin = round(box[ymin_pred], 1)
418417
xmax = round(box[xmax_pred], 1)
419418
ymax = round(box[ymax_pred], 1)
420-
prediction = [image_id, confidence, xmin, ymin, xmax, ymax]
419+
prediction = (image_id, confidence, xmin, ymin, xmax, ymax)
421420
# Append the predicted box to the results list for its class.
422421
results[class_id].append(prediction)
423422

424-
for i in range(self.n_classes + 1):
425-
results[i] = np.asarray(results[i])
426-
427423
self.prediction_results = results
428424

429425
if ret:
@@ -546,7 +542,6 @@ def match_predictions(self,
546542
matching_iou_threshold=0.5,
547543
include_border_pixels=True,
548544
sorting_algorithm='quicksort',
549-
pred_format={'image_id': 0, 'conf': 1, 'xmin': 2, 'ymin': 3, 'xmax': 4, 'ymax': 5},
550545
verbose=True,
551546
ret=False):
552547
'''
@@ -571,8 +566,6 @@ def match_predictions(self,
571566
The official Matlab evaluation algorithm uses a stable sorting algorithm, so this algorithm is only guaranteed
572567
to behave identically if you choose 'mergesort' as the sorting algorithm, but it will almost always behave identically
573568
even if you choose 'quicksort' (but no guarantees).
574-
pred_format (dict, optional): In what format to expect the predictions. This argument usually doesn't need be touched,
575-
because the default setting matches what `predict_on_dataset()` outputs.
576569
verbose (bool, optional): If `True`, will print out the progress during runtime.
577570
ret (bool, optional): If `True`, returns the true and false positives.
578571
@@ -587,13 +580,6 @@ def match_predictions(self,
587580
if self.prediction_results is None:
588581
raise ValueError("There are no prediction results. You must run `predict_on_dataset()` before calling this method.")
589582

590-
image_id_pred = pred_format['image_id']
591-
conf_pred = pred_format['conf']
592-
xmin_pred = pred_format['xmin']
593-
ymin_pred = pred_format['ymin']
594-
xmax_pred = pred_format['xmax']
595-
ymax_pred = pred_format['ymax']
596-
597583
class_id_gt = self.gt_format['class_id']
598584
xmin_gt = self.gt_format['xmin']
599585
ymin_gt = self.gt_format['ymin']
@@ -605,7 +591,7 @@ def match_predictions(self,
605591
ground_truth = {}
606592
eval_neutral_available = not (self.data_generator.eval_neutral is None) # Whether or not we have annotations to decide whether ground truth boxes should be neutral or not.
607593
for i in range(len(self.data_generator.image_ids)):
608-
image_id = int(self.data_generator.image_ids[i])
594+
image_id = str(self.data_generator.image_ids[i])
609595
labels = self.data_generator.labels[i]
610596
if ignore_neutral_boxes and eval_neutral_available:
611597
ground_truth[image_id] = (np.asarray(labels), np.asarray(self.data_generator.eval_neutral[i]))
@@ -623,25 +609,39 @@ def match_predictions(self,
623609
predictions = self.prediction_results[class_id]
624610

625611
# Store the matching results in these lists:
626-
true_pos = np.zeros(predictions.shape[0], dtype=np.int) # 1 for every prediction that is a true positive, 0 otherwise
627-
false_pos = np.zeros(predictions.shape[0], dtype=np.int) # 1 for every prediction that is a false positive, 0 otherwise
612+
true_pos = np.zeros(len(predictions), dtype=np.int) # 1 for every prediction that is a true positive, 0 otherwise
613+
false_pos = np.zeros(len(predictions), dtype=np.int) # 1 for every prediction that is a false positive, 0 otherwise
628614

629615
# In case there are no predictions at all for this class, we're done here.
630-
if predictions.size == 0:
616+
if len(predictions) == 0:
631617
print("No predictions for class {}/{}".format(class_id, self.n_classes))
632618
true_positives.append(true_pos)
633619
false_positives.append(false_pos)
634620
continue
635621

622+
# Convert the predictions list for this class into a structured array so that we can sort it by confidence.
623+
624+
# Get the number of characters needed to store the image ID strings in the structured array.
625+
num_chars_per_image_id = len(str(predictions[0][0])) + 6 # Keep a few characters buffer in case some image IDs are longer than others.
626+
# Create the data type for the structured array.
627+
preds_data_type = np.dtype([('image_id', 'U{}'.format(num_chars_per_image_id)),
628+
('confidence', 'f4'),
629+
('xmin', 'f4'),
630+
('ymin', 'f4'),
631+
('xmax', 'f4'),
632+
('ymax', 'f4')])
633+
# Create the structured array
634+
predictions = np.array(predictions, dtype=preds_data_type)
635+
636636
# Sort the detections by decreasing confidence.
637-
descending_indices = np.argsort(-predictions[:, conf_pred], axis=0, kind=sorting_algorithm)
637+
descending_indices = np.argsort(-predictions['confidence'], kind=sorting_algorithm)
638638
predictions_sorted = predictions[descending_indices]
639639

640640
if verbose:
641-
tr = trange(predictions.shape[0], file=sys.stdout)
641+
tr = trange(len(predictions), file=sys.stdout)
642642
tr.set_description("Matching predictions to ground truth, class {}/{}.".format(class_id, self.n_classes))
643643
else:
644-
tr = range(predictions.shape[0])
644+
tr = range(len(predictions.shape))
645645

646646
# Keep track of which ground truth boxes were already matched to a detection.
647647
gt_matched = {}
@@ -650,8 +650,8 @@ def match_predictions(self,
650650
for i in tr:
651651

652652
prediction = predictions_sorted[i]
653-
image_id = int(prediction[image_id_pred])
654-
pred_box = np.asarray(prediction[[conf_pred, xmin_pred, ymin_pred, xmax_pred, ymax_pred]], dtype=np.float)
653+
image_id = prediction['image_id']
654+
pred_box = np.asarray(list(prediction[['xmin', 'ymin', 'xmax', 'ymax']])) # Convert the structured array element to a regular array.
655655

656656
# Get the relevant ground truth boxes for this prediction,
657657
# i.e. all ground truth boxes that match the prediction's
@@ -677,7 +677,7 @@ def match_predictions(self,
677677

678678
# Compute the IoU of this prediction with all ground truth boxes of the same class.
679679
overlaps = iou(boxes1=gt[:,[xmin_gt, ymin_gt, xmax_gt, ymax_gt]],
680-
boxes2=pred_box[1:],
680+
boxes2=pred_box,
681681
coords='corners',
682682
mode='element-wise',
683683
include_border_pixels=include_border_pixels)

0 commit comments

Comments
 (0)