Skip to content

Commit

Permalink
switching to tf dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
preritj committed Jul 6, 2018
1 parent bd5bbe8 commit 1852a19
Show file tree
Hide file tree
Showing 19 changed files with 1,398 additions and 438 deletions.
4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions .idea/segmentation.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

369 changes: 369 additions & 0 deletions .idea/workspace.xml

Large diffs are not rendered by default.

259 changes: 259 additions & 0 deletions dataset/data_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import os
import glob
import numpy as np
import tensorflow as tf
import functools
from utils.dataset_util import (
rotate, random_hue, random_contrast, random_brightness)


slim_example_decoder = tf.contrib.slim.tfexample_decoder


class SegDataReader(object):
def __init__(self, data_cfg):
self.data_cfg = data_cfg
self.datasets = []

for dataset in data_cfg.datasets:
data_dir = dataset['data_dir']
name = dataset['name']
weight = dataset['weight']
img_files = os.path.join(
data_dir, dataset['img_glob_string'])
mask_files = dataset['img_glob_string']
if mask_files is not None:
img_files = os.path.join(
data_dir, dataset['img_glob_string'])
tfrecord_file = dataset['tfrecord_file']
if (tfrecord_file is None) or dataset['overwrite_tfrecord']:
tfrecord_name = os.path.basename(data_dir) + '.records'
sub_dir = os.path.dirname(dataset['tfrecord_files'])
tfrecord_path = os.path.join(data_dir, sub_dir, tfrecord_name)
tfrecord_dir = os.path.dirname(tfrecord_path)
if not os.path.exists(tfrecord_dir):
os.makedirs(tfrecord_dir)
ds = self.add_dataset(name, img_files)
ds.create_tf_record(tfrecord_path)
self.datasets.append({'name': name,
'tfrecord_path': tfrecord_path,
'weight': weight})

def add_dataset(self, name, img_files, mask_files):
if name == 'png_objects':
ds = PNGobjects(self.data_cfg, img_files, mask_files)
else:
raise RuntimeError('Dataset not supported')
return ds

def _get_probs(self):
probs = [ds['weight'] for ds in self.datasets]
probs = np.array(probs)
return probs / np.sum(probs)

@staticmethod
def _get_tensor(tensor):
if isinstance(tensor, tf.SparseTensor):
return tf.sparse_tensor_to_dense(tensor)
return tensor

@staticmethod
def _image_decoder(keys_to_tensors):
filename = keys_to_tensors['image/filename']
image_string = tf.read_file(filename)
# TODO: decode after crop to increase speed
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
return image_decoded

@staticmethod
def _mask_decoder(keys_to_tensors):
mask_x = SegDataReader._get_tensor(
keys_to_tensors['image/mask/x'])
mask_y = SegDataReader._get_tensor(
keys_to_tensors['image/mask/y'])
shape = SegDataReader._get_tensor(
keys_to_tensors['image/shape'])

indices = tf.stack([mask_x, mask_y], axis=1)
values = tf.zeros_like(mask_x)
mask = tf.SparseTensor(indices=indices, values=values,
dense_shape=shape)
# TODO: possibly do sparse to dense coversion after crop
mask = tf.sparse_tensor_to_dense(mask, default_value=1)
return tf.cast(mask, tf.int32)

@staticmethod
def _keypoints_decoder(keys_to_tensor, num_keypoints=15):
keypoints = keys_to_tensor['image/person/keypoints']
img_shape = keys_to_tensor['image/shape']
num_instances = PoseDataReader._get_tensor(
keys_to_tensor['image/num_instances'])
shape = [num_instances] + [num_keypoints, 3]
if isinstance(keypoints, tf.SparseTensor):
keypoints = tf.sparse_reshape(keypoints, shape=shape)
keypoints = tf.sparse_tensor_to_dense(keypoints)
else:
keypoints = tf.reshape(keypoints, shape=shape)
keypoints = normalize_keypoints(keypoints, img_shape)
return keypoints

