Skip to content

Commit

Permalink
chore: Update for new data generator
Browse files Browse the repository at this point in the history
All files were updated to work with the new `DataGenerator`.
  • Loading branch information
pierluigiferrari committed Mar 25, 2018
1 parent eed6632 commit 5d3659c
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 142 deletions.
94 changes: 52 additions & 42 deletions eval_utils/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from math import ceil
import sys

from ssd_box_utils.ssd_box_encode_decode_utils import decode_y
from data_generator.object_detection_2d_geometric_ops import Resize
from data_generator.object_detection_2d_patch_sampling_ops import RandomPadFixedAR
from data_generator.object_detection_2d_photometric_ops import ConvertTo3Channels
from ssd_encoder_decoder.ssd_output_decoder import decode_detections
from data_generator.object_detection_2d_misc_utils import apply_inverse_transforms

def get_coco_category_maps(annotations_file):
'''
Expand Down Expand Up @@ -61,9 +65,9 @@ def predict_all_to_json(out_file,
img_height,
img_width,
classes_to_cats,
batch_generator,
data_generator,
batch_size,
batch_generator_mode='resize',
data_generator_mode='resize',
model_mode='training',
confidence_thresh=0.01,
iou_threshold=0.45,
Expand All @@ -81,9 +85,9 @@ def predict_all_to_json(out_file,
img_width (int): The input image width for the model.
classes_to_cats (dict): A dictionary that maps the consecutive class IDs predicted by the model
to the non-consecutive original MS COCO category IDs.
batch_generator (BatchGenerator): A `BatchGenerator` object with the evaluation dataset.
data_generator (DataGenerator): A `DataGenerator` object with the evaluation dataset.
batch_size (int): The batch size for the evaluation.
batch_generator_mode (str, optional): Either of 'resize' or 'pad'. If 'resize', the input images will
data_generator_mode (str, optional): Either of 'resize' or 'pad'. If 'resize', the input images will
be resized (i.e. warped) to `(img_height, img_width)`. This mode does not preserve the aspect ratios of the images.
If 'pad', the input images will be first padded so that they have the aspect ratio defined by `img_height`
and `img_width` and then resized to `(img_height, img_width)`. This mode preserves the aspect ratios of the images.
Expand Down Expand Up @@ -113,75 +117,81 @@ def predict_all_to_json(out_file,
None.
'''

if batch_generator_mode == 'resize':
random_pad_and_resize=False
resize=(img_height,img_width)
elif batch_generator_mode == 'pad':
random_pad_and_resize=(img_height, img_width, 0, 3, 1.0)
resize=False
convert_to_3_channels = ConvertTo3Channels()
resize = Resize(height=img_height,width=img_width)
if data_generator_mode == 'resize':
transformations = [convert_to_3_channels,
resize]
elif data_generator_mode == 'pad':
random_pad = RandomPadFixedAR(patch_aspect_ratio=img_width/img_height, clip_boxes=False)
transformations = [convert_to_3_channels,
random_pad,
resize]
else:
raise ValueError("Unexpected argument value: `batch_generator_mode` can be either of 'resize' or 'pad', but received '{}'.".format(batch_generator_mode))
raise ValueError("Unexpected argument value: `data_generator_mode` can be either of 'resize' or 'pad', but received '{}'.".format(data_generator_mode))

