From 9eca8f448cb42aa654829312769c9b00c99b248f Mon Sep 17 00:00:00 2001 From: Nicolai Wojke Date: Sun, 11 Feb 2018 20:06:00 +0100 Subject: [PATCH] Generate detections from frozen inference graph --- README.md | 26 +- generate_detections.py | 457 ----------------------------------- tools/freeze_model.py | 219 +++++++++++++++++ tools/generate_detections.py | 213 ++++++++++++++++ 4 files changed, 448 insertions(+), 467 deletions(-) delete mode 100644 generate_detections.py create mode 100644 tools/freeze_model.py create mode 100644 tools/generate_detections.py diff --git a/README.md b/README.md index c47c9bb9..12144512 100644 --- a/README.md +++ b/README.md @@ -65,17 +65,23 @@ The following example generates these features from standard MOT challenge detections. Again, we assume resources have been extracted to the repository root directory and MOT16 data is in `./MOT16`: ``` -python generate_detections.py - --model=resources/networks/mars-small128.ckpt \ +python tools/generate_detections.py + --model=resources/networks/mars-small128.pb \ --mot_dir=./MOT16/train \ --output_dir=./resources/detections/MOT16_train ``` -For each sequence of the MOT16 dataset, the output is stored as separate binary -file in NumPy native format. Each file contains an array of shape `Nx138`, -where N is the number of detections in the corresponding MOT sequence. -The first 10 columns of this array contain the raw MOT detection copied over -from the input file. The remaining 128 columns store the appearance descriptor. -The files generated by this command can be used as input for the +The model has been generated with TensorFlow 1.5. If you run into +incompatibility, re-export the frozen inference graph to obtain a new +`mars-small128.pb` that is compatible with your version: +``` +python tools/freeze_model.py +``` +The ``generate_detections.py`` stores for each sequence of the MOT16 dataset +a separate binary file in NumPy native format. Each file contains an array of +shape `Nx138`, where N is the number of detections in the corresponding MOT +sequence. The first 10 columns of this array contain the raw MOT detection +copied over from the input file. The remaining 128 columns store the appearance +descriptor. The files generated by this command can be used as input for the `deep_sort_app.py`. ## Highlevel overview of source files @@ -106,9 +112,9 @@ files. These can be computed from MOTChallenge detections using If you find this repo useful in your research, please consider citing the following papers: @inproceedings{Wojke2017simple, - title={Simple Online and Realtime Tracking with a Deep Association Metric}, author={Wojke, Nicolai and Bewley, Alex and Paulus, Dietrich}, booktitle={2017 IEEE International Conference on Image Processing (ICIP)}, + title={Simple Online and Realtime Tracking with a Deep Association Metric}, year={2017}, pages={3645--3649} } @@ -116,7 +122,7 @@ If you find this repo useful in your research, please consider citing the follow @inproceedings{Bewley2016_sort, author={Bewley, Alex and Ge, Zongyuan and Ott, Lionel and Ramos, Fabio and Upcroft, Ben}, booktitle={2016 IEEE International Conference on Image Processing (ICIP)}, - title={Simple online and realtime tracking}, + title={Simple Online and Realtime Tracking}, year={2016}, pages={3464-3468}, doi={10.1109/ICIP.2016.7533003} diff --git a/generate_detections.py b/generate_detections.py deleted file mode 100644 index 2da4d61b..00000000 --- a/generate_detections.py +++ /dev/null @@ -1,457 +0,0 @@ -# vim: expandtab:ts=4:sw=4 -import os -import errno -import argparse -import numpy as np -import cv2 - -import tensorflow as tf -import tensorflow.contrib.slim as slim - - -def _batch_norm_fn(x, scope=None): - if scope is None: - scope = tf.get_variable_scope().name + "/bn" - return slim.batch_norm(x, scope=scope) - - -def create_link( - incoming, network_builder, scope, nonlinearity=tf.nn.elu, - weights_initializer=tf.truncated_normal_initializer(stddev=1e-3), - regularizer=None, is_first=False, summarize_activations=True): - if is_first: - network = incoming - else: - network = _batch_norm_fn(incoming, scope=scope + "/bn") - network = nonlinearity(network) - if summarize_activations: - tf.summary.histogram(scope+"/activations", network) - - pre_block_network = network - post_block_network = network_builder(pre_block_network, scope) - - incoming_dim = pre_block_network.get_shape().as_list()[-1] - outgoing_dim = post_block_network.get_shape().as_list()[-1] - if incoming_dim != outgoing_dim: - assert outgoing_dim == 2 * incoming_dim, \ - "%d != %d" % (outgoing_dim, 2 * incoming) - projection = slim.conv2d( - incoming, outgoing_dim, 1, 2, padding="SAME", activation_fn=None, - scope=scope+"/projection", weights_initializer=weights_initializer, - biases_initializer=None, weights_regularizer=regularizer) - network = projection + post_block_network - else: - network = incoming + post_block_network - return network - - -def create_inner_block( - incoming, scope, nonlinearity=tf.nn.elu, - weights_initializer=tf.truncated_normal_initializer(1e-3), - bias_initializer=tf.zeros_initializer(), regularizer=None, - increase_dim=False, summarize_activations=True): - n = incoming.get_shape().as_list()[-1] - stride = 1 - if increase_dim: - n *= 2 - stride = 2 - - incoming = slim.conv2d( - incoming, n, [3, 3], stride, activation_fn=nonlinearity, padding="SAME", - normalizer_fn=_batch_norm_fn, weights_initializer=weights_initializer, - biases_initializer=bias_initializer, weights_regularizer=regularizer, - scope=scope + "/1") - if summarize_activations: - tf.summary.histogram(incoming.name + "/activations", incoming) - - incoming = slim.dropout(incoming, keep_prob=0.6) - - incoming = slim.conv2d( - incoming, n, [3, 3], 1, activation_fn=None, padding="SAME", - normalizer_fn=None, weights_initializer=weights_initializer, - biases_initializer=bias_initializer, weights_regularizer=regularizer, - scope=scope + "/2") - return incoming - - -def residual_block(incoming, scope, nonlinearity=tf.nn.elu, - weights_initializer=tf.truncated_normal_initializer(1e3), - bias_initializer=tf.zeros_initializer(), regularizer=None, - increase_dim=False, is_first=False, - summarize_activations=True): - - def network_builder(x, s): - return create_inner_block( - x, s, nonlinearity, weights_initializer, bias_initializer, - regularizer, increase_dim, summarize_activations) - - return create_link( - incoming, network_builder, scope, nonlinearity, weights_initializer, - regularizer, is_first, summarize_activations) - - -def _create_network(incoming, num_classes, reuse=None, l2_normalize=True, - create_summaries=True, weight_decay=1e-8): - nonlinearity = tf.nn.elu - conv_weight_init = tf.truncated_normal_initializer(stddev=1e-3) - conv_bias_init = tf.zeros_initializer() - conv_regularizer = slim.l2_regularizer(weight_decay) - fc_weight_init = tf.truncated_normal_initializer(stddev=1e-3) - fc_bias_init = tf.zeros_initializer() - fc_regularizer = slim.l2_regularizer(weight_decay) - - def batch_norm_fn(x): - return slim.batch_norm(x, scope=tf.get_variable_scope().name + "/bn") - - network = incoming - network = slim.conv2d( - network, 32, [3, 3], stride=1, activation_fn=nonlinearity, - padding="SAME", normalizer_fn=batch_norm_fn, scope="conv1_1", - weights_initializer=conv_weight_init, biases_initializer=conv_bias_init, - weights_regularizer=conv_regularizer) - if create_summaries: - tf.summary.histogram(network.name + "/activations", network) - tf.summary.image("conv1_1/weights", tf.transpose( - slim.get_variables("conv1_1/weights:0")[0], [3, 0, 1, 2]), - max_images=128) - network = slim.conv2d( - network, 32, [3, 3], stride=1, activation_fn=nonlinearity, - padding="SAME", normalizer_fn=batch_norm_fn, scope="conv1_2", - weights_initializer=conv_weight_init, biases_initializer=conv_bias_init, - weights_regularizer=conv_regularizer) - if create_summaries: - tf.summary.histogram(network.name + "/activations", network) - - # NOTE(nwojke): This is missing a padding="SAME" to match the CNN - # architecture in Table 1 of the paper. Information on how this affects - # performance on MOT 16 training sequences can be found in - # issue 10 https://github.com/nwojke/deep_sort/issues/10 - network = slim.max_pool2d(network, [3, 3], [2, 2], scope="pool1") - - network = residual_block( - network, "conv2_1", nonlinearity, conv_weight_init, conv_bias_init, - conv_regularizer, increase_dim=False, is_first=True, - summarize_activations=create_summaries) - network = residual_block( - network, "conv2_3", nonlinearity, conv_weight_init, conv_bias_init, - conv_regularizer, increase_dim=False, - summarize_activations=create_summaries) - - network = residual_block( - network, "conv3_1", nonlinearity, conv_weight_init, conv_bias_init, - conv_regularizer, increase_dim=True, - summarize_activations=create_summaries) - network = residual_block( - network, "conv3_3", nonlinearity, conv_weight_init, conv_bias_init, - conv_regularizer, increase_dim=False, - summarize_activations=create_summaries) - - network = residual_block( - network, "conv4_1", nonlinearity, conv_weight_init, conv_bias_init, - conv_regularizer, increase_dim=True, - summarize_activations=create_summaries) - network = residual_block( - network, "conv4_3", nonlinearity, conv_weight_init, conv_bias_init, - conv_regularizer, increase_dim=False, - summarize_activations=create_summaries) - - feature_dim = network.get_shape().as_list()[-1] - print("feature dimensionality: ", feature_dim) - network = slim.flatten(network) - - network = slim.dropout(network, keep_prob=0.6) - network = slim.fully_connected( - network, feature_dim, activation_fn=nonlinearity, - normalizer_fn=batch_norm_fn, weights_regularizer=fc_regularizer, - scope="fc1", weights_initializer=fc_weight_init, - biases_initializer=fc_bias_init) - - features = network - - if l2_normalize: - # Features in rows, normalize axis 1. - features = slim.batch_norm(features, scope="ball", reuse=reuse) - feature_norm = tf.sqrt( - tf.constant(1e-8, tf.float32) + - tf.reduce_sum(tf.square(features), [1], keep_dims=True)) - features = features / feature_norm - - with slim.variable_scope.variable_scope("ball", reuse=reuse): - weights = slim.model_variable( - "mean_vectors", (feature_dim, num_classes), - initializer=tf.truncated_normal_initializer(stddev=1e-3), - regularizer=None) - scale = slim.model_variable( - "scale", (num_classes, ), tf.float32, - tf.constant_initializer(0., tf.float32), regularizer=None) - if create_summaries: - tf.summary.histogram("scale", scale) - # scale = slim.model_variable( - # "scale", (), tf.float32, - # initializer=tf.constant_initializer(0., tf.float32), - # regularizer=slim.l2_regularizer(1e-2)) - # if create_summaries: - # tf.scalar_summary("scale", scale) - scale = tf.nn.softplus(scale) - - # Each mean vector in columns, normalize axis 0. - weight_norm = tf.sqrt( - tf.constant(1e-8, tf.float32) + - tf.reduce_sum(tf.square(weights), [0], keep_dims=True)) - logits = scale * tf.matmul(features, weights / weight_norm) - - else: - logits = slim.fully_connected( - features, num_classes, activation_fn=None, - normalizer_fn=None, weights_regularizer=fc_regularizer, - scope="softmax", weights_initializer=fc_weight_init, - biases_initializer=fc_bias_init) - - return features, logits - - -def _network_factory(num_classes, is_training, weight_decay=1e-8): - - def factory_fn(image, reuse, l2_normalize): - with slim.arg_scope([slim.batch_norm, slim.dropout], - is_training=is_training): - with slim.arg_scope([slim.conv2d, slim.fully_connected, - slim.batch_norm, slim.layer_norm], - reuse=reuse): - features, logits = _create_network( - image, num_classes, l2_normalize=l2_normalize, - reuse=reuse, create_summaries=is_training, - weight_decay=weight_decay) - return features, logits - - return factory_fn - - -def _preprocess(image, is_training=False, enable_more_augmentation=True): - image = image[:, :, ::-1] # BGR to RGB - if is_training: - image = tf.image.random_flip_left_right(image) - if enable_more_augmentation: - image = tf.image.random_brightness(image, max_delta=50) - image = tf.image.random_contrast(image, lower=0.8, upper=1.2) - image = tf.image.random_saturation(image, lower=0.8, upper=1.2) - return image - - -def _run_in_batches(f, data_dict, out, batch_size): - data_len = len(out) - num_batches = int(data_len / batch_size) - - s, e = 0, 0 - for i in range(num_batches): - s, e = i * batch_size, (i + 1) * batch_size - batch_data_dict = {k: v[s:e] for k, v in data_dict.items()} - out[s:e] = f(batch_data_dict) - if e < len(out): - batch_data_dict = {k: v[e:] for k, v in data_dict.items()} - out[e:] = f(batch_data_dict) - - -def extract_image_patch(image, bbox, patch_shape): - """Extract image patch from bounding box. - - Parameters - ---------- - image : ndarray - The full image. - bbox : array_like - The bounding box in format (x, y, width, height). - patch_shape : Optional[array_like] - This parameter can be used to enforce a desired patch shape - (height, width). First, the `bbox` is adapted to the aspect ratio - of the patch shape, then it is clipped at the image boundaries. - If None, the shape is computed from :arg:`bbox`. - - Returns - ------- - ndarray | NoneType - An image patch showing the :arg:`bbox`, optionally reshaped to - :arg:`patch_shape`. - Returns None if the bounding box is empty or fully outside of the image - boundaries. - - """ - bbox = np.array(bbox) - if patch_shape is not None: - # correct aspect ratio to patch shape - target_aspect = float(patch_shape[1]) / patch_shape[0] - new_width = target_aspect * bbox[3] - bbox[0] -= (new_width - bbox[2]) / 2 - bbox[2] = new_width - - # convert to top left, bottom right - bbox[2:] += bbox[:2] - bbox = bbox.astype(np.int) - - # clip at image boundaries - bbox[:2] = np.maximum(0, bbox[:2]) - bbox[2:] = np.minimum(np.asarray(image.shape[:2][::-1]) - 1, bbox[2:]) - if np.any(bbox[:2] >= bbox[2:]): - return None - sx, sy, ex, ey = bbox - image = image[sy:ey, sx:ex] - image = cv2.resize(image, patch_shape[::-1]) - - return image - - -def _create_image_encoder(preprocess_fn, factory_fn, image_shape, batch_size=32, - session=None, checkpoint_path=None, - loss_mode="cosine"): - image_var = tf.placeholder(tf.uint8, (None, ) + image_shape) - - preprocessed_image_var = tf.map_fn( - lambda x: preprocess_fn(x, is_training=False), - tf.cast(image_var, tf.float32)) - - l2_normalize = loss_mode == "cosine" - feature_var, _ = factory_fn( - preprocessed_image_var, l2_normalize=l2_normalize, reuse=None) - feature_dim = feature_var.get_shape().as_list()[-1] - - if session is None: - session = tf.Session() - if checkpoint_path is not None: - slim.get_or_create_global_step() - init_assign_op, init_feed_dict = slim.assign_from_checkpoint( - checkpoint_path, slim.get_variables_to_restore()) - session.run(init_assign_op, feed_dict=init_feed_dict) - - def encoder(data_x): - out = np.zeros((len(data_x), feature_dim), np.float32) - _run_in_batches( - lambda x: session.run(feature_var, feed_dict=x), - {image_var: data_x}, out, batch_size) - return out - - return encoder - - -def create_image_encoder(model_filename, batch_size=32, loss_mode="cosine", - session=None): - image_shape = 128, 64, 3 - factory_fn = _network_factory( - num_classes=1501, is_training=False, weight_decay=1e-8) - - return _create_image_encoder( - _preprocess, factory_fn, image_shape, batch_size, session, - model_filename, loss_mode) - - -def create_box_encoder(model_filename, batch_size=32, loss_mode="cosine"): - image_shape = 128, 64, 3 - image_encoder = create_image_encoder(model_filename, batch_size, loss_mode) - - def encoder(image, boxes): - image_patches = [] - for box in boxes: - patch = extract_image_patch(image, box, image_shape[:2]) - if patch is None: - print("WARNING: Failed to extract image patch: %s." % str(box)) - patch = np.random.uniform( - 0., 255., image_shape).astype(np.uint8) - image_patches.append(patch) - image_patches = np.asarray(image_patches) - return image_encoder(image_patches) - - return encoder - - -def generate_detections(encoder, mot_dir, output_dir, detection_dir=None): - """Generate detections with features. - - Parameters - ---------- - encoder : Callable[image, ndarray] -> ndarray - The encoder function takes as input a BGR color image and a matrix of - bounding boxes in format `(x, y, w, h)` and returns a matrix of - corresponding feature vectors. - mot_dir : str - Path to the MOTChallenge directory (can be either train or test). - output_dir - Path to the output directory. Will be created if it does not exist. - detection_dir - Path to custom detections. The directory structure should be the default - MOTChallenge structure: `[sequence]/det/det.txt`. If None, uses the - standard MOTChallenge detections. - - """ - if detection_dir is None: - detection_dir = mot_dir - try: - os.makedirs(output_dir) - except OSError as exception: - if exception.errno == errno.EEXIST and os.path.isdir(output_dir): - pass - else: - raise ValueError( - "Failed to created output directory '%s'" % output_dir) - - for sequence in os.listdir(mot_dir): - print("Processing %s" % sequence) - sequence_dir = os.path.join(mot_dir, sequence) - - image_dir = os.path.join(sequence_dir, "img1") - image_filenames = { - int(os.path.splitext(f)[0]): os.path.join(image_dir, f) - for f in os.listdir(image_dir)} - - detection_file = os.path.join( - detection_dir, sequence, "det/det.txt") - detections_in = np.loadtxt(detection_file, delimiter=',') - detections_out = [] - - frame_indices = detections_in[:, 0].astype(np.int) - min_frame_idx = frame_indices.astype(np.int).min() - max_frame_idx = frame_indices.astype(np.int).max() - for frame_idx in range(min_frame_idx, max_frame_idx + 1): - print("Frame %05d/%05d" % (frame_idx, max_frame_idx)) - mask = frame_indices == frame_idx - rows = detections_in[mask] - - if frame_idx not in image_filenames: - print("WARNING could not find image for frame %d" % frame_idx) - continue - bgr_image = cv2.imread( - image_filenames[frame_idx], cv2.IMREAD_COLOR) - features = encoder(bgr_image, rows[:, 2:6].copy()) - detections_out += [np.r_[(row, feature)] for row, feature - in zip(rows, features)] - - output_filename = os.path.join(output_dir, "%s.npy" % sequence) - np.save( - output_filename, np.asarray(detections_out), allow_pickle=False) - - -def parse_args(): - """Parse command line arguments. - """ - parser = argparse.ArgumentParser(description="Re-ID feature extractor") - parser.add_argument( - "--model", - default="resources/networks/mars-small128.ckpt-68577", - help="Path to checkpoint file") - parser.add_argument( - "--loss_mode", default="cosine", help="Network loss training mode") - parser.add_argument( - "--mot_dir", help="Path to MOTChallenge directory (train or test)", - required=True) - parser.add_argument( - "--detection_dir", help="Path to custom detections. Defaults to " - "standard MOT detections Directory structure should be the default " - "MOTChallenge structure: [sequence]/det/det.txt", default=None) - parser.add_argument( - "--output_dir", help="Output directory. Will be created if it does not" - " exist.", default="detections") - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - f = create_box_encoder(args.model, batch_size=32, loss_mode=args.loss_mode) - generate_detections(f, args.mot_dir, args.output_dir, args.detection_dir) diff --git a/tools/freeze_model.py b/tools/freeze_model.py new file mode 100644 index 00000000..e89ad290 --- /dev/null +++ b/tools/freeze_model.py @@ -0,0 +1,219 @@ +# vim: expandtab:ts=4:sw=4 +import argparse +import tensorflow as tf +import tensorflow.contrib.slim as slim + + +def _batch_norm_fn(x, scope=None): + if scope is None: + scope = tf.get_variable_scope().name + "/bn" + return slim.batch_norm(x, scope=scope) + + +def create_link( + incoming, network_builder, scope, nonlinearity=tf.nn.elu, + weights_initializer=tf.truncated_normal_initializer(stddev=1e-3), + regularizer=None, is_first=False, summarize_activations=True): + if is_first: + network = incoming + else: + network = _batch_norm_fn(incoming, scope=scope + "/bn") + network = nonlinearity(network) + if summarize_activations: + tf.summary.histogram(scope+"/activations", network) + + pre_block_network = network + post_block_network = network_builder(pre_block_network, scope) + + incoming_dim = pre_block_network.get_shape().as_list()[-1] + outgoing_dim = post_block_network.get_shape().as_list()[-1] + if incoming_dim != outgoing_dim: + assert outgoing_dim == 2 * incoming_dim, \ + "%d != %d" % (outgoing_dim, 2 * incoming) + projection = slim.conv2d( + incoming, outgoing_dim, 1, 2, padding="SAME", activation_fn=None, + scope=scope+"/projection", weights_initializer=weights_initializer, + biases_initializer=None, weights_regularizer=regularizer) + network = projection + post_block_network + else: + network = incoming + post_block_network + return network + + +def create_inner_block( + incoming, scope, nonlinearity=tf.nn.elu, + weights_initializer=tf.truncated_normal_initializer(1e-3), + bias_initializer=tf.zeros_initializer(), regularizer=None, + increase_dim=False, summarize_activations=True): + n = incoming.get_shape().as_list()[-1] + stride = 1 + if increase_dim: + n *= 2 + stride = 2 + + incoming = slim.conv2d( + incoming, n, [3, 3], stride, activation_fn=nonlinearity, padding="SAME", + normalizer_fn=_batch_norm_fn, weights_initializer=weights_initializer, + biases_initializer=bias_initializer, weights_regularizer=regularizer, + scope=scope + "/1") + if summarize_activations: + tf.summary.histogram(incoming.name + "/activations", incoming) + + incoming = slim.dropout(incoming, keep_prob=0.6) + + incoming = slim.conv2d( + incoming, n, [3, 3], 1, activation_fn=None, padding="SAME", + normalizer_fn=None, weights_initializer=weights_initializer, + biases_initializer=bias_initializer, weights_regularizer=regularizer, + scope=scope + "/2") + return incoming + + +def residual_block(incoming, scope, nonlinearity=tf.nn.elu, + weights_initializer=tf.truncated_normal_initializer(1e3), + bias_initializer=tf.zeros_initializer(), regularizer=None, + increase_dim=False, is_first=False, + summarize_activations=True): + + def network_builder(x, s): + return create_inner_block( + x, s, nonlinearity, weights_initializer, bias_initializer, + regularizer, increase_dim, summarize_activations) + + return create_link( + incoming, network_builder, scope, nonlinearity, weights_initializer, + regularizer, is_first, summarize_activations) + + +def _create_network(incoming, reuse=None, weight_decay=1e-8): + nonlinearity = tf.nn.elu + conv_weight_init = tf.truncated_normal_initializer(stddev=1e-3) + conv_bias_init = tf.zeros_initializer() + conv_regularizer = slim.l2_regularizer(weight_decay) + fc_weight_init = tf.truncated_normal_initializer(stddev=1e-3) + fc_bias_init = tf.zeros_initializer() + fc_regularizer = slim.l2_regularizer(weight_decay) + + def batch_norm_fn(x): + return slim.batch_norm(x, scope=tf.get_variable_scope().name + "/bn") + + network = incoming + network = slim.conv2d( + network, 32, [3, 3], stride=1, activation_fn=nonlinearity, + padding="SAME", normalizer_fn=batch_norm_fn, scope="conv1_1", + weights_initializer=conv_weight_init, biases_initializer=conv_bias_init, + weights_regularizer=conv_regularizer) + network = slim.conv2d( + network, 32, [3, 3], stride=1, activation_fn=nonlinearity, + padding="SAME", normalizer_fn=batch_norm_fn, scope="conv1_2", + weights_initializer=conv_weight_init, biases_initializer=conv_bias_init, + weights_regularizer=conv_regularizer) + + # NOTE(nwojke): This is missing a padding="SAME" to match the CNN + # architecture in Table 1 of the paper. Information on how this affects + # performance on MOT 16 training sequences can be found in + # issue 10 https://github.com/nwojke/deep_sort/issues/10 + network = slim.max_pool2d(network, [3, 3], [2, 2], scope="pool1") + + network = residual_block( + network, "conv2_1", nonlinearity, conv_weight_init, conv_bias_init, + conv_regularizer, increase_dim=False, is_first=True) + network = residual_block( + network, "conv2_3", nonlinearity, conv_weight_init, conv_bias_init, + conv_regularizer, increase_dim=False) + + network = residual_block( + network, "conv3_1", nonlinearity, conv_weight_init, conv_bias_init, + conv_regularizer, increase_dim=True) + network = residual_block( + network, "conv3_3", nonlinearity, conv_weight_init, conv_bias_init, + conv_regularizer, increase_dim=False) + + network = residual_block( + network, "conv4_1", nonlinearity, conv_weight_init, conv_bias_init, + conv_regularizer, increase_dim=True) + network = residual_block( + network, "conv4_3", nonlinearity, conv_weight_init, conv_bias_init, + conv_regularizer, increase_dim=False) + + feature_dim = network.get_shape().as_list()[-1] + network = slim.flatten(network) + + network = slim.dropout(network, keep_prob=0.6) + network = slim.fully_connected( + network, feature_dim, activation_fn=nonlinearity, + normalizer_fn=batch_norm_fn, weights_regularizer=fc_regularizer, + scope="fc1", weights_initializer=fc_weight_init, + biases_initializer=fc_bias_init) + + features = network + + # Features in rows, normalize axis 1. + features = slim.batch_norm(features, scope="ball", reuse=reuse) + feature_norm = tf.sqrt( + tf.constant(1e-8, tf.float32) + + tf.reduce_sum(tf.square(features), [1], keepdims=True)) + features = features / feature_norm + return features, None + + +def _network_factory(weight_decay=1e-8): + + def factory_fn(image, reuse): + with slim.arg_scope([slim.batch_norm, slim.dropout], + is_training=False): + with slim.arg_scope([slim.conv2d, slim.fully_connected, + slim.batch_norm, slim.layer_norm], + reuse=reuse): + features, logits = _create_network( + image, reuse=reuse, weight_decay=weight_decay) + return features, logits + + return factory_fn + + +def _preprocess(image): + image = image[:, :, ::-1] # BGR to RGB + return image + + +def parse_args(): + """Parse command line arguments. + """ + parser = argparse.ArgumentParser(description="Freeze old model") + parser.add_argument( + "--checkpoint_in", + default="resources/networks/mars-small128.ckpt-68577", + help="Path to checkpoint file") + parser.add_argument( + "--graphdef_out", + default="resources/networks/mars-small128.pb") + return parser.parse_args() + + +def main(): + args = parse_args() + + with tf.Session(graph=tf.Graph()) as session: + input_var = tf.placeholder( + tf.uint8, (None, 128, 64, 3), name="images") + image_var = tf.map_fn( + lambda x: _preprocess(x), tf.cast(input_var, tf.float32), + back_prop=False) + + factory_fn = _network_factory() + features, _ = factory_fn(image_var, reuse=None) + features = tf.identity(features, name="features") + + saver = tf.train.Saver(slim.get_variables_to_restore()) + saver.restore(session, args.checkpoint_in) + + output_graph_def = tf.graph_util.convert_variables_to_constants( + session, tf.get_default_graph().as_graph_def(), + [features.name.split(":")[0]]) + with tf.gfile.GFile(args.graphdef_out, "wb") as file_handle: + file_handle.write(output_graph_def.SerializeToString()) + + +if __name__ == "__main__": + main() diff --git a/tools/generate_detections.py b/tools/generate_detections.py new file mode 100644 index 00000000..c7192c26 --- /dev/null +++ b/tools/generate_detections.py @@ -0,0 +1,213 @@ +# vim: expandtab:ts=4:sw=4 +import os +import errno +import argparse +import numpy as np +import cv2 +import tensorflow as tf + + +def _run_in_batches(f, data_dict, out, batch_size): + data_len = len(out) + num_batches = int(data_len / batch_size) + + s, e = 0, 0 + for i in range(num_batches): + s, e = i * batch_size, (i + 1) * batch_size + batch_data_dict = {k: v[s:e] for k, v in data_dict.items()} + out[s:e] = f(batch_data_dict) + if e < len(out): + batch_data_dict = {k: v[e:] for k, v in data_dict.items()} + out[e:] = f(batch_data_dict) + + +def extract_image_patch(image, bbox, patch_shape): + """Extract image patch from bounding box. + + Parameters + ---------- + image : ndarray + The full image. + bbox : array_like + The bounding box in format (x, y, width, height). + patch_shape : Optional[array_like] + This parameter can be used to enforce a desired patch shape + (height, width). First, the `bbox` is adapted to the aspect ratio + of the patch shape, then it is clipped at the image boundaries. + If None, the shape is computed from :arg:`bbox`. + + Returns + ------- + ndarray | NoneType + An image patch showing the :arg:`bbox`, optionally reshaped to + :arg:`patch_shape`. + Returns None if the bounding box is empty or fully outside of the image + boundaries. + + """ + bbox = np.array(bbox) + if patch_shape is not None: + # correct aspect ratio to patch shape + target_aspect = float(patch_shape[1]) / patch_shape[0] + new_width = target_aspect * bbox[3] + bbox[0] -= (new_width - bbox[2]) / 2 + bbox[2] = new_width + + # convert to top left, bottom right + bbox[2:] += bbox[:2] + bbox = bbox.astype(np.int) + + # clip at image boundaries + bbox[:2] = np.maximum(0, bbox[:2]) + bbox[2:] = np.minimum(np.asarray(image.shape[:2][::-1]) - 1, bbox[2:]) + if np.any(bbox[:2] >= bbox[2:]): + return None + sx, sy, ex, ey = bbox + image = image[sy:ey, sx:ex] + image = cv2.resize(image, tuple(patch_shape[::-1])) + return image + + +class ImageEncoder(object): + + def __init__(self, checkpoint_filename, input_name="images", + output_name="features"): + self.session = tf.Session() + with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle: + graph_def = tf.GraphDef() + graph_def.ParseFromString(file_handle.read()) + tf.import_graph_def(graph_def, name="net") + self.input_var = tf.get_default_graph().get_tensor_by_name( + "net/%s:0" % input_name) + self.output_var = tf.get_default_graph().get_tensor_by_name( + "net/%s:0" % output_name) + + assert len(self.output_var.get_shape()) == 2 + assert len(self.input_var.get_shape()) == 4 + self.feature_dim = self.output_var.get_shape().as_list()[-1] + self.image_shape = self.input_var.get_shape().as_list()[1:] + + def __call__(self, data_x, batch_size=32): + out = np.zeros((len(data_x), self.feature_dim), np.float32) + _run_in_batches( + lambda x: self.session.run(self.output_var, feed_dict=x), + {self.input_var: data_x}, out, batch_size) + return out + + +def create_box_encoder(model_filename, input_name="images", + output_name="features", batch_size=32): + image_encoder = ImageEncoder(model_filename, input_name, output_name) + image_shape = image_encoder.image_shape + + def encoder(image, boxes): + image_patches = [] + for box in boxes: + patch = extract_image_patch(image, box, image_shape[:2]) + if patch is None: + print("WARNING: Failed to extract image patch: %s." % str(box)) + patch = np.random.uniform( + 0., 255., image_shape).astype(np.uint8) + image_patches.append(patch) + image_patches = np.asarray(image_patches) + return image_encoder(image_patches, batch_size) + + return encoder + + +def generate_detections(encoder, mot_dir, output_dir, detection_dir=None): + """Generate detections with features. + + Parameters + ---------- + encoder : Callable[image, ndarray] -> ndarray + The encoder function takes as input a BGR color image and a matrix of + bounding boxes in format `(x, y, w, h)` and returns a matrix of + corresponding feature vectors. + mot_dir : str + Path to the MOTChallenge directory (can be either train or test). + output_dir + Path to the output directory. Will be created if it does not exist. + detection_dir + Path to custom detections. The directory structure should be the default + MOTChallenge structure: `[sequence]/det/det.txt`. If None, uses the + standard MOTChallenge detections. + + """ + if detection_dir is None: + detection_dir = mot_dir + try: + os.makedirs(output_dir) + except OSError as exception: + if exception.errno == errno.EEXIST and os.path.isdir(output_dir): + pass + else: + raise ValueError( + "Failed to created output directory '%s'" % output_dir) + + for sequence in os.listdir(mot_dir): + print("Processing %s" % sequence) + sequence_dir = os.path.join(mot_dir, sequence) + + image_dir = os.path.join(sequence_dir, "img1") + image_filenames = { + int(os.path.splitext(f)[0]): os.path.join(image_dir, f) + for f in os.listdir(image_dir)} + + detection_file = os.path.join( + detection_dir, sequence, "det/det.txt") + detections_in = np.loadtxt(detection_file, delimiter=',') + detections_out = [] + + frame_indices = detections_in[:, 0].astype(np.int) + min_frame_idx = frame_indices.astype(np.int).min() + max_frame_idx = frame_indices.astype(np.int).max() + for frame_idx in range(min_frame_idx, max_frame_idx + 1): + print("Frame %05d/%05d" % (frame_idx, max_frame_idx)) + mask = frame_indices == frame_idx + rows = detections_in[mask] + + if frame_idx not in image_filenames: + print("WARNING could not find image for frame %d" % frame_idx) + continue + bgr_image = cv2.imread( + image_filenames[frame_idx], cv2.IMREAD_COLOR) + features = encoder(bgr_image, rows[:, 2:6].copy()) + detections_out += [np.r_[(row, feature)] for row, feature + in zip(rows, features)] + + output_filename = os.path.join(output_dir, "%s.npy" % sequence) + np.save( + output_filename, np.asarray(detections_out), allow_pickle=False) + + +def parse_args(): + """Parse command line arguments. + """ + parser = argparse.ArgumentParser(description="Re-ID feature extractor") + parser.add_argument( + "--model", + default="resources/networks/mars-small128.pb", + help="Path to freezed inference graph protobuf.") + parser.add_argument( + "--mot_dir", help="Path to MOTChallenge directory (train or test)", + required=True) + parser.add_argument( + "--detection_dir", help="Path to custom detections. Defaults to " + "standard MOT detections Directory structure should be the default " + "MOTChallenge structure: [sequence]/det/det.txt", default=None) + parser.add_argument( + "--output_dir", help="Output directory. Will be created if it does not" + " exist.", default="detections") + return parser.parse_args() + + +def main(): + args = parse_args() + encoder = create_box_encoder(args.model, batch_size=32) + generate_detections(encoder, args.mot_dir, args.output_dir, + args.detection_dir) + + +if __name__ == "__main__": + main()