@staticmethod
def _bbox_decoder(keys_to_tensor):
bbox = keys_to_tensor['image/person/bbox']
img_shape = keys_to_tensor['image/shape']
num_instances = PoseDataReader._get_tensor(
keys_to_tensor['image/num_instances'])
shape = [num_instances] + [4]
if isinstance(bbox, tf.SparseTensor):
bbox = tf.sparse_reshape(bbox, shape=shape)
bbox = tf.sparse_tensor_to_dense(bbox)
else:
bbox = tf.reshape(bbox, shape=shape)
bbox = normalize_bboxes(bbox, img_shape)
return bbox

def _decoder(self):
keys_to_features = {
'image/filename':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/shape':
tf.FixedLenFeature([2], tf.int64),
'image/num_instances':
tf.FixedLenFeature((), tf.int64),
'image/person/bbox':
tf.VarLenFeature(tf.float32),
'image/person/keypoints':
tf.VarLenFeature(tf.float32),
'image/mask/x':
tf.VarLenFeature(tf.int64),
'image/mask/y':
tf.VarLenFeature(tf.int64)
}
keypoints_decoder = functools.partial(
self._keypoints_decoder, num_keypoints=self.num_keypoints)
items_to_handlers = {
'image': slim_example_decoder.ItemHandlerCallback(
'image/filename', self._image_decoder),
'mask': slim_example_decoder.ItemHandlerCallback(
['image/mask/x', 'image/mask/y', 'image/shape'],
self._mask_decoder),
'keypoints': slim_example_decoder.ItemHandlerCallback(
['image/person/keypoints', 'image/num_instances',
'image/shape'], keypoints_decoder),
'bbox': slim_example_decoder.ItemHandlerCallback(
['image/person/bbox', 'image/num_instances',
'image/shape'], self._bbox_decoder)
}
decoder = slim_example_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
return decoder

def augment_data(self, dataset, train_cfg):
aug_cfg = train_cfg.augmentation
preprocess_cfg = train_cfg.preprocess
img_size = preprocess_cfg['image_resize']
if aug_cfg['flip_left_right']:
kp_dict = {kp_name: i for i, kp_name in
enumerate(train_cfg.train_keypoints)}
flipped_kp_indices = []
for kp_name in train_cfg.train_keypoints:
if kp_name.startswith('left'):
flipped_kp_name = 'right' + kp_name.split('left')[1]
flipped_kp_indices.append(kp_dict[flipped_kp_name])
elif kp_name.startswith('right'):
flipped_kp_name = 'left' + kp_name.split('right')[1]
flipped_kp_indices.append(kp_dict[flipped_kp_name])
else:
flipped_kp_indices.append(kp_dict[kp_name])
random_flip_left_right_fn = functools.partial(
random_flip_left_right,
flipped_keypoint_indices=flipped_kp_indices)
dataset = dataset.map(
random_flip_left_right_fn,
num_parallel_calls=train_cfg.num_parallel_map_calls
)
dataset = dataset.prefetch(train_cfg.prefetch_size)
random_crop_fn = functools.partial(
random_crop,
crop_size=img_size,
scale_range=aug_cfg['scale_range']
)
if aug_cfg['random_crop']:
dataset = dataset.map(
random_crop_fn,
num_parallel_calls=train_cfg.num_parallel_map_calls
)
dataset = dataset.prefetch(train_cfg.prefetch_size)
if aug_cfg['random_brightness']:
dataset = dataset.map(
random_brightness,
num_parallel_calls=train_cfg.num_parallel_map_calls
)
dataset = dataset.prefetch(train_cfg.prefetch_size)
if aug_cfg['random_contrast']:
dataset = dataset.map(
random_contrast,
num_parallel_calls=train_cfg.num_parallel_map_calls
)
dataset = dataset.prefetch(train_cfg.prefetch_size)
return dataset

def preprocess_data(self, dataset, train_cfg):
preprocess_cfg = train_cfg.preprocess
img_size = preprocess_cfg['image_resize']
resize_fn = functools.partial(
resize,
target_image_size=img_size)
dataset = dataset.map(
resize_fn,
num_parallel_calls=train_cfg.num_parallel_map_calls
)
dataset.prefetch(train_cfg.prefetch_size)
return dataset