# Set the generator parameters.
generator = batch_generator.generate(batch_size=batch_size,
shuffle=False,
train=False,
returns={'processed_images', 'image_ids', 'inverse_transform'},
convert_to_3_channels=True,
random_pad_and_resize=random_pad_and_resize,
resize=resize,
limit_boxes=False,
keep_images_without_gt=True)
generator = data_generator.generate(batch_size=batch_size,
shuffle=False,
transformations=transformations,
label_encoder=None,
returns={'processed_images',
'image_ids',
'inverse_transform'},
keep_images_without_gt=True)
# Put the results in this list.
results = []
# Compute the number of batches to iterate over the entire dataset.
n_images = batch_generator.get_n_samples()
n_images = data_generator.get_dataset_size()
print("Number of images in the evaluation dataset: {}".format(n_images))
n_batches = int(ceil(n_images / batch_size))
# Loop over all batches.
tr = trange(n_batches, file=sys.stdout)
tr.set_description('Producing results file')
for i in tr:
# Generate batch.
batch_X, batch_image_ids, batch_inverse_coord_transform = next(generator)
batch_X, batch_image_ids, batch_inverse_transforms = next(generator)
# Predict.
y_pred = model.predict(batch_X)
# If the model was created in 'training' mode, the raw predictions need to
# be decoded and filtered, otherwise that's already taken care of.
if model_mode == 'training':
# Decode.
y_pred = decode_y(y_pred,
confidence_thresh=confidence_thresh,
iou_threshold=iou_threshold,
top_k=top_k,
input_coords=pred_coords,
normalize_coords=normalize_coords,
img_height=img_height,
img_width=img_width)
y_pred = decode_detections(y_pred,
confidence_thresh=confidence_thresh,
iou_threshold=iou_threshold,
top_k=top_k,
input_coords=pred_coords,
normalize_coords=normalize_coords,
img_height=img_height,
img_width=img_width)
else:
# Filter out the all-zeros dummy elements of `y_pred`.
y_pred_filtered = []
for i in range(len(y_pred)):
y_pred_filtered.append(y_pred[i][y_pred[i,:,0] != 0])
y_pred = y_pred_filtered
# Convert the predicted box coordinates for the original images.
y_pred = apply_inverse_transforms(y_pred, batch_inverse_transforms)

# Convert each predicted box into the results format.
for k, batch_item in enumerate(y_pred):
# The box coordinates were predicted for the transformed
# (resized, cropped, padded, etc.) image. We now have to
# transform these coordinates back to what they would be
# in the original images.
batch_item[:,2:] *= batch_inverse_coord_transform[k,:,1]
batch_item[:,2:] += batch_inverse_coord_transform[k,:,0]
for box in batch_item:
class_id = box[0]
# Transform the consecutive class IDs back to the original COCO category IDs.
cat_id = classes_to_cats[class_id]
# Round the box coordinates to reduce the JSON file size.
xmin = round(box[2], 1)
ymin = round(box[3], 1)
xmax = round(box[4], 1)
ymax = round(box[5], 1)
xmin = float(round(box[2], 1))
ymin = float(round(box[3], 1))
xmax = float(round(box[4], 1))
ymax = float(round(box[5], 1))
width = xmax - xmin
height = ymax - ymin
bbox = [xmin, ymin, width, height]
result = {}
result['image_id'] = batch_image_ids[k]
result['category_id'] = cat_id
result['score'] = round(box[1], 3)
result['score'] = float(round(box[1], 3))
result['bbox'] = bbox
results.append(result)

Expand Down
68 changes: 39 additions & 29 deletions eval_utils/pascal_voc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@
from tqdm import trange
import sys

from ssd_box_utils.ssd_box_encode_decode_utils import decode_y
from data_generator.object_detection_2d_geometric_ops import Resize
from data_generator.object_detection_2d_patch_sampling_ops import RandomPadFixedAR
from data_generator.object_detection_2d_photometric_ops import ConvertTo3Channels
from ssd_encoder_decoder.ssd_output_decoder import decode_detections
from data_generator.object_detection_2d_misc_utils import apply_inverse_transforms

