-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmy_generate_tfrecord.py
151 lines (125 loc) · 5.2 KB
/
my_generate_tfrecord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Usage:
# Create train data:
python generate_tfrecord.py --label=<LABEL> --csv_input=<PATH_TO_ANNOTATIONS_FOLDER>/train_labels.csv --output_path=<PATH_TO_ANNOTATIONS_FOLDER>/train.record <PATH_TO_ANNOTATIONS_FOLDER>/label_map.pbtxt
# Create test data:
python generate_tfrecord.py --label=<LABEL> --csv_input=<PATH_TO_ANNOTATIONS_FOLDER>/test_labels.csv --output_path=<PATH_TO_ANNOTATIONS_FOLDER>/test.record --label_map <PATH_TO_ANNOTATIONS_FOLDER>/label_map.pbtxt
"""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import os
import io
import pandas as pd
import tensorflow as tf
import sys
import pdb
import cv2
sys.path.append("../../models/research")
from PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDict
from random import shuffle
flags = tf.app.flags
flags.DEFINE_string("csv_input", "", "Path to the CSV input")
flags.DEFINE_string("output_path", "", "Path to output TFRecord")
flags.DEFINE_string(
"label_map",
"",
"Path to the `label_map.pbtxt` contains the <class_name>:<class_index> pairs generated by `xml_to_csv.py` or manually.",
)
# if your image has more labels input them as
# flags.DEFINE_string('label0', '', 'Name of class[0] label')
# flags.DEFINE_string('label1', '', 'Name of class[1] label')
# and so on.
flags.DEFINE_string("img_path", "", "Path to images")
FLAGS = flags.FLAGS
def split(df, group):
data = namedtuple("data", ["filename", "object"])
gb = df.groupby(group)
return [
data(filename, gb.get_group(x))
for filename, x in zip(gb.groups.keys(), gb.groups)
]
def create_tf_example(group, path, label_map):
with tf.gfile.GFile(os.path.join(path, "{}".format(group.filename)), "rb") as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
#width, height = image.size # changed because image maybe resized after generating xml files
filename = group.filename.encode("utf8")
image_format = b"jpg"
# check if the image format is matching with your images.
xmins = []
xmaxs = []
ymins = []
ymaxs = []
classes_text = []
classes = []
for index, row in group.object.iterrows():
width = row['width']
height = row['height']
if width == 0 or height == 0:
img_name = row["filename"]
img = cv2.imread(os.path.join(path, img_name))
height, width = img.shape[:2]
if width == 0 or height == 0:
continue
xmins.append(row["xmin"] / width)
xmaxs.append(row["xmax"] / width)
ymins.append(row["ymin"] / height)
ymaxs.append(row["ymax"] / height)
classes_text.append(row["class"].encode("utf8"))
class_index = label_map.get(row["class"])
assert (
class_index is not None
), "class label: `{}` not found in label_map: {}".format(
row["class"], label_map
)
classes.append(class_index)
#width, height = image.size
tf_example = tf.train.Example(
features=tf.train.Features(
feature={
"image/height": dataset_util.int64_feature(height),
"image/width": dataset_util.int64_feature(width),
"image/filename": dataset_util.bytes_feature(filename),
"image/source_id": dataset_util.bytes_feature(filename),
"image/encoded": dataset_util.bytes_feature(encoded_jpg),
"image/format": dataset_util.bytes_feature(image_format),
"image/object/bbox/xmin": dataset_util.float_list_feature(xmins),
"image/object/bbox/xmax": dataset_util.float_list_feature(xmaxs),
"image/object/bbox/ymin": dataset_util.float_list_feature(ymins),
"image/object/bbox/ymax": dataset_util.float_list_feature(ymaxs),
"image/object/class/text": dataset_util.bytes_list_feature(
classes_text
),
"image/object/class/label": dataset_util.int64_list_feature(classes),
}
)
)
return tf_example
def main(_):
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
path = os.path.join(os.getcwd(), FLAGS.img_path)
examples = pd.read_csv(FLAGS.csv_input)
# Load the `label_map` from pbtxt file.
from object_detection.utils import label_map_util
label_map = label_map_util.load_labelmap(FLAGS.label_map)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=90, use_display_name=True
)
category_index = label_map_util.create_category_index(categories)
label_map = {}
for k, v in category_index.items():
label_map[v.get("name")] = v.get("id")
grouped = split(examples, "filename")
shuffle(grouped) # this was added in order to add "randomness"
for group in grouped:
tf_example = create_tf_example(group, path, label_map)
writer.write(tf_example.SerializeToString())
writer.close()
output_path = os.path.join(os.getcwd(), FLAGS.output_path)
print("Successfully created the TFRecords: {}".format(output_path))
if __name__ == "__main__":
tf.app.run()