def read_data(self, train_config):
probs = self._get_probs()
probs = tf.cast(probs, tf.float32)
decoder = self._decoder()
filenames = [ds['tfrecord_path'] for ds in self.datasets]
file_ids = list(range(len(filenames)))
dataset = tf.data.Dataset.from_tensor_slices((file_ids, filenames))
dataset = dataset.apply(tf.contrib.data.rejection_resample(
class_func=lambda c, _: c,
target_dist=probs,
seed=42))
dataset = dataset.map(lambda _, a: a[1])
if train_config.shuffle:
dataset = dataset.shuffle(
train_config.filenames_shuffle_buffer_size)

dataset = dataset.repeat(train_config.num_epochs or None)

file_read_func = functools.partial(tf.data.TFRecordDataset,
buffer_size=8 * 1000 * 1000)
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
file_read_func, cycle_length=train_config.num_readers,
block_length=train_config.read_block_length, sloppy=True))
if train_config.shuffle:
dataset = dataset.shuffle(train_config.shuffle_buffer_size)

decode_fn = functools.partial(
decoder.decode, items=['image', 'keypoints', 'bbox', 'mask'])
dataset = dataset.map(
decode_fn, num_parallel_calls=train_config.num_parallel_map_calls)
dataset = dataset.prefetch(train_config.prefetch_size)

train_keypoints = [self.data_cfg.keypoints[kp_name]
for kp_name in train_config.train_keypoints]
kp_subset_fn = functools.partial(
keypoints_select, keypoints_to_keep=train_keypoints)
dataset = dataset.map(
kp_subset_fn,
num_parallel_calls=train_config.num_parallel_map_calls)
dataset = dataset.prefetch(train_config.prefetch_size)

dataset = self.augment_data(dataset, train_config)

dataset = self.preprocess_data(dataset, train_config)
return dataset
72 changes: 72 additions & 0 deletions dataset/seg_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import time
import numpy as np
import os
from abc import abstractmethod
from tqdm import tqdm
from utils import tfrecord_util
import tensorflow as tf


class SegData(object):
def __init__(self, cfg, img_files, mask_files=None):
"""
Constructor of ObjectData class
"""
self.cfg = cfg
self.imgs, self.ids = None, None
self.data_dir = data_dir
self.product_labels = {}
print('loading annotations into memory...')
tic = time.time()
self.datasets = []
if type(train_files) != list:
train_files = [train_files]
for train_file in train_files:
labels_file = os.path.dirname(train_file)
labels_file = os.path.join(labels_file, 'labels.txt')
with open(labels_file, 'r') as f:
self.product_names = {}
for line in f:
label, prod_name = line.split()
self.product_labels[prod_name] = int(label)
with open(train_file, 'r') as f:
dataset = {}
for line in f:
img, ann_file = line.split()
dataset[img] = ann_file
self.datasets.append(dataset)
print('Done (t={:0.2f}s)'.format(time.time() - tic))
self.create_index()

@abstractmethod
def create_index(self):
return

def get_size(self):
return len(self.ids)

def _create_tf_example(self, img_id):
img_meta = self.imgs[img_id]
img_file = img_meta['filename']
img_file = os.path.join(self.data_dir, img_file)
img_shape = list(img_meta['shape'])

feature_dict = {
'image/filename':
tfrecord_util.bytes_feature(img_file.encode('utf8')),
'image/shape':
tfrecord_util.int64_list_feature(img_shape),
'mask/filename':
tfrecord_util.bytes_feature(img_file.encode('utf8'))
}
return tf.train.Example(features=tf.train.Features(feature=feature_dict))

def create_tf_record(self, out_path, shuffle=True):
print("Creating tf records : ", out_path)
writer = tf.python_io.TFRecordWriter(out_path)
if shuffle:
np.random.shuffle(self.ids)
for img_id in tqdm(self.ids):
tf_example = self._create_tf_example(img_id)
writer.write(tf_example.SerializeToString())
writer.close()
Loading

0 comments on commit 1852a19

Please sign in to comment.