def predict_all_to_txt(model,
img_height,
img_width,
batch_generator,
data_generator,
batch_size,
batch_generator_mode='resize',
data_generator_mode='resize',
classes=['background',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat',
Expand All @@ -53,9 +57,9 @@ def predict_all_to_txt(model,
model (Keras model): A Keras SSD model object.
img_height (int): The input image height for the model.
img_width (int): The input image width for the model.
batch_generator (BatchGenerator): A `BatchGenerator` object with the evaluation dataset.
data_generator (DataGenerator): A `DataGenerator` object with the evaluation dataset.
batch_size (int): The batch size for the evaluation.
batch_generator_mode (str, optional): Either of 'resize' or 'pad'. If 'resize', the input images will
data_generator_mode (str, optional): Either of 'resize' or 'pad'. If 'resize', the input images will
be resized (i.e. warped) to `(img_height, img_width)`. This mode does not preserve the aspect ratios of the images.
If 'pad', the input images will be first padded so that they have the aspect ratio defined by `img_height`
and `img_width` and then resized to `(img_height, img_width)`. This mode preserves the aspect ratios of the images.
Expand Down Expand Up @@ -90,25 +94,28 @@ def predict_all_to_txt(model,
None.
'''

if batch_generator_mode == 'resize':
random_pad_and_resize=False
resize=(img_height,img_width)
elif batch_generator_mode == 'pad':
random_pad_and_resize=(img_height, img_width, 0, 3, 1.0)
resize=False
convert_to_3_channels = ConvertTo3Channels()
resize = Resize(height=img_height,width=img_width)
if data_generator_mode == 'resize':
transformations = [convert_to_3_channels,
resize]
elif data_generator_mode == 'pad':
random_pad = RandomPadFixedAR(patch_aspect_ratio=img_width/img_height, clip_boxes=False)
transformations = [convert_to_3_channels,
random_pad,
resize]
else:
raise ValueError("Unexpected argument value: `batch_generator_mode` can be either of 'resize' or 'pad', but received '{}'.".format(batch_generator_mode))
raise ValueError("Unexpected argument value: `data_generator_mode` can be either of 'resize' or 'pad', but received '{}'.".format(data_generator_mode))

# Set the generator parameters.
generator = batch_generator.generate(batch_size=batch_size,
shuffle=False,
train=False,
returns={'processed_images', 'image_ids', 'inverse_transform'},
convert_to_3_channels=True,
random_pad_and_resize=random_pad_and_resize,
resize=resize,
limit_boxes=False,
keep_images_without_gt=True)
generator = data_generator.generate(batch_size=batch_size,
shuffle=False,
transformations=transformations,
label_encoder=None,
returns={'processed_images',
'image_ids',
'inverse_transform'},
keep_images_without_gt=True)

# We have to generate a separate results file for each class.
results = []
Expand All @@ -117,15 +124,15 @@ def predict_all_to_txt(model,
results.append(open('{}{}.txt'.format(out_file_prefix, classes[i]), 'w'))

# Compute the number of batches to iterate over the entire dataset.
n_images = batch_generator.get_n_samples()
n_images = data_generator.get_dataset_size()
print("Number of images in the evaluation dataset: {}".format(n_images))
n_batches = int(ceil(n_images / batch_size))
# Loop over all batches.
tr = trange(n_batches, file=sys.stdout)
tr.set_description('Producing results files')
for j in tr:
# Generate batch.
batch_X, batch_image_ids, batch_inverse_coord_transform = next(generator)
batch_X, batch_image_ids, batch_inverse_transforms = next(generator)
# Predict.
y_pred = model.predict(batch_X)
# If the model was created in 'training' mode, the raw predictions need to
Expand All @@ -140,14 +147,17 @@ def predict_all_to_txt(model,
normalize_coords=normalize_coords,
img_height=img_height,
img_width=img_width)
else:
# Filter out the all-zeros dummy elements of `y_pred`.
y_pred_filtered = []
for i in range(len(y_pred)):
y_pred_filtered.append(y_pred[i][y_pred[i,:,0] != 0])
y_pred = y_pred_filtered
# Convert the predicted box coordinates for the original images.
y_pred = apply_inverse_transforms(y_pred, batch_inverse_transforms)

# Convert each predicted box into the results format.
for k, batch_item in enumerate(y_pred):
# The box coordinates were predicted for the transformed
# (resized, cropped, padded, etc.) image. We now have to
# transform these coordinates back to what they would be
# in the original images.
batch_item[:,2:] *= batch_inverse_coord_transform[k,:,1]
batch_item[:,2:] += batch_inverse_coord_transform[k,:,0]
for box in batch_item:
image_id = batch_image_ids[k]
class_id = int(box[0])
Expand Down
25 changes: 12 additions & 13 deletions ssd300_evaluation_COCO.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"from keras_layers.keras_layer_DecodeDetections import DecodeDetections\n",
"from keras_layers.keras_layer_DecodeDetections2 import DecodeDetections2\n",
"from keras_layers.keras_layer_L2Normalization import L2Normalization\n",
"from data_generator.ssd_batch_generator import BatchGenerator\n",
"from data_generator.object_detection_2d_data_generator import DataGenerator\n",
"from eval_utils.coco_utils import get_coco_category_maps, predict_all_to_json\n",
"\n",
"%matplotlib inline"
Expand Down Expand Up @@ -82,7 +82,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {
"collapsed": true
},
Expand All @@ -106,12 +106,12 @@
" two_boxes_for_ar1=True,\n",
" steps=[8, 16, 32, 64, 100, 300],\n",
" offsets=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5],\n",
" limit_boxes=False,\n",
" clip_boxes=False,\n",
" variances=[0.1, 0.1, 0.2, 0.2],\n",
" coords='centroids',\n",
" normalize_coords=True,\n",
" subtract_mean=[123, 117, 104],\n",
" swap_channels=True,\n",
" swap_channels=[2, 1, 0],\n",
" confidence_thresh=0.01,\n",
" iou_threshold=0.45,\n",
" top_k=200,\n",
Expand All @@ -120,7 +120,6 @@
"# 2: Load the trained weights into the model.\n",
"\n",
"# TODO: Set the path of the trained weights.\n",
"\n",
"weights_path = 'path/to/trained/weights/VGG_coco_SSD_300x300_iter_400000.h5'\n",
"\n",
"model.load_weights(weights_path, by_name=True)\n",
Expand Down Expand Up @@ -187,7 +186,7 @@
},
"outputs": [],
"source": [
"dataset = BatchGenerator(box_output_format=['class_id', 'xmin', 'ymin', 'xmax', 'ymax'])\n",
"dataset = DataGenerator(labels_output_format=['class_id', 'xmin', 'ymin', 'xmax', 'ymax'])\n",
"\n",
"# TODO: Set the paths to the dataset here.\n",
"MS_COCO_dataset_images_dir = '../../datasets/MicrosoftCOCO/val2017/'\n",
Expand Down Expand Up @@ -237,7 +236,7 @@
"output_type": "stream",
"text": [
"Number of images in the evaluation dataset: 5000\n",
"Producing results file: 100%|██████████| 250/250 [17:07<00:00, 4.29s/it]\n",
"Producing results file: 100%|██████████| 250/250 [04:11<00:00, 1.05s/it]\n",
"Prediction results saved in 'detections_val2017_ssd300_results.json'\n"
]
}
Expand All @@ -248,9 +247,9 @@
" img_height=img_height,\n",
" img_width=img_width,\n",
" classes_to_cats=classes_to_cats,\n",
" batch_generator=dataset,\n",
" data_generator=dataset,\n",
" batch_size=batch_size,\n",
" batch_generator_mode='resize',\n",
" data_generator_mode='resize',\n",
" model_mode='inference',\n",
" confidence_thresh=0.01,\n",
" iou_threshold=0.45,\n",
Expand Down Expand Up @@ -278,11 +277,11 @@
"output_type": "stream",
"text": [
"loading annotations into memory...\n",
"Done (t=0.41s)\n",
"Done (t=0.46s)\n",
"creating index...\n",
"index created!\n",
"Loading and preparing results...\n",
"DONE (t=5.34s)\n",
"DONE (t=5.87s)\n",
"creating index...\n",
"index created!\n"
]
Expand All @@ -305,9 +304,9 @@
"text": [
"Running per image evaluation...\n",
"Evaluate annotation type *bbox*\n",
"DONE (t=69.19s).\n",
"DONE (t=64.15s).\n",
"Accumulating evaluation results...\n",
"DONE (t=14.10s).\n",
"DONE (t=10.58s).\n",
" Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.247\n",
" Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.424\n",
" Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.253\n",
Expand Down
Loading

0 comments on commit 5d3659c

Please sign in to comment.