Skip to content

Commit

Permalink
add freeze pb
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangpf2 committed Mar 16, 2018
1 parent f37b6bb commit e2c7a27
Show file tree
Hide file tree
Showing 5 changed files with 429 additions and 3 deletions.
79 changes: 79 additions & 0 deletions dataset_tools/dataset_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Utility functions for creating TFRecord data sets."""

import tensorflow as tf


def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def read_examples_list(path):
"""Read list of training or validation examples.
The file is assumed to contain a single example per line where the first
token in the line is an identifier that allows us to find the image and
annotation xml for that example.
For example, the line:
xyz 3
would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored).
Args:
path: absolute path to examples list file.
Returns:
list of example identifiers (strings).
"""
with tf.gfile.GFile(path) as fid:
lines = fid.readlines()
return [line.strip().split(' ')[0] for line in lines]


def recursive_parse_xml_to_dict(xml):
"""Recursively parses XML contents to python dict.
We assume that `object` tags are the only ones that can appear
multiple times at the same level of a tree.
Args:
xml: xml tree obtained by parsing XML file contents using lxml.etree
Returns:
Python dictionary holding XML contents.
"""
if not xml:
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = recursive_parse_xml_to_dict(child)
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result:
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
178 changes: 178 additions & 0 deletions dataset_tools/mscoco_label_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@

# mscoco label range [1,90], but not consecutive. Note: 0 is added for 'background' here
# here we map them to label_id range [1,80]
LABEL_IDX = {
0:0,
1:1,
2:2,
3:3,
4:4,
5:5,
6:6,
7:7,
8:8,
9:9,
10:10,
11:11,
13:12,
14:13,
15:14,
16:15,
17:16,
18:17,
19:18,
20:19,
21:20,
22:21,
23:22,
24:23,
25:24,
27:25,
28:26,
31:27,
32:28,
33:29,
34:30,
35:31,
36:32,
37:33,
38:34,
39:35,
40:36,
41:37,
42:38,
43:39,
44:40,
46:41,
47:42,
48:43,
49:44,
50:45,
51:46,
52:47,
53:48,
54:49,
55:50,
56:51,
57:52,
58:53,
59:54,
60:55,
61:56,
62:57,
63:58,
64:59,
65:60,
67:61,
70:62,
72:63,
73:64,
74:65,
75:66,
76:67,
77:68,
78:69,
79:70,
80:71,
81:72,
82:73,
84:74,
85:75,
86:76,
87:77,
88:78,
89:79,
90:80
}

# map label_id to category_text
COCO_LABELS = {
0: 'Background',
1: 'person',
2: 'bicycle',
3: 'car',
4: 'motorcycle',
5: 'airplane',
6: 'bus',
7: 'train',
8: 'truck',
9: 'boat',
10: 'traffic light',
11: 'fire hydrant',
12: 'stop sign',
13: 'parking meter',
14: 'bench',
15: 'bird',
16: 'cat',
17: 'dog',
18: 'horse',
19: 'sheep',
20: 'cow',
21: 'elephant',
22: 'bear',
23: 'zebra',
24: 'giraffe',
25: 'backpack',
26: 'umbrella',
27: 'handbag',
28: 'tie',
29: 'suitcase',
30: 'frisbee',
31: 'skis',
32: 'snowboard',
33: 'sports ball',
34: 'kite',
35: 'baseball bat',
36: 'baseball glove',
37: 'skateboard',
38: 'surfboard',
39: 'tennis racket',
40: 'bottle',
41: 'wine glass',
42: 'cup',
43: 'fork',
44: 'knife',
45: 'spoon',
46: 'bowl',
47: 'banana',
48: 'apple',
49: 'sandwich',
50: 'orange',
51: 'broccoli',
52: 'carrot',
53: 'hot dog',
54: 'pizza',
55: 'donut',
56: 'cake',
57: 'chair',
58: 'couch',
59: 'potted plant',
60: 'bed',
61: 'dining table',
62: 'toilet',
63: 'tv',
64: 'laptop',
65: 'mouse',
66: 'remote',
67: 'keyboard',
68: 'cell phone',
69: 'microwave',
70: 'oven',
71: 'toaster',
72: 'sink',
73: 'refrigerator',
74: 'book',
75: 'clock',
76: 'vase',
77: 'scissors',
78: 'teddy bear',
79: 'hair drier',
80: 'toothbrush'
}


class COCO_MAP(object):
label_idx=LABEL_IDX
coco_labels=COCO_LABELS
def __init__(self):
pass
4 changes: 3 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main():
# prepare eval/test data and label
img=imread('data/test/t_1_0.jpeg')
img = imresize(img, (args.image_height, args.image_width))
img=preprocess(img)
img=preprocess(img).astype(np.float32)
print(img.dtype)
label=1
feed_dict={input_x:[img],input_y:[label]} # use [], because we need 4-D tensor
Expand All @@ -108,6 +108,8 @@ def main():
res=sess.run(prob, feed_dict=feed_dict)[0] # index 0 for batch_size
print('prob: {}, class: {}'.format(res, np.argmax(res)))
print('time: {}'.format(time.time()-start))
# close session
sess.close()


if __name__=='__main__':
Expand Down
6 changes: 4 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, sess, tf_files, num_sampes, epoch, batch_size, image_height,

def _build_train_graph(self):
self.x_=tf.placeholder(tf.float32, [None, self.h, self.w, 3], name='input')
self.y_=tf.placeholder(tf.int64, [None], name='label')
self.y_=tf.placeholder(tf.int32, [None], name='label')

self.global_step = tf.Variable(0, name='global_step', trainable=False)

Expand Down Expand Up @@ -70,7 +70,7 @@ def _build_train_graph(self):

def _build_test_graph(self):
self.x_ = tf.placeholder(tf.float32, [None, self.h, self.w, 3], name='input')
self.y_ = tf.placeholder(tf.int64, [None], name='label')
self.y_ = tf.placeholder(tf.int32, [None], name='label')
_, _ = self._nets(self.x_, is_train=False)

def _nets(self, X, is_train, reuse=False):
Expand Down Expand Up @@ -217,6 +217,8 @@ def _train(self):
'''

# save the last model when finish training
# graph pb file, need when freeze model
tf.train.write_graph(sess.graph_def, self.checkpoint_dir, self.model_name+'.pb')
save_path=saver.save(self.sess, os.path.join(self.checkpoint_dir, self.model_name), global_step= step)
print('Final model saved in '+save_path)
print('FINISHED TRAINING.')
Loading

0 comments on commit e2c7a27

Please sign in to comment.